diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c8212..db4b1581ae671b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements @@ -79,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g., Changes to TensorFlow C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). -Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: ```bash apt-get install -y clang-tidy diff --git a/README.md b/README.md index e1a50c87e26d49..6fb4486d0de9ff 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ ----------------- -| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------|---------------| -| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +| **`Documentation`** | +|-----------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -40,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. -**Individual whl files** -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) -* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) -([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) - #### *Try your first TensorFlow program* ```shell $ python @@ -82,6 +73,30 @@ The TensorFlow project strives to abide by generally accepted best practices in [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + +## Continuous build status + +### Official Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | TBA | TBA | +| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | + + +### Community Supported Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | +| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | + + ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index 2717c75740aeea..27f73b7fc6a524 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -6,7 +6,7 @@ * Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor. * Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability. * `tf.contrib.bayesflow` is moving out to it's own repo. -* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication. +* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication[1](#rpc-issue). ## Bug Fixes and Other Changes * `tf.data`: @@ -49,13 +49,14 @@ * Fix non-uniformity of orthogonal matrices. * Fix bug where multi-image Estimator eval summaries were not displayed correctly. +1 The cancellation logic of the RPC op contains a concurrency error. A fix has been submitted to master and will be part of the next release. + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: 4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu - # Release 1.7.0 ## Major Features And Improvements @@ -235,7 +236,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 * Add `complex64` support to XLA compiler. * `bfloat` support is now added to XLA infrastructure. * Make `ClusterSpec` propagation work with XLA devices. - * Use a determinisitic executor to generate XLA graph. + * Use a deterministic executor to generate XLA graph. * `tf.contrib`: * `tf.contrib.distributions`: * Add `tf.contrib.distributions.Autoregressive`. @@ -403,14 +404,6 @@ answered questions, and were part of inspiring discussions. # Release 1.4.0 -## Major Features And Improvements -* `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of - the core TensorFlow API. - * The API is now subject to backwards compatibility guarantees. - -# Release 1.4.0 - ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. * [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of diff --git a/SECURITY.md b/SECURITY.md index a5ce3a62ee202f..0a4be37cbc2066 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -168,12 +168,12 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. In addition, please include the following information along with your report: * Your name and affiliation (if any). -* A description the technical details of the vulnerabilities. It is very +* A description of the technical details of the vulnerabilities. It is very important to let us know how we can reproduce your findings. * An explanation who can exploit this vulnerability, and what they gain when doing so -- write an attack scenario. This will help us evaluate your report @@ -246,5 +246,8 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= | Type | Versions affected | Reported by | Additional Information | |--------------------|:-----------------:|-----------------------|-----------------------------| +| TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-003.md) | +| GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-002.md) | +| BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-001.md) | | Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/WORKSPACE b/WORKSPACE index 4ddfb9a3832ea1..fd7570a80ae2ee 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0") load("//tensorflow:workspace.bzl", "tf_workspace") -# Uncomment and update the paths in these entries to build the Android demo. -#android_sdk_repository( -# name = "androidsdk", -# api_level = 23, -# # Ensure that you have the build_tools_version below installed in the -# # SDK manager as it updates periodically. -# build_tools_version = "26.0.1", -# # Replace with path to Android SDK on your system -# path = "", -#) -# -#android_ndk_repository( -# name="androidndk", -# path="", -# # This needs to be 14 or higher to compile TensorFlow. -# # Please specify API level to >= 21 to build for 64-bit -# # archtectures or the Android NDK will automatically select biggest -# # API level that it supports without notice. -# # Note that the NDK version is not the API level. -# api_level=14) +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() diff --git a/build b/build index 25744ada64402a..f8c345d169ee3f 100755 --- a/build +++ b/build @@ -11,4 +11,4 @@ pip uninstall -y tensorflow || true bazel build --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures && bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg && -pip install /tmp/tensorflow_pkg/tensorflow-1.8.0rc1-cp27-cp27mu-linux_x86_64.whl +pip install /tmp/tensorflow_pkg/tensorflow-1.8.0-cp27-cp27mu-linux_x86_64.whl diff --git a/build_python3 b/build_python3 index a093bdd4873fe6..b0f6f2318a0b9b 100755 --- a/build_python3 +++ b/build_python3 @@ -11,4 +11,4 @@ pip3 uninstall -y tensorflow || true bazel build --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures && bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg && -pip3 install /tmp/tensorflow_pkg/tensorflow-1.8.0rc1-cp35-cp35m-linux_x86_64.whl +pip3 install /tmp/tensorflow_pkg/tensorflow-1.8.0-cp35-cp35m-linux_x86_64.whl diff --git a/configure.py b/configure.py index 4f6fc8e70bc29b..46c637843bc2dd 100644 --- a/configure.py +++ b/configure.py @@ -498,10 +498,6 @@ def set_cc_opt_flags(environ_cp): if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=haswell') write_to_bazelrc('build:opt --define with_default_optimizations=true') - # TODO(mikecase): Remove these default defines once we are able to get - # TF Lite targets building without them. - write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -674,8 +670,9 @@ def valid_ndk_path(path): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.') ) - - write_android_ndk_workspace_rule(android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', + check_ndk_level(android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -737,41 +734,12 @@ def valid_build_tools(version): error_msg=('The selected SDK does not have build-tools version %s ' 'available.')) - write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level) - - -def write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level): - print('Writing android_sdk_workspace rule.\n') - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_sdk_repository( - name="androidsdk", - api_level=%s, - path="%s", - build_tools_version="%s")\n -""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) - - -def write_android_ndk_workspace_rule(android_ndk_home_path): - print('Writing android_ndk_workspace rule.') - ndk_api_level = check_ndk_level(android_ndk_home_path) - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_ndk_repository( - name="androidndk", - path="%s", - api_level=%s)\n -""" % (android_ndk_home_path, ndk_api_level)) + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', + android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', + android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -784,18 +752,16 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - return revision.group(1) - return None - - -def workspace_has_any_android_rule(): - """Check the WORKSPACE for existing android_*_repository rules.""" - with open(_TF_WORKSPACE, 'r') as f: - workspace = f.read() - has_any_rule = re.search(r'^android_[ns]dk_repository', - workspace, - re.MULTILINE) - return has_any_rule + ndk_api_level = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + return ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -845,8 +811,8 @@ def reformat_version_sequence(version_str, sequence_count): def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use, ' - 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION + 'Please specify the CUDA SDK version you want to use. ' + '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. @@ -1226,6 +1192,9 @@ def set_tf_cuda_compute_capabilities(environ_cp): ask_cuda_compute_capabilities, default_cuda_compute_capabilities) # Check whether all capabilities from the input is valid all_valid = True + # Remove all whitespace characters before splitting the string + # that users may insert by accident, as this will result in error + tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: @@ -1428,6 +1397,10 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_build_strip_flag(): + write_to_bazelrc('build --strip=always') + + def set_windows_build_flags(): if is_windows(): # The non-monolithic build is not supported yet @@ -1558,23 +1531,18 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) + set_build_strip_flag() set_windows_build_flags() - if workspace_has_any_android_rule(): - print('The WORKSPACE file has at least one of ["android_sdk_repository", ' - '"android_ndk_repository"] already set. Will not ask to help ' - 'configure the WORKSPACE. Please delete the existing rules to ' - 'activate the helper.\n') - else: - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): - create_android_ndk_rule(environ_cp) - create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See tools/bazel.rc for ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 79adbe318c4c64..fb5a52e0c9e44c 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -19,6 +19,10 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/tools/api/generator:api_gen.bzl", + "gen_api_init_files", # @unused +) # Config setting for determining if we are building for Android. config_setting( @@ -478,7 +482,7 @@ tf_cc_shared_object( # excludes all but a subset of function names. # On MacOS, the linker does not support version_script, but has an # an "-exported_symbols_list" command. -z defs disallows undefined -# symbols in object files and -s strips the output. +# symbols in object files. tf_cc_shared_object( name = "libtensorflow.so", @@ -492,7 +496,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow/c:version_script.lds)", ], @@ -518,7 +521,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_version_script.lds)", ], @@ -543,13 +545,16 @@ exports_files( ], ) +gen_api_init_files( + name = "python_api_gen", + srcs = ["api_template.__init__.py"], + root_init_template = "api_template.__init__.py", +) + py_library( name = "tensorflow_py", - srcs = ["__init__.py"], + srcs = [":python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - "//tensorflow/python", - "//tensorflow/tools/api/generator:python_api", - ], + deps = ["//tensorflow/python"], ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index c8683e3976c90a..440e9f8dbd2f4b 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -22,9 +22,6 @@ # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# pylint: disable=wildcard-import -from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin -# pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py new file mode 100644 index 00000000000000..9b0d7d48afd058 --- /dev/null +++ b/tensorflow/api_template.__init__.py @@ -0,0 +1,43 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + +from tensorflow.python.util.lazy_loader import LazyLoader +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 18eeb2816807ec..b86b277ac3200b 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index d3916bc16778a9..95b04f9058afdf 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -8368,3 +8368,90 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( return getnext_node; #endif } + +TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + VLOG(1) << "Dequeuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + } + + TF_Operation* dequeue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str()); + if (dequeue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the dequeue node in the TF graph."); + return nullptr; + } + + VLOG(1) << "Running the dequeue op"; + TF_Output output{dequeue_op, 0}; + TF_Tensor* ret; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ &output, /*output_values*/ &ret, + /*noutputs*/ 1, + /*targets*/ nullptr, /*ntargets*/ 0, + /*run_metadata*/ nullptr, status); + if (VLOG_IS_ON(1) && status->status.ok()) { + tensorflow::Tensor tensor; + if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) { + VLOG(1) << "Dequeued tensor content: " << tensor.DebugString(); + } + } + return ret; +} + +void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TF_Tensor* tensor, TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Enqueuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + tensorflow::Tensor internal_tensor; + if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { + VLOG(1) << "Enqueu'ing tensor content: " + << internal_tensor.DebugString(); + } + } + } + + TF_Operation* enqueue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str()); + if (enqueue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the enqueue node in the TF graph."); + return; + } + + TF_Operation* placeholder_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str()); + if (placeholder_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the placeholder node as input to enqueue in the TF " + "graph."); + return; + } + + VLOG(1) << "Running the enqueue op"; + TF_Output input{placeholder_op, 0}; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, + /*targets*/ &enqueue_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); + VLOG(1) << "Enqueuing is done."; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 88cb173cd25f42..20bdace40f1272 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -86,6 +86,35 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( TF_Graph* graph, const char* file_path, int batch_size, unsigned char is_mnist, TF_Status* status); +// On success, dequeues a tensor from a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. +// +// Tensors are enqueued via the corresponding TF enqueue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, + int tensor_id, + TF_Status* status); + +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 2762f31e0ccebf..581e5bd1998c73 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) { } const tensorflow::string input_op_name = - tensorflow::ParseTensorName(input_name).first.ToString(); + std::string(tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); @@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); const tensorflow::string output_op_name = - tensorflow::ParseTensorName(output_name).first.ToString(); + std::string(tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index cd19cf8d624d9b..c16aba666ee697 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index b1f7bdaa5420a5..74bc25a491ac01 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - v2_reader_->key().ToString() /* full var's name */, + std::string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; + if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = v2_reader_->key().ToString(); + string key = std::string(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d51f0520ac39d5..f2af81f04c5e1f 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -14,6 +14,7 @@ tf_gpu_library( name = "c_api", srcs = [ "c_api.cc", + "c_api_debug.cc", "c_api_internal.h", ], hdrs = ["c_api.h"], @@ -24,10 +25,10 @@ tf_gpu_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", @@ -45,10 +46,22 @@ tf_gpu_library( "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", ], "//conditions:default": [], }) + [ "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", ], ) @@ -59,7 +72,6 @@ tf_gpu_library( visibility = ["//tensorflow:internal"], deps = [ ":c_api", - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -69,70 +81,65 @@ tf_gpu_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) -tf_gpu_cc_test( - name = "c_api_test", - srcs = ["c_api_test.cc"], - extra_copts = tfe_xla_copts(), - tags = [ - "guitar", - "multi_gpu", +tf_gpu_library( + name = "c_api_test_util", + testonly = 1, + srcs = ["c_api_test_util.cc"], + hdrs = ["c_api_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", ], deps = [ ":c_api", "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", - "//tensorflow/core:test_main", ], ) -tf_gpu_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - copts = tf_copts(), - visibility = ["//tensorflow:internal"], - deps = select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/c:c_api", - "//tensorflow/core:core_cpu", - "//tensorflow/core/common_runtime/eager:kernel_and_device", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], - }), -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], +tf_gpu_cc_test( + name = "c_api_test", + srcs = [ + "c_api_debug_test.cc", + "c_api_test.cc", + ], + extra_copts = tfe_xla_copts(), + tags = [ + "guitar", + "multi_gpu", + ], deps = [ - ":runtime", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", + ":c_api", + ":c_api_test_util", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3bf071f3abaac7..81221c4078bec9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -32,15 +31,22 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -67,10 +73,121 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#ifdef TENSORFLOW_EAGER_USE_XLA -std::atomic_int_fast64_t func_id_generator(0); -#endif // TENSORFLOW_EAGER_USE_XLA +tensorflow::Status GetAllRemoteDevices( + const std::vector& remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { + std::vector remote_devices; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + for (const string& remote_worker : remote_workers) { + tensorflow::Notification n; + tensorflow::NewRemoteDevices( + tensorflow::Env::Default(), worker_cache, remote_worker, + [&status, &n, &remote_devices]( + const tensorflow::Status& s, + std::vector* devices) { + status = s; + if (s.ok()) { + for (tensorflow::Device* d : *devices) { + remote_devices.push_back(d); + } + } + n.Notify(); + }); + n.WaitForNotification(); + } + std::unique_ptr remote_device_mgr( + new tensorflow::DeviceMgr(remote_devices)); + + TF_RETURN_IF_ERROR(status); + + *device_mgr = std::move(remote_device_mgr); + return tensorflow::Status::OK(); +} + +tensorflow::Status CreateRemoteContexts( + const std::vector& remote_workers, + tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + tensorflow::gtl::FlatMap* remote_contexts) { + for (int i = 0; i < remote_workers.size(); i++) { + const string& remote_worker = remote_workers[i]; + + tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextResponse response; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + return tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + } + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.set_async(async); + auto* eager_client = remote_eager_workers->GetClient(remote_worker); + if (eager_client == nullptr) { + return tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + tensorflow::Notification n; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + eager_client->CreateContextAsync( + &request, &response, [&status, &n](const tensorflow::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(status); + + remote_contexts->emplace(remote_worker, response.context_id()); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, + TFE_Context** ctx) { + string worker_name = tensorflow::strings::StrCat( + "/job:", opts->server_def.job_name(), + "/replica:0/task:", opts->server_def.task_index()); + std::unique_ptr server; + TF_RETURN_IF_ERROR( + tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); + + TF_RETURN_IF_ERROR(server->Start()); + + std::vector remote_workers; + server->master_env()->worker_cache->ListWorkers(&remote_workers); + remote_workers.erase( + std::remove(remote_workers.begin(), remote_workers.end(), worker_name), + remote_workers.end()); + + std::unique_ptr remote_device_mgr; + TF_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + + std::shared_ptr channel_cache = + server->channel_cache(); + std::unique_ptr remote_eager_workers( + tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); + // Initialize remote eager workers. + tensorflow::gtl::FlatMap remote_contexts; + TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); + + tensorflow::RemoteRendezvous* r = + server->worker_env()->rendezvous_mgr->Find(0); + + auto* device_mgr = server->worker_env()->device_mgr; + *ctx = new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, r, std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts); + + return tensorflow::Status::OK(); +} } // namespace extern "C" { @@ -91,6 +208,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status) { + if (!options->server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + } +} + TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -100,17 +226,23 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { + if (!opts->server_def.job_name().empty()) { + TFE_Context* ctx = nullptr; + status->status = NewRemoteAwareTFE_Context(opts, &ctx); + return ctx; + } + std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); - if (!status->status.ok()) { - return nullptr; - } + if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( new tensorflow::DeviceMgr(devices)); + tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); + return new TFE_Context(opts->session_options.options, opts->policy, opts->async, std::move(device_mgr), r); } @@ -119,7 +251,10 @@ void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context.remote_device_mgr()) { + ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + } return list; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index c06ce84a8c578a..1862af3ce2f505 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); @@ -181,6 +191,45 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); +// Debugging/Profiling information for TFE_TensorHandle +// +// TFE_TensorDebugInfo contains information useful for debugging and +// profiling tensors. +typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo; + +// Retrieves TFE_TensorDebugInfo for `handle`. +// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller +// is responsible for deleting returned TFE_TensorDebugInfo. +// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate +// error and nullptr is returned. This function can block till the operation +// that produces `handle` has completed. +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status); + +// Deletes `debug_info`. +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of dimensions used to represent the tensor on its device. +// The number of dimensions used to reprensent the tensor on device can be +// different from the number returned by TFE_TensorHandleNumDims. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of elements in dimension `dim_index`. +// Tensor representation on device can be transposed from its representation +// on host. The data contained in dimension `dim_index` on device +// can correspond to the data contained in another dimension in on-host +// representation. The dimensions are indexed using the standard TensorFlow +// major-to-minor order (slowest varying dimension first), +// not the XLA's minor-to-major order. +// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns +// the number of elements in a dimension after padding. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index); + // Description of the TensorFlow op to execute. // // Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc new file mode 100644 index 00000000000000..5006b76f1981d0 --- /dev/null +++ b/tensorflow/c/eager/c_api_debug.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/jit/xla_device.h" +#endif // TENSORFLOW_EAGER_USE_XLA + +using tensorflow::int64; +using tensorflow::string; + +namespace { + +std::vector TensorShapeAsVector(TFE_TensorHandle* handle, + TF_Status* status) { + std::vector shape; + int rank = TFE_TensorHandleNumDims(handle, status); + if (!status->status.ok()) { + return shape; + } + shape.reserve(rank); + for (int i = 0; i < rank; ++i) { + shape.push_back(TFE_TensorHandleDim(handle, i, status)); + if (!status->status.ok()) { + return shape; + } + } + return shape; +} + +} // namespace + +extern "C" { + +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status) { + const tensorflow::Tensor* tensor; + status->status = handle->handle->Tensor(&tensor); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::Device* device; + status->status = handle->handle->Device(&device); + if (!status->status.ok()) { + return nullptr; + } + +#ifdef TENSORFLOW_EAGER_USE_XLA + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. + tensorflow::XlaDevice* xla_device = + dynamic_cast(device); + if (xla_device != nullptr) { + tensorflow::XlaDevice::PaddedShapeFn shape_fn = + xla_device->metadata().padded_shape_fn(); + xla::Shape padded_shape; + status->status = shape_fn(*tensor, &padded_shape); + if (!status->status.ok()) { + return nullptr; + } + if (VLOG_IS_ON(3)) { + std::vector shape_to_log = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + // Ignore the status here as we are simply logging. + status->status = tensorflow::Status::OK(); + } else { + VLOG(3) << "Fully padded shape of [" + << tensorflow::str_util::Join(shape_to_log, ", ") << "] is " + << padded_shape.DebugString(); + } + } + + if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { + // Currently, the only case of XlaTensor containing a tuple shape is to + // represent 64 bit ints, doubles, and complex numbers (we don't support + // 64bit complex numbers). + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should only contain tuples of size 2. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // shape0 is not a const& because we will assign it to padded_shape below. + // It is illegal to assign a part of a message to itself. + xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); + const xla::Shape& shape1 = + xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); + if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should not contain nested tuples. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + if (!xla::ShapeUtil::Equal(shape0, shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "Subshapes of XlaTensors should be the same. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // Since the only case we handle here are two equal subshapes, we + // simply return one of them. The caller will interpret it as this + // shape directly storing the 64bit types. This approximation is good + // enough for this API's debugging use case. + padded_shape = shape0; + } + + int rank = padded_shape.dimensions_size(); + std::vector dev_dims; + dev_dims.reserve(rank); + if (rank == 1) { + // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, + dev_dims.push_back(padded_shape.dimensions(0)); + } else { + for (int i = rank - 1; i >= 0; --i) { + int64 dim_index = padded_shape.layout().minor_to_major(i); + dev_dims.push_back(padded_shape.dimensions(dim_index)); + } + } + status->status = tensorflow::Status::OK(); + return new TFE_TensorDebugInfo(dev_dims); + } +#endif // TENSORFLOW_EAGER_USE_XLA + + // If the tensor is not an XLA tensor, the device shape is + // the same as regular tensor shape. + std::vector dev_dims = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + return nullptr; + } + return new TFE_TensorDebugInfo(dev_dims); +} + +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info) { + delete debug_info; +} + +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info) { + return debug_info->dev_dims.size(); +} + +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index) { + return debug_info->dev_dims[dim_index]; +} + +} // extern "C" diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc new file mode 100644 index 00000000000000..cddb9f6e00e9d6 --- /dev/null +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +TEST(CApiDebug, ScalarCPU) { + TFE_TensorHandle* h = TestScalarTensorHandle(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} + +TEST(CApiDebug, 2DCPU) { + TFE_TensorHandle* h = TestMatrixTensorHandle3X2(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + // Shape is the same for CPU tensors. + EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0)); + EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 49e1aab1cef957..04a6efc47c5177 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -28,8 +28,8 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" @@ -37,6 +37,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -51,6 +59,7 @@ struct TFE_ContextOptions { // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; + tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -64,6 +73,23 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} + explicit TFE_Context( + const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + tensorflow::DeviceMgr* local_device_mgr, + tensorflow::Rendezvous* rendezvous, + std::unique_ptr server, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_mgr, + const tensorflow::gtl::FlatMap& + remote_contexts) + : context(opts, + static_cast( + default_policy), + async, local_device_mgr, rendezvous, std::move(server), + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts) {} + tensorflow::EagerContext context; }; @@ -81,6 +107,14 @@ struct TFE_TensorHandle { tensorflow::TensorHandle* handle; }; +struct TFE_TensorDebugInfo { + TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; +}; + struct TFE_Op { // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a // primitive operation. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 701175e4943d1d..27ff5f7211b059 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -23,128 +25,14 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::string; namespace { -TFE_TensorHandle* DoubleTestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle3X2() { - int64_t dims[] = {3, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, a, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, b, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); - - return op; -} - -TFE_TensorHandle* TestAxisTensorHandle() { - int64_t dims[] = {1}; - int data[] = {1}; - TF_Tensor* t = TF_AllocateTensor( - TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, - TFE_TensorHandle* axis) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "Min", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, input, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, axis, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpSetAttrBool(op, "keep_dims", 1); - TFE_OpSetAttrType(op, "Tidx", TF_INT32); - TF_DeleteStatus(status); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); - - return op; -} - -// If there is a GPU device, returns true and sets 'gpu_device_name' -// accordingly. -bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - const int num_devices = TF_DeviceListCount(devices); - for (int i = 0; i < num_devices; ++i) { - const string device_type(TF_DeviceListType(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - const string device_name(TF_DeviceListName(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - if (device_type == "GPU") { - *gpu_device_name = device_name; - LOG(INFO) << "Found GPU device " << device_name; - TF_DeleteDeviceList(devices); - return true; - } - } - TF_DeleteDeviceList(devices); - return false; -} - void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -220,6 +108,103 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } +tensorflow::ServerDef GetServerDef(int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name("localhost"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name("localhost"); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +void TestRemoteExecute(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } +TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -436,7 +421,7 @@ void TensorHandleSilentCopy(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -483,7 +468,7 @@ void TensorHandleSilentCopyLocal(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -524,7 +509,7 @@ void SetAndGetOpDevices(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_OpSetDevice(matmul, "GPU:0", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); const char* device_name = TFE_OpGetDevice(matmul, status); @@ -588,7 +573,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m1 = TestMatrixTensorHandle(); - TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(); TFE_Op* matmul = MatMulOp(ctx, m1, m2); TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", status); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc new file mode 100644 index 00000000000000..5607c9dcb0bbec --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_test_util.h" + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +using tensorflow::string; + +TFE_TensorHandle* TestScalarTensorHandle() { + float data[] = {1.0f}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrBool(op, "transpose_a", 0); + TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +bool GetDeviceName(TFE_Context* ctx, string* device_name, + const char* device_type) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string dev_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string dev_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (dev_type == device_type) { + *device_name = dev_name; + LOG(INFO) << "Found " << device_type << " device " << *device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h new file mode 100644 index 00000000000000..474cae67c89249 --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include "tensorflow/core/platform/types.h" + +// Return a tensor handle containing a float scalar +TFE_TensorHandle* TestScalarTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle(); + +// Return a tensor handle containing a 3x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); + +// Return a tensor handle containing a 3x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle3X2(); + +// Return a matmul op multiplying `a` by `b`. +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return an 1-D INT32 tensor containing a single value 1. +TFE_TensorHandle* TestAxisTensorHandle(); + +// Return an op taking minimum of `input` long `axis` dimension. +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis); + +// If there is a device of type `device_type`, returns true +// and sets 'device_name' accordingly. +// `device_type` must be either "GPU" or "TPU". +bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, + const char* device_type); + +#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 97c323b8722803..734e712daa39c0 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -48,7 +48,7 @@ struct OpTapeEntry { // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. - std::function backward_function_deleter; + std::function backward_function_deleter; }; // Map from tensor_id to internally-defined operation-id of the operation which @@ -104,14 +104,12 @@ class VSpace { gtl::ArraySlice output_gradients, std::vector* result) const = 0; + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; - - // Lets this VSpace know that it can release resources held by the - // `backward_function`, It will not be called again. - // `backward_function` must not be null. - virtual void ReleaseBackwardFunction( - BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -126,19 +124,21 @@ class GradientTape { GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } - bool ShouldRecord(gtl::ArraySlice tensor_ids); + bool ShouldRecord(gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes); void Watch(int64 tensor_id); - void RecordOperation(const string& op_type, - gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, - BackwardFunction* backward_function, - const std::function& backward_function_deleter); + void RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -170,12 +170,32 @@ class GradientTape { // Template instantiations here +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + default: + return false; + } +} + template bool GradientTape::ShouldRecord( - gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; + gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes) { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + if (IsDtypeTrainable(dtypes[i])) { + return true; + } } } return false; @@ -189,10 +209,12 @@ void GradientTape::Watch(int64 tensor_id) { template void GradientTape::RecordOperation( const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, - const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { - backward_function_deleter(); + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { + backward_function_deleter(backward_function); return; } std::vector ids; @@ -247,7 +269,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { for (int64 id : op_it->second.input_tensor_id) { DeleteTrace(id); } - op_it->second.backward_function_deleter(); + op_it->second.backward_function_deleter(op_it->second.backward_function); op_tape_.erase(op_it); } @@ -332,8 +354,7 @@ BackpropInitialState PrepareBackprop( count_it->second++; } else { result.tensor_usage_counts[it] = 1; - if (sources_set.find(it) == sources_set.end() && - tensor_tape.find(it) != tensor_tape.end()) { + if (tensor_tape.find(it) != tensor_tape.end()) { tensor_stack.push_back(it); } } @@ -354,7 +375,8 @@ BackpropInitialState PrepareBackprop( // backward functions that will be used for gradient computation // has been transferred to `result`. for (const auto& op_pair : *op_tape) { - op_pair.second.backward_function_deleter(); + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); } op_tape->clear(); } @@ -380,49 +402,39 @@ Status InitialGradients(const VSpace& vspace, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, - const gtl::FlatMap& tensor_usage_counts, gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; - if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { - if (!output_gradients.empty() && output_gradients[i] != nullptr) { - // TODO(apassos) figure out how to print debugging information here. - return errors::InvalidArgument( - "A gradient was provided for a tensor which is used as part of the " - "computation."); - } - } else { - if (output_gradients.empty() || output_gradients[i] == nullptr) { - auto tensor_it = tensor_tape.find(id); - if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { - auto op_it = op_tape.find(tensor_it->second); - if (op_it == op_tape.end()) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "failed to find operation producing a tensor"); - } - bool found = false; - for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { - found = true; - (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); - break; - } - } - if (!found) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "none of operations outputs match expected tensor"); + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; } - } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); } } else { - (*result)[id].push_back(output_gradients[i]); + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. } + } else { + (*result)[id].push_back(output_gradients[i]); } } return Status::OK(); @@ -451,13 +463,12 @@ Status GradientTape::ComputeGradient( InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, - tensor_tape_, state.op_tape, - state.tensor_usage_counts, &gradients); + tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { // Release all backprop functions for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } }; @@ -509,10 +520,15 @@ Status GradientTape::ComputeGradient( } } else { any_gradient_nonzero = true; - out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + auto new_gradients = vspace.AggregateGradients(grad_it->second); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); } + out_gradients.push_back(new_gradients); } } std::vector in_gradients; @@ -520,7 +536,7 @@ Status GradientTape::ComputeGradient( Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } if (!s.ok()) { cleanup(); @@ -529,7 +545,7 @@ Status GradientTape::ComputeGradient( } else { in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } for (Gradient* grad : out_gradients) { if (grad != nullptr) { diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 02a6a58b6153bb..7184ad68fb79f2 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -15,10 +15,12 @@ # ============================================================================== TF_PREFIX='/usr/local' +LIBDIR='lib' usage() { echo "Usage: $0 OPTIONS" echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-l, --libdir\tset lib directory (default: lib)" echo -e "-v, --version\tset TensorFlow version" echo -e "-h, --help\tdisplay this message" } @@ -26,7 +28,7 @@ usage() { [ $# == 0 ] && usage && exit 0 # read the options -ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@") eval set -- "$ARGS" # extract options and their arguments into variables. @@ -38,6 +40,11 @@ while true ; do "") shift 2 ;; *) TF_PREFIX=$2 ; shift 2 ;; esac ;; + -l|--libdir) + case "$2" in + "") shift 2 ;; + *) LIBDIR=$2 ; shift 2 ;; + esac ;; -v|--version) case "$2" in "") shift 2 ;; @@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} -libdir=\${exec_prefix}/lib +libdir=\${exec_prefix}/${LIBDIR} includedir=\${prefix}/include Name: TensorFlow diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d73121c7b701ec..d6a4f141b6bb8c 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return name.ToString(); + return std::string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index c143b978338815..62a889181e787f 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -220,7 +220,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(s.ToString()); + current_constraints.insert(std::string(s)); } } } else { diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a8c88..35a01e0341cb08 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual"); REGISTER_NO_GRADIENT_OP("LogicalAnd"); REGISTER_NO_GRADIENT_OP("LogicalOr"); REGISTER_NO_GRADIENT_OP("LogicalNot"); +REGISTER_NO_GRADIENT_OP("Floor"); // Conjugate helper function returns the conjugate of an Output if it // is complex valued. diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1b4c7c2688083e..fd7b6fe6625f27 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -31,7 +31,6 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; -using ops::Greater; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -46,7 +45,6 @@ using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; using ops::Sum; -using ops::Where3; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 0cb3132e94e381..c73482d5f4d13a 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -255,6 +255,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb017fe..b4d457a9d14eb7 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 4ddddcb5863c9f..23e9dc40d23899 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include #include #include "tensorflow/core/framework/attr_value.pb.h" @@ -71,6 +72,15 @@ void GetNodeNameToNodeDefMap( } } +// Strips off the tensor part of the tensor_name to get the node_name. +const string GetNodeNameFromTensorName(string tensor_name) { + if (tensor_name[0] == '^') { + tensor_name.erase(0, 1); + } + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + return tensor_name_parts[0]; +} + // Gets the set of node names needed by `outputs` and the corresponding set of // variable nodes to convert. void GetReachableNodesAndVariables( @@ -83,10 +93,8 @@ void GetReachableNodesAndVariables( new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); std::queue nodes_to_visit; - for (const string& tensor_name : outputs) { - // We need to strip off the tensor part to get the node name. - std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); - nodes_to_visit.push(tensor_name_parts[0]); + for (const string& output_tensor_name : outputs) { + nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name)); } // We do a traversal backwards from the outputs specified in the MetaGraphDef. while (!nodes_to_visit.empty()) { @@ -100,8 +108,8 @@ void GetReachableNodesAndVariables( if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } - for (const string& input : node->input()) { - nodes_to_visit.push(input); + for (const string& input_tensor_name : node->input()) { + nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name)); } } } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index cd35fd3b95deec..979b23c3fc5f66 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -351,6 +351,56 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } +TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) { + // Tensors from operations with multiple outputs get tensor suffixes when used + // in input fields of following nodes, i.e. split:0, split:1. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2}); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output; + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), split[1], b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithControlDependency) { + // Inputs that are control dependencies get tensor prefixes, + // i.e. ^control_dependency. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output source = ops::Const(scope.WithOpName("source"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source), + {10.0f, 10.0f}, {2}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { TestFreezeGraphWithoutDependentVariables(false); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 19e6bf68e77725..2119c8ec47f941 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -214,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:support", "@llvm//:target", ], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2cae85e8965216..0025842aead539 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -333,6 +333,20 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" : ""; + const string include_hlo_profile_printer_data_proto = + opts.gen_hlo_profile_printer_data + ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")" + : ""; + + // When HLO profiling is disabled we only forward declare the + // HloProfilePrinter protobuf. So we can only conditionally emit this code + // calling HloProfilePrinter::profile_counters_size. + const string assign_profile_counters_size = + opts.gen_hlo_profile_printer_data + ? "data->profile_counters_size = " + "data->hlo_profile_printer_data->profile_counters_size();" + : ""; + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -348,6 +362,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) {{INCLUDE_XLA_DATA_PROTO}} +{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}} #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" @@ -418,6 +433,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); + data->hlo_profile_printer_data = StaticHloProfilePrinterData(); + {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); return *kStaticData; @@ -487,6 +504,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } + + // Metadata that can be used to pretty-print profile counters. + static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() { + static const xla::HloProfilePrinterData* kHloProfilePrinterData = + {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}; + return kHloProfilePrinterData; + } }; {{NS_END}} @@ -501,35 +525,41 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, + {"{{DECLS_FROM_OBJ_FILE}}", + str_util::Join(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, + {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", + metadata_result.hlo_profile_printer_data_access_shim}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, + {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}", + include_hlo_profile_printer_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", + metadata_result.program_shape_access_shim}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, - {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}, - {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, - {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", - metadata_result.program_shape_access_shim}}; + {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } -static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) { +static string CreateUniqueIdentifier(const CodegenOpts& opts, + StringPiece suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { strings::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape"); + strings::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -550,18 +580,31 @@ Status GenerateMetadata(const CodegenOpts& opts, // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives // a shim that evaluates to nullptr, which is what we want. + ProtobufToEmbed program_shape_protobuf{ + CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", + program_shape.get()}; + + ProtobufToEmbed hlo_profile_printer_data_protobuf{ + CreateUniqueIdentifier(opts, "HloProfilePrinterData"), + "xla::HloProfilePrinterData", + compile_result.aot->hlo_profile_printer_data()}; + TF_ASSIGN_OR_RETURN( - EmbeddedProtocolBuffer embedded_program_shape, - CreateEmbeddedProtocolBuffer(opts.target_triple, - CreateUniqueIdentifierForProgramShape(opts), - "xla::ProgramShape", program_shape.get())); + EmbeddedProtocolBuffers embedded_protobufs, + CreateEmbeddedProtocolBuffers( + opts.target_triple, + {program_shape_protobuf, hlo_profile_printer_data_protobuf})); metadata_result->program_shape_access_shim = - std::move(embedded_program_shape.cpp_shim_expression); + std::move(embedded_protobufs.cpp_shims[0].expression); + metadata_result->hlo_profile_printer_data_access_shim = + std::move(embedded_protobufs.cpp_shims[1].expression); + metadata_result->header_variable_decls.emplace_back( + std::move(embedded_protobufs.cpp_shims[0].variable_decl)); metadata_result->header_variable_decls.emplace_back( - std::move(embedded_program_shape.cpp_variable_decl)); + std::move(embedded_protobufs.cpp_shims[1].variable_decl)); metadata_result->object_file_data = - std::move(embedded_program_shape.object_file_data); + std::move(embedded_protobufs.object_file_data); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 3430b1f96cf4d3..83f2d3ee11d09d 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -44,6 +44,10 @@ struct CodegenOpts { // If true, generate program shape data for the ProgramShape method. bool gen_program_shape = false; + + // If true, emit a serialized HloProfilePrinterData protobuf that can be used + // to pretty print HLO profile counters. + bool gen_hlo_profile_printer_data = false; }; // Describes a generated metadata object file. @@ -57,6 +61,12 @@ struct MetadataResult { // GenerateMetadata. string program_shape_access_shim; + // hlo_profile_printer_data_access_shim is a C++ expression that constructs + // the xla::HloProfilePrinterData instance for the CompileResult passed to + // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a + // C++ expression that evaluates to nullptr at runtime. + string hlo_profile_printer_data_access_shim; + // The contents of the object (".o") file. string object_file_data; }; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 2642536c4f67eb..29bc9c13b889c8 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -172,7 +172,7 @@ TEST(CodegenTest, Golden) { fetch->set_name("myfetch"); CompileResult compile_result; compile_result.aot.reset( - new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5)); + new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index ac3b5873318873..6641d45e83020f 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -10,6 +10,7 @@ #define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) #include "tensorflow/compiler/xla/xla_data.pb.h" + #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" @@ -23,6 +24,7 @@ extern "C" void entry_point( extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; + namespace foo { namespace bar { @@ -54,9 +56,9 @@ namespace bar { // // Memory stats: // arg bytes total: 104 -// arg bytes aligned: 128 +// arg bytes aligned: 192 // temp bytes total: 126 -// temp bytes aligned: 224 +// temp bytes aligned: 320 class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. @@ -82,6 +84,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); + data->hlo_profile_printer_data = StaticHloProfilePrinterData(); + return data; }(); return *kStaticData; @@ -243,6 +247,13 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { }(); return kShape; } + + // Metadata that can be used to pretty-print profile counters. + static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() { + static const xla::HloProfilePrinterData* kHloProfilePrinterData = + nullptr; + return kHloProfilePrinterData; + } }; } // end namespace bar diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index e17a7c4bf67321..bbc35da2ef6d14 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -44,7 +44,7 @@ namespace { // Compiles the XLA computation into executable code. Status CompileXla(xla::CompileOnlyClient* client, - const xla::Computation& computation, + const xla::XlaComputation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::CompileOnlyClient::AotComputationInstance instance; + xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -93,14 +93,14 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::CompileOnlyClient* client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR( ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, + TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); - // Serialize the SessionModule deterministically so that all the outputs of - // a tf_library genrule are deterministic. + // Serialize the HloSnapshot deterministically so that all the outputs of a + // tf_library genrule are deterministic. string proto; TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); TF_RETURN_IF_ERROR( @@ -110,6 +110,7 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, flags.target_triple, flags.target_cpu, flags.target_features, flags.entry_point, xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic); + return CompileXla(client, computation, aot_opts, compile_result); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 0048eec93bbe10..4e27aafec77476 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -36,9 +36,8 @@ namespace tfcompile { using xla::llvm_ir::AsStringRef; -static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( - llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine, - const ::tensorflow::protobuf::MessageLite& proto, +static void AddEmbeddedProtocolBufferToLlvmModule( + llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, StringPiece unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); @@ -46,19 +45,14 @@ static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( strings::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); - std::unique_ptr module = - MakeUnique("embedded_data_module", *llvm_context); - llvm::Constant* protobuf_array_initializer = - llvm::ConstantDataArray::getString(*llvm_context, + llvm::ConstantDataArray::getString(module->getContext(), AsStringRef(protobuf_array_contents), /*AddNull=*/false); new llvm::GlobalVariable( *module, protobuf_array_initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); - - return module; } static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, @@ -88,7 +82,8 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, llvm::legacy::PassManager codegen_passes; if (target_machine->addPassesToEmitFile( - codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + codegen_passes, ostream, nullptr, + llvm::TargetMachine::CGFT_ObjectFile)) { return xla::InternalError( "Could not create pass pipeline to generate object file"); } @@ -115,42 +110,44 @@ GetTargetMachineFromTriple(StringPiece target_triple) { /*Features=*/"", llvm::TargetOptions(), llvm::None)); } -StatusOr CreateEmbeddedProtocolBuffer( - StringPiece target_triple, StringPiece symbol_prefix, - StringPiece qualified_cpp_protobuf_name, - const ::tensorflow::protobuf::MessageLite* proto) { +StatusOr CreateEmbeddedProtocolBuffers( + StringPiece target_triple, + gtl::ArraySlice protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; - string object_file, cpp_shim, cpp_variable_decl; - - if (proto) { - string protobuf_array_symbol_name; - int64 protobuf_array_size; - - std::unique_ptr module_with_serialized_proto = - CreateModuleWithEmbeddedProtocolBuffer( - &llvm_context, target_machine.get(), *proto, symbol_prefix, - &protobuf_array_symbol_name, &protobuf_array_size); - TF_ASSIGN_OR_RETURN(object_file, - CodegenModule(target_machine.get(), - std::move(module_with_serialized_proto))); - cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name, - protobuf_array_symbol_name, - protobuf_array_size); - - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); - } else { - TF_ASSIGN_OR_RETURN( - object_file, - CodegenModule(target_machine.get(), - MakeUnique("empty_module", llvm_context))); - cpp_shim = "nullptr"; + std::unique_ptr module_with_serialized_proto = + MakeUnique("embedded_data_module", llvm_context); + + EmbeddedProtocolBuffers result; + + for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) { + string cpp_shim, cpp_variable_decl; + if (protobuf_to_embed.message) { + string protobuf_array_symbol_name; + int64 protobuf_array_size; + + AddEmbeddedProtocolBufferToLlvmModule( + module_with_serialized_proto.get(), *protobuf_to_embed.message, + protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name, + &protobuf_array_size); + cpp_shim = CreateCPPShimExpression( + protobuf_to_embed.qualified_cpp_protobuf_name, + protobuf_array_symbol_name, protobuf_array_size); + + cpp_variable_decl = strings::StrCat("extern \"C\" char ", + protobuf_array_symbol_name, "[];"); + } else { + cpp_shim = "nullptr"; + } + result.cpp_shims.push_back({cpp_shim, cpp_variable_decl}); } - return {{cpp_shim, cpp_variable_decl, object_file}}; + TF_ASSIGN_OR_RETURN(result.object_file_data, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); + return result; } } // namespace tfcompile diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 8436e0ff67f352..4e194a6aba9a9e 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -21,51 +21,70 @@ limitations under the License. #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace tfcompile { using xla::StatusOr; -// Represents a protocol buffer embedded into an object file and describes a way -// to access it at runtime. -struct EmbeddedProtocolBuffer { - // cpp_shim_expression is a C++ expression that creates an instance of said - // protocol buffer when executed. - string cpp_shim_expression; - - // cpp_variable_decl is an "extern C" array declaration that is used in - // cpp_shim_expression. It must be visible wherever cpp_shim_expression is - // emitted. - string cpp_variable_decl; - - // The contents of the object (".o") file the protocol buffer is embbed in. - // This needs to be linked in to any program that wants to execute - // cpp_variable_decl . +// Represents a set of protocol buffers embedded into an object file and +// describes how to access them at runtime. +struct EmbeddedProtocolBuffers { + // Each instance CPPShim describes how to generate C++ code to instantiate a + // protobuf instance from the corresponding static data emitted into the + // object file. + struct CPPShim { + // `expression` is a C++ expression that creates an instance of said + // protocol buffer when executed. + string expression; + + // `variable_decl` is an "extern C" array declaration that is used in + // `expression`. It must be visible wherever `expression` is emitted. + string variable_decl; + }; + + // Each cpp_shim corresponds to one embedded protocol buffer. + std::vector cpp_shims; + + // The contents of the object (".o") file the protocol buffers are embbed in. + // This needs to be linked in to any program that wants to execute any of the + // expressions in `cpp_shims`. string object_file_data; }; -// Creates an object file that contains `proto`. -// -// `proto` is allowed to be nullptr, in which case the generated C++ shim -// expression is just `nullptr`, and the generated object file does not define -// any symbols. +// Describes a protocol buffer to embed into an object file. +struct ProtobufToEmbed { + // `symbol_prefix` is prefix that is guaranteed to be unique across the binary + // or DSO the generated object file will be linked into. + string symbol_prefix; + + // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ + // namespace qualified) protocol buffer name. This is only used in + // CPPShim::expression so relatively qualified names are fine as long as + // they're valid wherever CPPShim::expression is emitted. + string qualified_cpp_protobuf_name; + + // `message` is the protocol buffer to be embedded. It is allowed to be + // nullptr, in which case the generated C++ shim expression is just `nullptr`, + // and the generated object file does not define any symbols. + const ::tensorflow::protobuf::MessageLite* message; +}; + +// Embeds a sequence of protocol buffers into an object file. // // `target_triple` is the target triple for the target architecture for the // generated object file. // -// `symbol_prefix` is prefix that is guaranteed to be unique across the binary -// or DSO the generated object file will be linked into. -// -// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ -// namespace qualified) protocol buffer name. This needs is only used in -// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified -// names are fine as long as they're valid wherever cpp_shim_expression -// is emitted. -StatusOr CreateEmbeddedProtocolBuffer( - StringPiece target_triple, StringPiece symbol_prefix, - StringPiece qualified_cpp_protobuf_name, - const ::tensorflow::protobuf::MessageLite* proto); +// `protobufs_to_embed` describes the protocol buffers to embed into the +// resulting object file. The C++ shim for protobufs_to_embed[i] is +// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents +// of all the protocol buffers are embedded into a single .o file whose content +// is stored in the object_file_data field in the returned +// EmbeddedProtocolBuffers instance. +StatusOr CreateEmbeddedProtocolBuffers( + StringPiece target_triple, + gtl::ArraySlice protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h index d085864f0012e4..d1a669ceb17b9f 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/aot/runtime.h @@ -25,8 +25,8 @@ namespace tensorflow { namespace tfcompile { namespace runtime { -// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 32; +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +static constexpr size_t kAlign = 64; // aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 // values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index 6d603a02eb4cea..06ec623eb2dce5 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -24,7 +24,7 @@ namespace runtime { namespace { TEST(Runtime, AlignmentValue) { - // We've chosen 32 byte alignment for the tfcompile runtime to mimic the + // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 @@ -39,13 +39,13 @@ TEST(Runtime, AlignedBufferBytes) { EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192); + EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -101,11 +101,11 @@ TEST(Runtime, MallocFreeContiguousBuffers) { EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); - EXPECT_EQ(bufD[2], add_ptr(base, 32)); + EXPECT_EQ(bufD[2], add_ptr(base, 64)); EXPECT_EQ(bufD[3], nullptr); - EXPECT_EQ(bufD[4], add_ptr(base, 64)); - EXPECT_EQ(bufD[5], add_ptr(base, 128)); - EXPECT_EQ(bufD[6], add_ptr(base, 160)); + EXPECT_EQ(bufD[4], add_ptr(base, 128)); + EXPECT_EQ(bufD[5], add_ptr(base, 192)); + EXPECT_EQ(bufD[6], add_ptr(base, 256)); for (int i = 0; i < 7; ++i) { const intptr_t size = sizesD[i]; if (size != -1) { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index bb73cb19c57a65..0ecc3feeb6fef1 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -7,6 +7,10 @@ package( load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# We disable some tfcompile tests in the open source build with the +# "manual" tag to avoid making our OSS users build LLVM twice +# (once for host and once for target). + test_suite( name = "all_tests", tags = ["manual"], @@ -15,6 +19,7 @@ test_suite( ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", ":test_graph_tfassert_eq_test", + ":test_graph_tfcond_test", ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", @@ -55,6 +60,7 @@ genrule( "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", "test_graph_tfassert_eq.pb", + "test_graph_tfcond.pb", "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", @@ -118,6 +124,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfcond", + testonly = 1, + config = "test_graph_tfcond.config.pbtxt", + cpp_class = "CondComp", + graph = "test_graph_tfcond.pb", + tags = [ + "manual", + ], +) + tf_library( name = "test_graph_tffunction", testonly = 1, @@ -163,6 +180,15 @@ tf_library( tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) +tf_library( + name = "test_graph_tfmatmulandadd_with_profiling", + testonly = 1, + config = "test_graph_tfmatmulandadd.config.pbtxt", + cpp_class = "MatMulAndAddCompWithProfiling", + enable_xla_hlo_profiling = True, + graph = "test_graph_tfmatmulandadd.pb", +) + tf_library( name = "test_graph_tfsplits", testonly = 1, @@ -185,13 +211,18 @@ tf_cc_test( ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", ":test_graph_tfassert_eq", + ":test_graph_tfcond", ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", + ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_profile_printer", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 67767f55dae9b1..9ec7df163b1425 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir): f.write(saver.as_saver_def().SerializeToString()) +def tfassert_eq(_): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + control_flow_ops.Assert( + math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') + math_ops.add(x, math_ops.negative(y), name='x_y_diff') + + +def tfcond(_): + p = array_ops.placeholder(dtypes.bool, name='p_hold') + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + z = control_flow_ops.cond(p, lambda: x, lambda: y) + array_ops.identity(z, name='result') + + def tfgather(_): params = array_ops.placeholder(dtypes.float32, name='params') indices = array_ops.placeholder(dtypes.int32, name='indices') @@ -126,14 +142,6 @@ def tfsplits(_): array_ops.identity(y, name='result') -def tfassert_eq(_): - x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = array_ops.placeholder(dtypes.int32, name='y_hold') - control_flow_ops.Assert( - math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') - math_ops.add(x, math_ops.negative(y), name='x_y_diff') - - def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -148,12 +156,13 @@ def main(_): write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) + write_graph(tfassert_eq, FLAGS.out_dir) + write_graph(tfcond, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) - write_graph(tffunction, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) - write_graph(tfassert_eq, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt new file mode 100644 index 00000000000000..94a01ad4abfaab --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt @@ -0,0 +1,20 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "p_hold" } + shape {} +} +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "result" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 67dbd643bfc7bf..fee46280e9a0e7 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -21,19 +21,27 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h" #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace tfcompile { namespace { +using ::testing::HasSubstr; +using ::testing::IsSupersetOf; + TEST(TFCompileTest, Add) { AddComp add; EXPECT_EQ(add.arg0_data(), add.args()[0]); @@ -143,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); } +TEST(TFCompileTest, Cond) { + CondComp cond; + EXPECT_EQ(cond.arg0_data(), cond.args()[0]); + EXPECT_EQ(cond.arg1_data(), cond.args()[1]); + EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + cond.arg1() = 10; + cond.arg2() = 20; + { + cond.arg0() = true; + const int32 expected_result = cond.arg1(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } + { + cond.arg0() = false; + const int32 expected_result = cond.arg2(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } +} + TEST(TFCompileTest, Gather) { GatherComp gather; EXPECT_EQ(gather.arg0_data(), gather.args()[0]); @@ -484,6 +517,56 @@ TEST(TFCompileTest, ProgramShape) { EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); } +TEST(TFCompileTest, HloProfiling) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + MatMulAndAddCompWithProfiling fn; + ASSERT_TRUE(fn.hlo_profiling_enabled()); + + fn.set_thread_pool(&device); + + // x = [[1, 2], [3, 4]] + fn.arg0(0, 0) = 1; + fn.arg0(0, 1) = 2; + fn.arg0(1, 0) = 3; + fn.arg0(1, 1) = 4; + + // y = [[10, 20], [30, 40]] + fn.arg1(0, 0) = 10; + fn.arg1(0, 1) = 20; + fn.arg1(1, 0) = 30; + fn.arg1(1, 1) = 40; + + EXPECT_TRUE(fn.Run()); + + string hlo_profile_as_string = + xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), + /*clock_rate_ghz=*/1.0); + VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + + std::vector hlo_profile_lines = + tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + + auto header = HasSubstr("Execution profile for"); + auto total_cycles_profile_line = HasSubstr("[total]"); + auto dot_profile_line = HasSubstr( + "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); + auto add_profile_line = HasSubstr( + "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); + auto tuple_profile_line = HasSubstr( + "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " + "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); + auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + + EXPECT_THAT(hlo_profile_lines, + IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, + add_profile_line, tuple_profile_line})); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 3a877c5337ff76..5c57fee326ca74 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -25,7 +25,8 @@ def tf_library(name, graph, config, visibility=None, testonly=None, tfcompile_flags=None, tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, deps=None, tags=None): + include_standard_runtime_deps=True, + enable_xla_hlo_profiling=False, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. Given an invocation of tf_library(name="foo", ...), generates the following @@ -68,6 +69,8 @@ def tf_library(name, graph, config, include_standard_runtime_deps: If True, the standard list of kernel/runtime deps is added to deps. If False, deps must contain the full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, + and emit metadata that lets us pretty-print the gathered profile counters. deps: a list of deps to include on the build rules for the generated library, added to the standard deps if standard_runtime_deps is True. tags: tags to apply to subsidiary build rules. @@ -137,6 +140,10 @@ def tf_library(name, graph, config, flags = tfcompile_flags else: flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( name=("gen_" + name), srcs=[ @@ -157,7 +164,7 @@ def tf_library(name, graph, config, " --out_header=$(@D)/" + header_file + " --out_metadata_object=$(@D)/" + metadata_object_file + " --out_function_object=$(@D)/" + function_object_file + - " " + flags), + " " + flags + " " + profiling_flag), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -220,6 +227,8 @@ def tf_library(name, graph, config, ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the proto. "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data" ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 8ea014c2eede2c..839e1588b7be6c 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -100,6 +100,8 @@ Status Main(const MainFlags& flags) { if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } + codegen_opts.gen_hlo_profile_printer_data = + xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c8b4b05c6fb22c..0a10c97e74f320 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -23,6 +23,7 @@ package( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_gpu_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") @@ -134,7 +135,6 @@ cc_library( srcs = ["xla_tensor.cc"], hdrs = ["xla_tensor.h"], deps = [ - ":common", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -186,11 +186,12 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", "//tensorflow/core/kernels:variable_ops", - "@com_google_absl//absl/memory", ], ) @@ -227,6 +228,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", @@ -271,6 +273,7 @@ cc_library( name = "create_xla_launch_op", srcs = [ "create_xla_launch_op.cc", + "create_xla_launch_op.h", ], deps = [ ":common", @@ -280,6 +283,27 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "create_xla_launch_op_test", + srcs = [ + "create_xla_launch_op.h", + "create_xla_launch_op_test.cc", + ], + deps = [ + ":create_xla_launch_op", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -299,6 +323,7 @@ cc_library( ":common", ":shape_inference_helpers", ":union_find", + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", @@ -319,6 +344,18 @@ cc_library( ], ) +cc_library( + name = "xla_cluster_util", + srcs = ["xla_cluster_util.cc"], + hdrs = ["xla_cluster_util.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core/kernels:bounds_check", + ], +) + cc_library( name = "union_find", hdrs = ["union_find.h"], @@ -370,6 +407,63 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_launch_util_test", + size = "small", + srcs = ["xla_launch_util_test.cc"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_launch_util", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/kernels:variable_ops", + ], +) + +cc_library( + name = "xla_fusion_optimizer", + srcs = ["xla_fusion_optimizer.cc"], + hdrs = ["xla_fusion_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ], +) + +tf_gpu_cc_test( + name = "xla_fusion_optimizer_test", + srcs = ["xla_fusion_optimizer_test.cc"], + deps = [ + ":common", + ":xla_cluster_util", + ":xla_fusion_optimizer", + "//tensorflow/core:graph", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb000752755..b17ff589e2597f 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 18d901323f1085..731b8ebfdc6262 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/create_xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" @@ -21,82 +22,194 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { -// Givens a NodeDef 'ndef' and the function library runtime 'flr', if -// 'ndef' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. // -// This routine is here so that FunctionLibraryRuntime can jit a -// specific function call as requested. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr* kernel) { - bool xla_compile = false; - if (!flr->GetFunctionLibraryDefinition() - ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) - .ok() || - !xla_compile) { - // Not marked as _XlaCompile=true. - return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; } - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, ndef)) { - // ndef is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + + private: + int current_index_; + const std::vector* values_; +}; + +Status CompilationRequested(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { + bool xla_compile = false; + // Check if op is marked _XlaCompile=true. + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return Status(error::INVALID_ARGUMENT, ""); } + return Status::OK(); +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector* constant_arg_indices, + std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If ndef is not instantiable, e.g., the function does not exist, + // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); // Can't be nullptr since we just instantiated it. - std::vector const_args(fbody->arg_types.size()); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - // There is a const arg. Bail out. - return errors::InvalidArgument("Const arg: ", i, " in ", - DebugString(fbody->fdef)); + constant_arg_indices->push_back(i); + } + } + + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); } } - NodeDef launch_def; - launch_def.set_name(ndef.name()); - launch_def.set_op("_XlaLaunch"); - launch_def.set_device(flr->device()->name()); - AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); - AddNodeAttr("Nresources", 0, &launch_def); - AddNodeAttr("Targs", fbody->arg_types, &launch_def); - AddNodeAttr("Tresults", fbody->ret_types, &launch_def); - NameAttrList func; - func.set_name(ndef.op()); - *(func.mutable_attr()) = ndef.attr(); - AddNodeAttr("function", func, &launch_def); - - // TODO(b/32387911): Handles the host memory types across function - // calls properly. For now, we assume all inputs and outputs are on - // the device memory. + return Status::OK(); +} + +} // namespace + +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel) { + TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); + + VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, node_def)) { + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", + node_def.ShortDebugString()); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &launch_def, + dev->GetAllocator(AllocatorAttributes()), &node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - kernel->reset(new XlaLocalLaunchOp(&construction)); + + *kernel = MakeUnique(&construction, constant_arg_indices, + resource_arg_indices, function); return s; } +namespace { + bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h new file mode 100644 index 00000000000000..98a22e351532c1 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +// Given a NodeDef 'node_def' and the function library runtime 'flr', if +// 'node_def' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc new file mode 100644 index 00000000000000..b75ab486b80e09 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/create_xla_launch_op.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +// Create a FunctionDef that takes one resource and one regular param +FunctionDef XTimesY() { + return FunctionDefHelper::Define( + // Name + "XTimesY", + // Args + {"x: float", "y: resource"}, + // Return values + {"z: float"}, + // Attr def + {}, + // Nodes + { + {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, + {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, + }); +} + +class CreateXlaLaunchOpTest : public ::testing::Test { + protected: + void Init(const std::vector& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + + FunctionDefLibrary proto; + for (const auto& fdef : flib) { + *(proto.add_function()) = fdef; + } + lib_def_ = + MakeUnique(OpRegistry::Global(), proto); + OptimizerOptions opts; + device_mgr_ = MakeUnique(devices_); + pflr_ = MakeUnique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + } + + FunctionLibraryRuntime* flr_; + std::vector devices_; + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; + + std::unique_ptr kernel_; +}; + +AttrValue BoolAttr(bool b) { + AttrValue v; + v.set_b(b); + return v; +} + +TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + Init({fdef}); + + Status status = CreateXlaLaunchOp( + flr_, ToNodeDef(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), &kernel_); + ASSERT_TRUE(status.ok()) << status.ToString(); + + EXPECT_EQ("XTimesY", kernel_->name()); + EXPECT_EQ("XTimesY", kernel_->type_string()); + + EXPECT_EQ(2, kernel_->num_inputs()); + EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); + EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); + EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); + + EXPECT_EQ(1, kernel_->num_outputs()); + EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { + FunctionDef fdef = XTimesY(); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f06debaf316c01..6d1e3325ebd35b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -240,7 +240,7 @@ class Encapsulator { // Once edges between compiled and outside_compilation clusters have been // replaced by send/recv ops, some dependencies may no longer be apparent. // A clustering pass finds all the dependencies between HC nodes that are only - // present as a result of edges between nodes in outside_compilaton clusters. + // present as a result of edges between nodes in outside_compilation clusters. // Suppose there is a path from outside_compilation cluster C in subgraph S // to outside_compilation cluster D in subgraph T. If S != T then a control // edge is added from the call node for S to the call node for T, which diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be4409a38119..5fee36f022a751 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 5ec24d39a2c40a..eef113a3547f0b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -1050,7 +1050,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_outside", "O1")); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, shape2.opts()); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") @@ -1075,7 +1075,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", - {"D:o:0", "F:o:0"}, + {"F:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"ancestors", @@ -1123,13 +1123,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, b2.opts()); - Node* g = Binary(e, ops::NodeOut(recv2, 1), + Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") .WithControlInputs({recv2, e}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index bc68afb322b5cf..805bbc62c1e2e8 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { return reachable; } +bool GraphCycles::CanContractEdge(int32 a, int32 b) { + CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b; + RemoveEdge(a, b); + bool reachable = IsReachableNonConst(a, b); + // Restore the graph to its original state. + InsertEdge(a, b); + // If reachable, then contracting edge will cause cycle. + return !reachable; +} + bool GraphCycles::ContractEdge(int32 a, int32 b) { CHECK(HasEdge(a, b)); RemoveEdge(a, b); @@ -388,4 +398,8 @@ std::unordered_set GraphCycles::Successors(int32 node) { return rep_->nodes_[node]->out; } +std::unordered_set GraphCycles::Predecessors(int32 node) { + return rep_->nodes_[node]->in; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index d11d6e27b1b7bb..44448fa3d787d0 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -85,6 +85,9 @@ class GraphCycles { // and returns false. bool ContractEdge(int32 a, int32 b); + // Return true if can contract edge, otherwise return false. + bool CanContractEdge(int32 a, int32 b); + // Return whether dest_node is reachable from source_node // by following edges. bool IsReachable(int32 source_node, int32 dest_node) const; @@ -115,6 +118,7 @@ class GraphCycles { bool CheckInvariants() const; std::unordered_set Successors(int32 node); + std::unordered_set Predecessors(int32 node); // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index e47b782207e912..274f5938a1228b 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) { EXPECT_TRUE(g_.HasEdge(1, 4)); } +TEST_F(GraphCyclesTest, CanContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + EXPECT_FALSE(g_.CanContractEdge(1, 3)); + EXPECT_FALSE(g_.CanContractEdge(2, 4)); + EXPECT_TRUE(g_.CanContractEdge(1, 2)); + EXPECT_TRUE(g_.CanContractEdge(2, 3)); + EXPECT_TRUE(g_.CanContractEdge(3, 4)); +} + static void BM_StressTest(int iters, int num_nodes) { while (iters > 0) { tensorflow::GraphCycles g; diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 23997a67166918..8617dd41a96e41 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -39,15 +39,15 @@ limitations under the License. namespace tensorflow { -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx), device_type_(ctx->device_type()) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + device_type_(ctx->device_type()), + function_(function) { if (device_type_ == DeviceType(DEVICE_CPU)) { platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { @@ -59,8 +59,8 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) } } -Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -92,8 +92,8 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOp::Compute " +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -114,7 +114,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata; + const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); @@ -126,7 +126,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { } std::map variables = - SnapshotResourceVariables(ctx, num_resource_args_); + SnapshotResourceVariables(ctx, resources_); xla::LocalClient* client = static_cast(cache->client()); @@ -150,30 +150,32 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - // TODO(b/77671268): We don't set variable_representation_shape_fn here. This - // is restricted to Variables, but we need something like this to apply to - // normal Tensors too. + if (metadata) { + options.shape_representation_fn = metadata->shape_representation_fn(); + } const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; std::map constant_args; - for (int i = 0; i < num_constant_args_; ++i) { + for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + OP_REQUIRES_OK( + ctx, cache->Compile(options, function_, constant_args, variables, ctx, + &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context( - num_resource_args_, client, xla_allocator, allocate_xla_tensors); + XlaComputationLaunchContext launch_context(client, xla_allocator, + allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -196,14 +198,69 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Done"; } +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 8f8e646f0ff6d9..8dfc4b382d5115 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,6 +26,41 @@ limitations under the License. namespace tensorflow { +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache); + + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + DeviceType device_type_; + NameAttrList function_; + se::Platform::Id platform_id_; +}; + // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. @@ -35,26 +70,12 @@ namespace tensorflow { // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public OpKernel { +class XlaLocalLaunchOp : public XlaLocalLaunchBase { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; - void Compute(OpKernelContext* ctx) override; - private: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** compiler); - - DeviceType device_type_; - NameAttrList function_; - int num_constant_args_; - // Number of resource variable arguments. - int num_resource_args_; - - se::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8e2ee0f1d71bc1..74468266b9e983 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -41,11 +42,14 @@ limitations under the License. namespace tensorflow { -const char* const kXlaClusterAttr = "_XlaCluster"; -const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; - namespace { +// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward +// a ref tensor input to its output. +static bool AlwaysForwardsRefInput(const Node& node) { + return node.IsIdentity(); +} + bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by @@ -60,6 +64,26 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return false; } } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Not clustering " << node.def().ShortDebugString() + << " because of ref input " << incoming_node->name() << " " + << incoming_node->type_string(); + return false; + } + } + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } @@ -165,16 +189,6 @@ bool IsCompilableCall(const NodeDef& call_def, return true; } -// Returns the DeviceType corresponding to 'device'. -Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - // Tests whether `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), @@ -183,18 +197,11 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } -struct NodeCompare { - bool operator()(const Node* a, const Node* b) const { - return a->id() < b->id(); - } -}; -using OrderedNodeSet = std::set; - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // -// TODO(hpucha): Consider a black list instead of a white list as -// implemented below. +// TODO(hpucha): Remove this code since this functionality is subsumed by +// Grappler XlaFusionOptimizer. bool IsXlaFusable(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( @@ -364,7 +371,7 @@ Status FindCompilationCandidates( for (Node* node : graph.op_nodes()) { sorted_nodes.push_back(node); } - std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); for (Node* node : sorted_nodes) { VLOG(2) << "Fuel: " << fuel; @@ -379,9 +386,13 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + DeviceToDeviceType(node->assigned_device_name(), &device_type)); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + if (is_compilable_fn && !is_compilable_fn(node, device_type)) { + VLOG(2) << "Compilation rejected node: not compilable " << node->name() + << ": " << node->type_string(); + continue; + } const XlaOpRegistry::DeviceRegistration* registration; CHECK( @@ -430,46 +441,6 @@ struct Cluster { int representative = -1; }; -// Returns a string describing how an edge from src to dst would -// create a cycle. -string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, - int dst) { - int32 max_path_size = graph.num_node_ids() + 1; - std::vector path(max_path_size); - int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); - if (path_size == 0) { - return ""; - } - - auto node_name = [&cycles, &graph](int node_id) { - if (!FastBoundsCheck(node_id, graph.num_node_ids())) { - return string("(null)"); - } - auto* node = graph.FindNodeId(node_id); - if (node == nullptr) { - return string("(null)"); - } - return node->name(); - }; - - string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); - path.resize(path_size); - for (int32 node_id : path) { - string ascii_art; - if (node_id == dst) { - ascii_art = "+-> "; - } else if (node_id != src) { - ascii_art = "| "; - } else { - ascii_art = "+-- "; - } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); - } - return description; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -575,84 +546,13 @@ Status MarkForCompilationPass::RunImpl( : Env::Default(), is_compilable_fn, &compilation_candidates)); - GraphCycles cycles; - for (int i = 0; i < graph->num_node_ids(); ++i) { - // We rely on the node IDs in the cycle detection graph being consecutive - // integers starting from 0. - CHECK_EQ(i, cycles.NewNode()); + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - // Compute the loop structure of the graph. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); - - // The clustering code must avoid adding cycles to the graph to prevent - // deadlock. However, the graph may contain loops, which would trigger the - // cycle detection code. To handle loops, we alter the structure of the cycle - // detection graph, disconnecting each loop from the enclosing graph. - // Specifically, we: - // * add a new "frame" node for each loop. - // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges - // to/from the corresponding frame node. In essence, we collapse the loop - // into a single node for the purpose of cycle detection in the enclosing - // graph. - // * the body of the loop should now be disconnected from the rest of the - // graph; we make it acyclic by breaking loop backedges (edges outgoing from - // "NextIteration" nodes. - - // Map from frame name strings to node IDs in the cycle detection graph. - std::unordered_map frame_nodes; - - // Get the cycle graph node ID for frame 'frame_name', or add one if none - // exists. - auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { - int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; - if (frame_id < 0) { - // The emplace succeeded; we have not allocated a frame node yet. - frame_id = cycles.NewNode(); - } - return frame_id; - }; - - for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); - } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(src, edge->dst()->id())) { - return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); - } - // Drop the original edge. - continue; - } - if (edge->src()->IsNextIteration()) { - // Break loop back-edges. - continue; - } - if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { - // This should never happen. All cycles in the graph should contain - // a control flow operator. - return errors::Internal( - "Found cycle in graph without control flow operator during XLA " - "compilation: ", - DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); - } - } + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -670,6 +570,9 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). while (!worklist.empty()) { int from = worklist.front()->Get().representative; worklist.pop_front(); @@ -778,7 +681,7 @@ Status MarkForCompilationPass::RunImpl( // compilation. DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + DeviceToDeviceType(n->assigned_device_name(), &device_type)); const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 703d8825d74ced..772c92d369e67f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) { } } +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b43dab790..f2473d98ffd5da 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc new file mode 100644 index 00000000000000..70bd10336b824b --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -0,0 +1,161 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include + +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; + +namespace { +// Returns a string describing how an edge from src to dst would +// create a cycle. +string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, + int dst) { + int32 max_path_size = graph.num_node_ids() + 1; + std::vector path(max_path_size); + int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data()); + if (path_size == 0) { + return ""; + } + + auto node_name = [cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + + string description; + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); + path.resize(path_size); + for (int32 node_id : path) { + string ascii_art; + if (node_id == dst) { + ascii_art = "+-> "; + } else if (node_id != src) { + ascii_art = "| "; + } else { + ascii_art = "+-- "; + } + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + } + return description; +} +} // namespace + +Status DeviceToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles->NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles->NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + int dst = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(edge->src()->id(), dst)) { + return errors::Internal( + "Cycle detected when adding enter->frame edge: ", + DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + } + continue; + } + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + int src = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(src, edge->dst()->id())) { + return errors::Internal( + "Cycle detected when adding frame->exit edge: ", + DescribeCycle(cycles, *graph, src, edge->dst()->id())); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation: ", + DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 00000000000000..5b673bdc27fccb --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + +using OrderedNodeSet = std::set; + +// Returns the DeviceType corresponding to 'device'. +Status DeviceToDeviceType(const string& device, DeviceType* device_type); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6430975335f5ee..7ed609c4374806 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature( namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 60458f6f3314b2..b1943d3e1a7e32 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,13 +48,12 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable) { std::map variables = GetVariables(ctx); - int64 num_resource_args = variables.size(); xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - num_resource_args, client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), true); launch_context.PopulateInputs(ctx, result, variables); @@ -152,16 +151,18 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - DeviceType device_type = metadata.jit_device_type(); - options.device_type = &device_type; + options.device_type = metadata.jit_device_type(); options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.shape_representation_fn = metadata.shape_representation_fn(); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, - /*compile_options=*/nullptr); + result, executable, &compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 23c6f3903f841a..7cc3d0e007ba29 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { // An OpKernel that compiles an op to an XLA computation and runs it. Unlike -// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a // vanilla TensorFlow op as long as the bridge supports it. -// -// Importantly _XlaLaunch assumes all input and output tensors are on the host, -// whereas XlacompileOnDemandOp works with tensors in device memory. class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index bc07dbd7bdf005..43648402f65c65 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -53,7 +53,9 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device)); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index c814b7eb029054..ed007d603ea1b3 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { @@ -105,12 +106,33 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( return alloc_ptr; } +namespace { + +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + +} // namespace + /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, std::unique_ptr* device) { + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -129,17 +151,22 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice(options, attrs, device_ordinal, - DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal)); + device->reset(new XlaDevice( + options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + shape_representation_fn_(std::move(shape_representation_fn)), + padded_shape_fn_(std::move(padded_shape_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -170,17 +197,21 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice(const SessionOptions& options, - const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform, - bool transfer_as_literal) +XlaDevice::XlaDevice( + const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_(device_ordinal, platform, jit_device_name, + shape_representation_fn, padded_shape_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), - transfer_as_literal_(transfer_as_literal) { + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name; } @@ -230,10 +261,10 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { GetAllocator({}); // XlaDevice owns both gpu_device_info_ and // gpu_device_info_->default_context. - gpu_device_info_ = absl::make_unique(); + gpu_device_info_ = MakeUnique(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = - new XlaDeviceContext(stream, client(), transfer_as_literal_); + gpu_device_info_->default_context = new XlaDeviceContext( + stream, client(), transfer_as_literal_, shape_representation_fn_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -247,7 +278,8 @@ Status XlaDevice::FillContextMap(const Graph* graph, TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, + shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -260,11 +292,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When TraceMe profiling is off (which is the default), the - // following TraceMe constructor is simply a conditional test of - // false value. Measurements show that its overhead is negligible. - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + // When Xprof profiling is off (which is the default), constructing the + // activity is simple enough that its overhead is negligible. + tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -272,8 +303,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } @@ -295,7 +326,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_); + XlaTransferManager manager(stream, client(), transfer_as_literal_, + shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 3ae87308cc7cff..02e88ee6793e98 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,8 +17,7 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state -// is managed by XLA. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -27,6 +26,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -45,12 +45,19 @@ namespace tensorflow { class XlaDevice : public LocalDevice { public: + // Given a tensor, sets `xla::Shape*` the shape of tensor's representation + // on device, fully padded. On error, the contents of `xla::Shape*` + // are undefined. + typedef std::function PaddedShapeFn; + // Wrapper class to store metadata about the XlaDevice, where it can be // retrieved e.g., when lazily creating the XlaCompilationCache device. class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type); + const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn); // The index of the device on this host. int device_ordinal() const; @@ -58,11 +65,17 @@ class XlaDevice : public LocalDevice { se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } + const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + PaddedShapeFn padded_shape_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -76,16 +89,25 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. - static Status Create(const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - std::unique_ptr* device); - + // If padded_shape_fn is empty, a default implementation that returns + // the on-host shape is used. + static Status Create( + const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); + + // Creates a new XLA Device. + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal); + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -102,6 +124,7 @@ class XlaDevice : public LocalDevice { Tensor* tensor) override; xla::LocalClient* client() const; + const Metadata& metadata() { return xla_metadata_; } xla::StatusOr GetStream(); // If not already set, create and set GpuDeviceInfo. @@ -116,8 +139,8 @@ class XlaDevice : public LocalDevice { // The name of the device that is used to compile Ops for this XlaDevice. DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -126,6 +149,7 @@ class XlaDevice : public LocalDevice { // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // If set, holds default device context (that we must Unref) // and its stream. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index bf8c1886a02231..71e63b110b3b13 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -47,22 +47,33 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream, - xla::LocalClient* client, - bool transfer_as_literal) +XlaTransferManager::XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : stream_(stream), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal) {} + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(std::move(shape_representation_fn)) { + if (!shape_representation_fn_) { + shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { + return shape; + }; + } +} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { - xla::Literal literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal)); - VLOG(1) << "Transfer to device as literal: " << literal.ToString(); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + xla::BorrowingLiteral literal( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + << shaped_buffer.ToString(); return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, shaped_buffer); } @@ -75,8 +86,17 @@ Status XlaTransferManager::TransferLiteralFromDevice( TF_ASSIGN_OR_RETURN(std::unique_ptr literal, transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " + << shaped_buffer.ToString(); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + return errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return Status::OK(); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -89,16 +109,21 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << " " << reinterpret_cast( device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements(); + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); + + TensorShape shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), device_tensor->shape(), client_, + device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); if (!s.ok()) { done(s); @@ -106,12 +131,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } } - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - status = TransferLiteralToDevice(*cpu_tensor, device_tensor); + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. Status block_status = stream_->BlockHostUntilDone(); @@ -142,7 +173,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_tensor->tensor_data().data()) << " " << reinterpret_cast(cpu_tensor->tensor_data().data()) - << device_tensor->NumElements(); + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); const int64 total_bytes = cpu_tensor->TotalBytes(); se::DeviceMemoryBase dev_src_ptr = @@ -171,9 +204,47 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal) - : manager_(stream, client, transfer_as_literal) {} +void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + // TODO(phawkins): replace this code with an asynchronous implementation. + auto body = [&]() { + if (src_tensor.NumElements() == 0) { + return Status::OK(); + } + XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); + XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); + CHECK(xla_src && xla_dst) + << "Missing destination tensor for device-to-device copy"; + if (!xla_dst->has_shaped_buffer()) { + TensorShape shape = + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_RETURN_IF_ERROR( + xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, + stream_->parent()->device_ordinal())); + } + TF_RETURN_IF_ERROR( + xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const se::DeviceMemoryBase& from_buffer = + xla_src->shaped_buffer().buffers().element(index); + CHECK_EQ(buffer->size(), from_buffer.size()); + if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, + buffer->size())) { + return errors::Internal("Device to device memcpy failed"); + } + return Status::OK(); + })); + return Status::OK(); + }; + done(body()); +} + +XlaDeviceContext::XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(stream, client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -190,4 +261,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, done); } +void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index d7f5f1d2089892..ee346e5653bbf9 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -45,14 +46,19 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); + + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const { return stream_; } private: @@ -69,7 +75,8 @@ class XlaTransferManager { // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. - bool transfer_as_literal_; + const bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -77,8 +84,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -86,6 +94,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const override { return manager_.stream(); } private: diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index f68dba6b6a26c0..5ecb1afa7bcec9 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_ops.h" +#include + #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_tensor.h" namespace tensorflow { @@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { << type_string() << " on an XLA device. This should never happen."; } +XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); +} + +void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(), + errors::InvalidArgument( + "Variable and value dtypes don't match; respectively, ", + dtype_, " and ", context->input(1).dtype()), + done); + Var* variable = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + LookupOrCreateResource( + context, HandleFromInput(context, 0), &variable, + [this, context](Var** ptr) { + *ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + TF_RETURN_IF_ERROR(context->allocate_persistent( + dtype_, context->input(1).shape(), &unused, &tmp, attr)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + }), + done); + core::ScopedUnref s(variable); + + OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_)), + done); + + const Tensor& value = context->input(1); + AllocatorAttributes attr; + + // Copying is unnecessary if we are the last user of the value tensor, we can + // just adopt the input tensor's buffer instead. + std::unique_ptr input_alias = context->forward_input( + 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_, + value.shape(), DEVICE_MEMORY, attr); + mutex_lock ml(*variable->mu()); + variable->is_initialized = true; + if (input_alias) { + *variable->tensor() = *input_alias; + done(); + return; + } + + // Need to copy, but maybe we can re-use variable's buffer? + if (!XlaTensor::RefCountIsOne(*variable->tensor()) || + !variable->tensor()->shape().IsSameSize(value.shape())) { + // Copy to new buffer + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_persistent(dtype_, value.shape(), + &unused, &tmp, attr), + done); + *variable->tensor() = *tmp; + } + + XlaDeviceContext* device_context = + static_cast(context->op_device_context()); + + variable->Ref(); + device_context->CopyDeviceTensorToDevice( + value, variable->tensor(), [context, variable, done](Status status) { + variable->Unref(); + context->SetStatus(status); + done(); + }); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 498d25cf566a91..0c49286acd3aba 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,8 +23,10 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" #include "tensorflow/core/kernels/variable_ops.h" @@ -32,7 +34,7 @@ namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -40,8 +42,17 @@ class XlaDeviceDummyOp : public OpKernel { void Compute(OpKernelContext* ctx) override; }; +class XlaAssignVariableOp : public AsyncOpKernel { + public: + explicit XlaAssignVariableOp(OpKernelConstruction* c); + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + DataType dtype_; +}; + #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ @@ -63,13 +74,37 @@ class XlaDeviceDummyOp : public OpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ + IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ - ResourceHandleOp); + ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ + ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ + XlaAssignVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); \ + REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ + SwitchOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc new file mode 100644 index 00000000000000..96016521ea9022 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -0,0 +1,321 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +namespace tensorflow { + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "ShapeN" || + node.type_string() == "Rank" || node.type_string() == "Size"; +} + +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", + "StopGradient", /*"Snapshot",*/ + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + +Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) { + VLOG(2) << "Here at fusion optimizer"; + + // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. + // Once that happens, the expected interaction between this optimizer and when + // the global_jit_level is set is as follows: Fusion optimizer will replace + // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can + // be further compiled where possible via mark_for_compilation_pass. Note that + // this might lead to inefficient clustering, and it is best to use either the + // fusion optimizer or the global_jit flag, and not combine the two. + + // Create a Graph out of GraphDef. This is required currently because the + // helpers around clustering, encapsulation etc work on graphs. + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + Graph graph(function_library); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + shape_refiner.set_disable_constant_propagation(true); + ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, when + // colocation constraints are already validated previously and the device + // placement of nodes has also completed, so there is no need to validate + // colocation constraints again. + options.validate_colocation_constraints = false; + options.validate_shape = false; + TF_RETURN_IF_ERROR( + ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + + // Collect nodes that can be fused via XLA, while ignoring those that + // explicitly ask for XLA: (*) nodes that are marked to be compiled + // explicitly. (*) nodes assigned to XLA device. + OrderedNodeSet compilation_candidates; + for (Node* node : graph.op_nodes()) { + // If there is a _XlaCompile annotation, ignore the node if it is + // true. Nodes are marked with this attr via experimental_jit_scope, and + // will be handled by the mark_for_compilation pass. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok() && compile) { + continue; + } + // If there is already a _XlaCluster annotation, ignore the node. Nodes are + // marked with this attr to indicate they are already part of a cluster and + // hence ignored. + status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); + if (status.ok()) { + continue; + } + + // If there is an explicit XLA device placement, ignore the node. + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); + if (device_type.type_string().find("XLA") != string::npos) continue; + + // Assume all fusable ops are registered. + // TODO(hpucha): Check for registration if possible. + if (!IsXlaFusable(node->def())) { + continue; + } + + compilation_candidates.insert(node); + } + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + *output = item.graph; + return Status::OK(); + } + + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + + // TODO(hpucha): Make clustering more robust. There are two known issues that + // we need to mitigate: (a) Non-resource variables can cause deadlocks + // when clustering changes order of execution. See b/77263461 for a specific + // example. (b) Queue operations can also cause deadlocks. See b/77261498 for + // example. + + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + }; + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters(graph.num_node_ids()); + std::deque*> worklist; + for (Node* node : compilation_candidates) { + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); + worklist.push_back(&clusters[node->id()]); + } + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. This is a simplified + // version of the clustering in mark_for_compilation_pass that also deals with + // nodes that are explicitly tagged to be compiled/clustered. + while (!worklist.empty()) { + int from = worklist.front()->Get().representative; + worklist.pop_front(); + + Node* node_from = graph.FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + for (int to : cycles.Successors(from)) { + if (to >= graph.num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph.FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + // Do not cluster across devices. + if (node_from->def().device() != node_to->def().device()) { + VLOG(2) << "Devices " << node_from->def().device() << " " + << node_to->def().device(); + VLOG(2) << "Device names " << node_from->assigned_device_name() << " " + << node_to->assigned_device_name(); + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph.num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } + } + + const int min_cluster_size = 2; + int num_clusters = 0; + for (auto size : effective_cluster_sizes) { + if (size >= min_cluster_size) { + VLOG(3) << "Cluster " << num_clusters << " " << size; + num_clusters++; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + // Sequence number generator to ensure clusters have unique names. + static std::atomic cluster_sequence_num; + + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + + // Compile if this is a cluster of >= min_cluster_size compilable operators. + if (effective_cluster_sizes[cluster] >= min_cluster_size) { + string& name = cluster_names[cluster]; + + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + graph.ToGraphDef(output); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h new file mode 100644 index 00000000000000..3d2309e782d387 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { + +// Optimizes graphs by fusing ops where possible, resulting in more efficient +// execution. +class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { + public: + XlaFusionOptimizer() {} + ~XlaFusionOptimizer() override {} + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { + return Status::OK(); + } + + string name() const override { return "xla-fusion"; }; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) override; + + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, + const GraphDef& optimize_output, double result) override { + // Nothing to do for XlaFusionOptimizer. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc new file mode 100644 index 00000000000000..5736760a878dc8 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +class XlaFusionOptimizerTest : public grappler::GrapplerTest { + protected: + std::unordered_map GetClusters(const GraphDef& graph) { + std::unordered_map ids; + for (const NodeDef& node : graph.node()) { + string cluster; + if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node.name()] = cluster; + } + } + return ids; + } +}; + +TEST_F(XlaFusionOptimizerTest, Chains) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, FusableOps) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(clusters["C"], clusters["E"]); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp( + "Add", a, b, + builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + ops::UnaryOp("Cos", e, + builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, CompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ad4d44ce26224e..65aa5c35ede81f 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -50,7 +50,9 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, //XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 9e098c46f422b4..661187f4a873b0 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -51,7 +51,9 @@ Status XlaInterpreterDeviceFactory::CreateDevices( TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device)); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 2a7f04271d4b7e..d0c7a936512570 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -38,14 +38,13 @@ using xla::ScopedShapedBuffer; using xla::ShapedBuffer; } // anonymous namespace -std::map SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables) { std::map snapshot; - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { + for (int i : variables) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - OptionalTensor& tensor = snapshot[first_variable + i]; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -61,32 +60,35 @@ XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) XlaAllocator::~XlaAllocator() {} -xla::StatusOr XlaAllocator::Allocate( +xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { - void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size); + AllocationAttributes attrs; + attrs.no_retry_on_failure = !retry_on_failure; + void* data = + wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); if (data == nullptr) { return errors::ResourceExhausted("Out of memory while trying to allocate ", size, " bytes."); - } else { - return se::DeviceMemoryBase(data, size); } + return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), + device_ordinal, this); } -Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) { - wrapped_->DeallocateRaw(mem->opaque()); +Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { + wrapped_->DeallocateRaw(mem.opaque()); return Status::OK(); } -namespace { +namespace internal { // Return the 'index''th subtree of the given ShapedBuffer as a // ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the // subtree, and sets the input's buffer pointers to nullptr for the subtree. ScopedShapedBuffer ExtractSubShapedBuffer( ShapedBuffer* shaped_buffer, int index, xla::DeviceMemoryAllocator* allocator) { - xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape( + const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( shaped_buffer->on_host_shape(), index); - xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape( + const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( shaped_buffer->on_device_shape(), index); ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, @@ -98,20 +100,23 @@ ScopedShapedBuffer ExtractSubShapedBuffer( sub_shape_tree.CopySubtreeFrom(shape_tree, /*source_base_index=*/{index}, /*target_base_index=*/{}); - for (auto& index_to_buffer : shape_tree) { - if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) { - index_to_buffer.second = se::DeviceMemoryBase(nullptr, 0); - } - } + shape_tree.ForEachMutableElement( + [index](const xla::ShapeIndex& shape_index, + tensorflow::se::DeviceMemoryBase* data) { + // shape_index is empty for the root node. Ignore that. + if (!shape_index.empty() && shape_index[0] == index) { + *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); + } + }); return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); } -} // namespace +} // namespace internal +using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - int64 num_resource_args, xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) - : num_resource_args_(num_resource_args), - client_(client), + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} @@ -190,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES_OK( ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( - const_tensor.dtype(), const_tensor.shape(), - client_, stream->parent()->device_ordinal())); - } Device* device = dynamic_cast(ctx->device()); OP_REQUIRES(ctx, device != nullptr, @@ -236,7 +236,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } ++output_num; @@ -286,7 +286,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable->tensor() = output_tensor; } ++output_num; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 8a6ff3b0c75120..4390701ccbd0bc 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -31,15 +33,17 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back +// Takes a snapshot of the values of resource variable arguments, whose +// indices are specified in `variables` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // -// Returns a map of TensorFlow argument index to resource variable. -std::map SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); +// Returns a map of TensorFlow argument index to resource variable. If a +// resource variable is not initialized, the corresponding OptionalTensor +// will have its `present` field set to false. +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -48,9 +52,9 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { public: XlaAllocator(const se::Platform* platform, Allocator* wrapped); ~XlaAllocator() override; - xla::StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - Status Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // The Tensorflow BFC allocator used on GPU allows host-side deallocation // before GPU execution takes place. Tensorflow uses the ordering of the main @@ -72,7 +76,7 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. - XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors); @@ -92,7 +96,6 @@ class XlaComputationLaunchContext { const std::vector& arguments() const { return arg_ptrs_; } private: - int64 num_resource_args_; xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; @@ -140,6 +143,17 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; +// Exposed in this header file for microbenchmarking purposes, but this is an +// internal implementation detail. +namespace internal { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +xla::ScopedShapedBuffer ExtractSubShapedBuffer( + xla::ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator); +} // namespace internal + } // namespace tensorflow #endif diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc new file mode 100644 index 00000000000000..a45932403ec176 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains microbenchmarks for performance critical functions in +// xla_launch_util.cc. + +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, + /*index=*/fan_out / 2, + /*allocator=*/nullptr) + .release(); + } +} + +BENCHMARK(BM_ExtractSubBuffer) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index ce6456880bc1b3..3c44c4ae6df7f3 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { -/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) { +/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { if (tensor->NumElements() == 0) { return nullptr; } @@ -27,8 +27,8 @@ namespace tensorflow { return xla_tensor; } -/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { - return FromTensor(const_cast(tensor)); +/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) { + return tensor.RefCountIsOne(); } /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( @@ -52,20 +52,24 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); - xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(), - device_ordinal); - for (auto& index_to_buffer : buffer.buffers()) { + xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, + client->backend().memory_allocator(), + device_ordinal); + for (auto& index_to_buffer : shaped_buffer.buffers()) { xla::Shape subshape = xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(index_to_buffer.second, + TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false)); + // Move our buffer into shaped_buffer, which takes ownership of it. + index_to_buffer.second = buffer.Forget(); } - set_shaped_buffer(xla::ScopedShapedBuffer( - std::move(buffer), client->backend().memory_allocator())); + VLOG(4) << shaped_buffer.ToString(); + + set_shaped_buffer(std::move(shaped_buffer)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 922a9189731209..c54001a999998f 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -34,10 +34,9 @@ class XlaTensor { public: // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast // fails. - static XlaTensor* FromTensor(Tensor* tensor); - // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast - // fails. - static const XlaTensor* FromTensor(const Tensor* tensor); + static XlaTensor* FromTensor(const Tensor* tensor); + + static bool RefCountIsOne(const Tensor& tensor); // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in // which case the returned value is shaped_buffer()->root_buffer(), or a @@ -54,7 +53,7 @@ class XlaTensor { // Some Tensors can have complex on-device shapes, including tuple shapes. To // manage the memory for these tensors a ShapedBuffer may be required. - // Return true if this TensorInfo contains a ShapedBuffer. + // Return true if this XlaTensor contains a ShapedBuffer. bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } // Return the contained ShapedBuffer. // REQUIRES: has_shaped_buffer() @@ -62,7 +61,11 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } - // Mutates the TensorInfo to set the ShapedBuffer. + xla::ShapedBuffer& shaped_buffer() { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = xla::MakeUnique(std::move(shaped_buffer)); @@ -72,7 +75,7 @@ class XlaTensor { // in on-demand mode to avoid re-copying values from the device if we know the // host value already. - // Return true if this TensorInfo contains a host tensor. + // Return true if this XlaTensor contains a host tensor. bool has_host_tensor() const { return host_tensor_ != nullptr; } // Return the contained host tensor. // REQUIRES: has_host_tensor() diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f7bad39af082d2..20baa745af6d8e 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -93,7 +93,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -111,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -120,6 +120,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "categorical_op_test", size = "small", @@ -127,7 +140,7 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -141,7 +154,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -156,7 +169,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -170,7 +183,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -184,7 +197,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -196,9 +209,11 @@ tf_xla_py_test( name = "oom_test", size = "medium", srcs = ["oom_test.py"], + # TODO(b/80081500): Re-enable on GPU. Disabled on 2018-05-21. disabled_backends = [ "cpu", "cpu_ondemand", + "gpu", ], tags = [ # Allocates very large amounts of memory and does not work under TSAN. @@ -209,7 +224,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -225,7 +240,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -241,7 +256,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -263,7 +278,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -291,7 +306,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -300,10 +315,14 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -322,8 +341,12 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", ], ) @@ -338,7 +361,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -352,7 +375,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -364,7 +387,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -380,7 +403,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -395,12 +418,27 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], ) +tf_xla_py_test( + name = "listdiff_op_test", + size = "small", + srcs = ["listdiff_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform_test", + "@six_archive//:six", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -408,7 +446,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -423,7 +461,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -435,7 +473,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -449,7 +487,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -462,7 +500,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -475,7 +513,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -490,7 +528,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -507,7 +545,7 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -522,7 +560,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -538,7 +576,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -551,7 +589,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", ], ) @@ -563,7 +601,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -575,7 +613,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -590,7 +628,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -603,7 +641,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -618,7 +656,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -634,7 +672,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -647,7 +685,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -661,7 +699,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -680,7 +718,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -693,7 +731,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -707,7 +745,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -726,7 +764,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -745,7 +783,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -755,11 +793,12 @@ tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], + tags = ["noasan"], # times out, http://b/78599043 deps = [ ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -771,7 +810,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -784,21 +823,34 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) -gpu_py_test( +tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +gpu_py_test( + name = "xla_device_gpu_test", + size = "small", + srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], ) @@ -815,15 +867,23 @@ gpu_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], - # TODO(b/62961789): Test fails with SIGABRT - tags = [ - "manual", - "notap", +) + +gpu_py_test( + name = "dense_layer_test", + size = "small", + srcs = ["dense_layer_test.py"], + additional_deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", + "//tensorflow/python:variables", ], ) @@ -866,7 +926,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -881,7 +941,7 @@ gpu_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -919,7 +979,7 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -931,7 +991,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index ec547e16cd9c91..9d3a889b1f54c8 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -29,51 +29,70 @@ class ArgMinMaxTest(xla_test.XLATestCase): - def _assertOpOutputMatchesExpected(self, op, inp, expected): - """Verifies that 'op' produces 'expected' when fed input 'inp' . + def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input, + expected): + """Verifies that 'op' produces 'expected' when fed input 'op_input' . Args: - op: operator to test - inp: numpy input array to use as input to 'op'. + op: argmin or argmax operator to test. + axis: integer axis to reduce across. + output_type: numpy datatype of the output to produce. + op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ with self.test_session() as session: with self.test_scope(): pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name="a") - output = op(pinp) - result = session.run(output, {pinp: inp}) + dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") + output = op(pinp, axis=axis, output_type=output_type) + result = session.run(output, {pinp: op_input}) self.assertAllEqual(result, expected) def testArgMinMax(self): # Complex numbers do not support argmin/argmax. minmax_types = set(self.numeric_types) - set(self.complex_types) for dtype in minmax_types: - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([1, 10, 27, 3, 3, 4], dtype=dtype), - expected=np.int32(2)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([0, 1, 0], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([0, 0], dtype=np.int32)) + # output_type is a numpy data type that is used to specify the desired + # output type of the op as well as to convert the Python number to the + # array scalar of the type. + for output_type in self.int_types: + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype), + expected=output_type(2)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([0, 1, 0], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([0, 0], dtype=output_type)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([3, 10, 27, 3, 2, 4], dtype=dtype), - expected=np.int32(4)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([1, 0, 1], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([1, 1], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype), + expected=output_type(4)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([1, 0, 1], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([1, 1], dtype=output_type)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py new file mode 100644 index 00000000000000..fde9759a1c2098 --- /dev/null +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bucketize_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BucketizationOpTest(XLATestCase): + + def testInt(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual(expected_out, + sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) + + def testFloat(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual( + expected_out, + sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) + + def test2DInput(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + self.assertAllEqual( + expected_out, sess.run(op, + {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + + def testInvalidBoundariesOrder(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Expected sorted boundaries"): + sess.run(op, {p: [-5, 0]}) + + def testBoundariesNotList(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Expected list.*"): + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + math_ops._bucketize(p, boundaries=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py new file mode 100644 index 00000000000000..865f60ccab46ec --- /dev/null +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DenseLayer JIT compilation on the CPU and GPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.compiler import jit +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +jit_scope = jit.experimental_jit_scope + + +def GetRunMetadataLabels(run_metadata): + """Returns all labels in run_metadata.""" + labels = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + labels.append(node_stats.timeline_label) + return labels + + +def InLabels(labels, substr): + """Returns true iff one of the labels contains substr.""" + return any([substr in x for x in labels]) + + +def XlaLaunchOpCount(labels): + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) + + +class DenseLayerTest(test.TestCase): + + def testDenseLayerAutoJit(self): + """Tests dense layer compilation in auto-jit mode. + + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + """ + + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = ( + config_pb2.OptimizerOptions.ON_1) + + with self.test_session(config=config) as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + def testDenseLayerJitScopeDefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer with static shape input tensor should be compiled into a single + XlaLaunch op by XLA. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + # No need to check whether ListDiff is compiled or not because ListDiff op + # is not used when input tensor shape is fully defined. + + def testDenseLayerJitScopeUndefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer uses shape op to get shape of input tensor if its shape is not + fully defined. XLA does not cluster shape op with other operators. But in + experimental_jit_scope, XLA is forced to compile shape op into its own + cluster, causing dense layer to be split into TWO XlaLaunch ops. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 0a0d335ca76dd7..03d96a2cd8ab22 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -153,7 +153,7 @@ def _VerifyValues(self, dtype=data_type).reshape(filter_in_sizes) with self.test_session() as sess: if data_type == np.float32: - tolerance = 1e-5 + tolerance = 1e-4 else: self.assertEqual(data_type, np.float64) tolerance = 1e-8 @@ -339,7 +339,7 @@ def _GetVal(use_xla): gpu_value = _GetVal(use_xla=True) cpu_value = _GetVal(use_xla=False) - self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4) + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-3) def testDepthwiseConv2DInputGradCompare(self): for index, (input_size, filter_size, output_size, stride, diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index bdd0185dfe4abe..4dff5f0f405fb1 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -24,10 +24,16 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest @@ -43,7 +49,7 @@ def testBasic(self): def testExecuteListOutputLen0(self): with self.test_scope(): - empty = constant_op.constant([], dtype=dtypes.int32) + empty = constant_op.constant([], dtype=dtypes.float32) result = array_ops.unstack(empty, 0) self.assertTrue(isinstance(result, list)) self.assertEqual(0, len(result)) @@ -51,7 +57,7 @@ def testExecuteListOutputLen0(self): def testExecuteListOutputLen1(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 1, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) @@ -60,7 +66,7 @@ def testExecuteListOutputLen1(self): def testExecuteListOutputLen3(self): with self.test_scope(): split_dim = constant_op.constant(1) - value = constant_op.constant([[0, 1, 2], [3, 4, 5]]) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) result = array_ops.split(value, 3, axis=split_dim) self.assertTrue(isinstance(result, list)) self.assertEqual(3, len(result)) @@ -111,6 +117,15 @@ def testAssignAddVariable(self): v.assign_add(2.0) self.assertEqual(3.0, v.numpy()) + def testReadAssignRead(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + val1 = v.read_value() + v.assign_add(2.0) + val2 = v.read_value() + self.assertEqual(1.0, val1.numpy()) + self.assertEqual(3.0, val2.numpy()) + def testGradient(self): def f(x): return x @@ -130,8 +145,189 @@ def f(): grads = backprop.implicit_grad(f)() self.assertEqual(2., grads[0][0].numpy()) + def testMultipleVariableReads(self): + # This test makes sure consecutive variable reads don't copy + # the underlying memory. + with self.test_scope(): + # Create 128MiB variables + var = resource_variable_ops.ResourceVariable( + array_ops.ones([32, 1024, 1024])) + + # Read the same variable 100 times. If the underlying tensor + # is not copied, this is a trivial operation. If it is copied, + # this will eat over 13GB and OOM. + values = [] + for _ in range(100): + values.append(var.value()) + + +class EagerFunctionTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + matmul = function.defun(math_ops.matmul, compiled=True) + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t, transpose_a=True) + self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) + + def testConv(self): + if 'GPU' in self.device: + # TODO(b/32333178) + self.skipTest('Current implementation of RandomStandardNormal kernel ' + 'is very slow on GPU, and has been blacklisted.') + with self.test_scope(): + data_format = 'channels_last' + conv = convolutional.Conv2D( + filters=1, kernel_size=2, padding='VALID', + data_format=data_format, activation=nn_ops.relu, + kernel_initializer=init_ops.ones_initializer(), + bias_initializer=init_ops.zeros_initializer()) + pool = pooling.MaxPooling2D(2, 2, data_format=data_format) + + def model(x): + x = conv(x) + return pool(x) + model = function.defun(model, compiled=True) + + x = array_ops.ones([1, 4, 4, 1]) + y = model(x) + self.assertAllEqual(y.numpy(), [[[[4.]]]]) + + def testReadVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun(compiled=True) + def f(): + return v.read_value() + + var = f() + self.assertEqual(1.0, var.numpy()) + + def testUpdateVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + def f(v): + v.assign_add(1.0) + return v + + f = function.defun(f, compiled=True) + + var = f(v) + self.assertEqual(2.0, var.numpy()) + + def testAllArgumentKinds(self): + """Test a complex function that takes different argument kinds. + + tf2xla machinery that translates, compiles, and runs defuns + classifies arguments into: compile-time constants, regular tensors, + and resources. This test creates a function with a mix of all these + kinds. Moreover, the order of function arguments is intentionally mixed up. + + This also tests the case when the same argument is a compile-time constant + as well as used in an operation that normally expects its inputs to be + in device memory - addition in this case. + """ + with self.test_scope(): + def foo(c1, r1, v1, c2, v2, r2): + # c1 and c2 are compile-time constants + # r1 and r2 are regular tensors + # v1 and v2 are resource variables + a = c1 + r1 + b = math_ops.cast(c2, dtypes.float32) + v2 + c = array_ops.slice(v1, c1, c2) + d = r2 * v2 + return a, b, c, d + + foo = function.defun(foo, compiled=True) + + c1 = [0, 0] + c2 = array_ops.ones([2], dtype=dtypes.int32) + + r1 = array_ops.ones([2]) + r2 = [[2., 2.], [3., 3.]] + + v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) + v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) + + a, b, c, d = foo(c1, r1, v1, c2, v2, r2) + + self.assertAllEqual([1, 1], a.numpy()) + self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) + self.assertAllEqual([[1.]], c.numpy()) + self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + + def testDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun(compiled=True) + def f(x): + x = v0 * v0 * x + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + +class ExcessivePaddingTest(XLATestCase): + """Test that eager execution works with TPU flattened tensors. + + Tensors that would normally be excessively padded when written + to TPU memory are reshaped to 1-D flat tensors. + + This test case verifies that such tensors work with eager execution. + + The flattening currently only happens on TPU, but tests should work + fine with all backends as flattening is transparent. + """ + + def testFromConstant(self): + with self.test_scope(): + # Create constant of shape [100, 2, 1]. This tensor would be + # excessively padded on TPU. + tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) + # Use reduce_sum since it requires correctly working with + # a particular dimension. + reduced = math_ops.reduce_sum(tensor, axis=1) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testFromOperation(self): + with self.test_scope(): + tensor = array_ops.ones([3, 100, 2, 2]) + reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) + self.assertAllEqual(100 * [12.0], reduced) + + def testAsFunctionInput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return math_ops.reduce_sum(x, axis=2) + + tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) + reduced = f(tensor) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testAsFunctionOutput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return x * constant_op.constant(100 * [[[10.0, 2.0]]]) + + y = f(3) + reduced = math_ops.reduce_sum(y, axis=2) + self.assertAllEqual(100 * [[36.0]], reduced) + -if __name__ == "__main__": +if __name__ == '__main__': ops.enable_eager_execution( config=config_pb2.ConfigProto(log_device_placement=True)) googletest.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index fbc3c994d163a5..8a3f4b0bdc7a61 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,12 +24,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 42e637734c578f..7cf953ef25ef5d 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -65,9 +65,7 @@ def testBatch(self): join1 = array_ops.stack(split1) join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], - { - batch0: inp - }) + {batch0: inp}) # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) @@ -401,9 +399,7 @@ def testAdjustRandomSaturation(self): x = array_ops.placeholder(dtypes.float32, shape=x_shape) with self.test_scope(): y_fused = self._adjust_saturation(x, - scale).eval(feed_dict={ - x: x_np - }) + scale).eval(feed_dict={x: x_np}) self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) @@ -412,7 +408,8 @@ class ResizeBilinearTest(XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, target_shape, - expected=None): + expected=None, + large_tolerance=False): if expected is None: self.fail("expected must be specified") with self.test_session() as sess, self.test_scope(): @@ -420,7 +417,11 @@ def _assertForwardOpMatchesExpected(self, resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=True) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) def _assertBackwardOpMatchesExpected(self, grads_np, @@ -555,6 +556,28 @@ def testAlignCorners3x3To9x9Grad(self): [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], dtype=np.float32)) + def testAlignCorners4x4To8x8(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array( + [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8], + expected=3 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)), + large_tolerance=True) + + def testAlignCorners8x8To16x16(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, + [16, 16], + expected=7 * (np.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), + large_tolerance=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 1f7da659e5590b..6e0db54b7a74b2 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -78,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: @@ -125,7 +125,7 @@ def _compare(self, fn, args, require_kernel_launch=True, noinline=None): for (x, y) in zip(compiled, direct): self.assertAllClose(x, y, rtol=1e-1) else: - self.assertAllClose(compiled, direct) + self.assertAllClose(compiled, direct, rtol=1e-2) def testNoOutputs(self): with session_lib.Session() as sess: @@ -441,14 +441,14 @@ def Forward(x): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaLaunch")) class ElementWiseFusionTest(test.TestCase): @@ -482,14 +482,15 @@ def simpleTest(self, arg0, arg1, global_jit_level): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("_XlaLaunch(" in x for x in labels) + count = sum("XlaLaunch(" in x for x in labels) return output, count def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true" + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit") tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py new file mode 100644 index 00000000000000..45a04f0cf56e88 --- /dev/null +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA listdiff operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ListDiffTest(xla_test.XLATestCase): + + def _testListDiff(self, x, y, out, idx): + for dtype in [dtypes.int32, dtypes.int64]: + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + with self.test_scope(): + out_tensor, idx_tensor = array_ops.listdiff( + x_tensor, y_tensor, out_idx=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(out, tf_out) + self.assertAllEqual(idx, tf_idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) + + def testBasic1(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3]) + + def testBasic2(self): + self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3]) + + def testBasic3(self): + self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2]) + + def testDuplicates(self): + self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1], + y=[4, 2], + out=[1, 3, 3, 3, 1], + idx=[0, 3, 5, 6, 7]) + + def testRandom(self): + num_random_tests = 10 + int_low = -7 + int_high = 8 + max_size = 50 + for _ in xrange(num_random_tests): + x_size = np.random.randint(max_size + 1) + x = np.random.randint(int_low, int_high, size=x_size) + y_size = np.random.randint(max_size + 1) + y = np.random.randint(int_low, int_high, size=y_size) + out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y] + if out_idx: + out, idx = map(list, zip(*out_idx)) + else: + out = [] + idx = [] + self._testListDiff(list(x), list(y), out, idx) + + def testFullyOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[]) + + def testNonOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], + y=[5, 6], + out=[1, 2, 3, 4], + idx=[0, 1, 2, 3]) + + def testEmptyX(self): + self._testListDiff(x=[], y=[1, 2], out=[], idx=[]) + + def testEmptyY(self): + self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3]) + + def testEmptyXY(self): + self._testListDiff(x=[], y=[], out=[], idx=[]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py index 1434e965e3d7ea..d68d32057a3677 100644 --- a/tensorflow/compiler/tests/oom_test.py +++ b/tensorflow/compiler/tests/oom_test.py @@ -22,6 +22,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest @@ -42,20 +44,33 @@ def testOutputOutOfMemory(self): """ def test_loop(): - size = 2e8 + size = int(2e8) while True: with self.test_session(): - # Force the compiled code to not be constant by feeding in an addend. - p = array_ops.placeholder(dtypes.float32, shape=[]) + # Force the compiled code to not be constant by feeding in a + # parameter. + p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) with self.test_scope(): - # Create a large R1 tensor. - c = array_ops.zeros([size, 1]) + p + # Create a computation that produces a large R1 tensor as an + # intermediate result. Reduce it down so that if this file was + # compiled without --config=cuda, we don't force a D2H copy of a + # large tensor and potentially OOM the host. + # + # This is a bit tricky because XLA:GPU doesn't currently support RNG + # ops. Here we rely on the fact that XLA doesn't do algebraic + # simplifications on conv(, ). + c = math_ops.reduce_sum( + nn_ops.convolution( + array_ops.ones([1, size, 1]), + p, + padding='SAME', + data_format='NWC')) - c.eval(feed_dict={p: 1.0}) + c.eval(feed_dict={p: [[[1.0]], [[2.0]]]}) size *= 2 self.assertRaises(errors.ResourceExhaustedError, test_loop) -if __name__ == "__main__": +if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index d6c93088d4efff..70be22936a500f 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -76,6 +76,13 @@ def testRandomUniformIsInRange(self): self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsNotConstant(self): + def rng(dtype): + return random_ops.truncated_normal(shape=[2], dtype=dtype) + + # TODO(b/34339814): implement inverse erf support for non-F32 types. + self._testRngIsNotConstant(rng, dtypes.float32) + def testTruncatedNormalIsInRange(self): count = 10000 # TODO(b/34339814): implement inverse erf support for non-F32 types. diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e53efc3091d893..16f293891d56d7 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -619,8 +619,8 @@ std::vector OpTest::ImageDims(TensorFormat format, int batch, dims.push_back(dim); } break; - case FORMAT_NCHW_VECT_C: - LOG(FATAL) << "FORMAT_NCHW_VECT_C not supported."; + default: + LOG(FATAL) << "Tensor format " << ToString(format) << " not supported."; } return dims; } diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 2c084b04fa2f67..7420724bdbeab6 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import functools +import itertools import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -155,5 +156,68 @@ def testReduceAny(self): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) +class ReduceOpPrecisionTest(XLATestCase): + + def _testReduceSum(self, + expected_result, + dtype, + test_inputs, + rtol=1e-3, + atol=1e-4): + """Tests reduce sum on a list of input arrays. + + For each array in test_inputs, check that performing reduce sum on the array + produces a value that is close to the expected result. + + Args: + expected_result: the expected result. + dtype: the data type of the reduce sum operation. + test_inputs: a list of input arrays for the reduce sum operation. + rtol: the relative error. + atol: the absolute error. + """ + + for test_input in test_inputs: + with self.test_session() as sess: + with self.test_scope(): + a = array_ops.placeholder(dtype) + index = array_ops.placeholder(dtypes.int32) + out = math_ops.reduce_sum(a, index) + result = sess.run(out, { + a: np.array(test_input, dtype=dtype), + index: [0] + }) + # Compare the results using float32 type. + self.assertAllClose( + np.float32(result), + np.float32(expected_result), + rtol=rtol, + atol=atol) + + def testReduceSumF16(self): + """Tests the reduce sum of float16 doesn't lose too much precision.""" + + if np.float16 not in self.all_types: + return + + f16_max = np.finfo(np.float16).max + self._testReduceSum( + f16_max, np.float16, + itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3)) + + def testReduceSumBF16(self): + """Tests the reduce sum of bfloat16 doesn't lose too much precision.""" + + if dtypes.bfloat16.as_numpy_dtype not in self.all_types: + return + + bf16_max = np.float32(dtypes.bfloat16.max) + f32_max = dtypes.float32.max + value = min(bf16_max, f32_max - bf16_max) + self._testReduceSum( + dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, + itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 4336ebdbd184a0..b6f8390a45d43b 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,6 +86,15 @@ def testDistributionOfStatelessRandomUniform(self): # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) + def testRandomNormalIsFinite(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[10000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(np.isfinite(y))) + def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 7624d6e4b2e2ec..f332aa2e9b97e1 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -472,7 +472,9 @@ def _testTensorArrayGradientWriteReadType(self, dtype): self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) def testTensorArrayGradientWriteRead(self): - for dtype in self.numeric_types: + for dtype in self.float_types: + self._testTensorArrayGradientWriteReadType(dtype) + for dtype in self.complex_types: self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ba79f393a8f9b2..689a4a1f4e02f5 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -209,7 +209,8 @@ def testFloatOps(self): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype)) + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -251,12 +252,12 @@ def testFloatOps(self): np.array([[1, 2]], dtype=dtype), expected=np.array([[0.540297, -0.41614]], dtype=dtype)) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.rint, @@ -333,13 +334,19 @@ def testFloatOps(self): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, @@ -419,7 +426,9 @@ def testComplexOps(self): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), - expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)), + rtol=1e-6, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.reciprocal, @@ -441,13 +450,13 @@ def testComplexOps(self): np.array([[5j, 3 - 2j]], dtype=dtype), expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) self._assertOpOutputMatchesExpected( @@ -789,7 +798,9 @@ def _assertSoftplusMatchesExpected(self, features, dtype): zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected) + nn_ops.softplus, features, expected=expected, + rtol=1e-6, + atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 8ecad00f6e23b3..2c09b03d5a35cd 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -187,6 +187,25 @@ def testTraining(self): rtol=1e-4) self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) + def testWriteOfAliasedTensor(self): + for dtype in self.numeric_types: + init = np.array([[1, 2j], [3, 4]]).astype(dtype) + update = np.array([[7, 1j], [2, 11]]).astype(dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + q = array_ops.identity(p) + x = v.read_value() + # Writes the value of 'p' to 'v', but keeps a reference to the original + # value of 'v' so the variable update cannot reuse its buffer. + with ops.control_dependencies([x]): + y = v.assign(q) + result = sess.run([x, y, q], {p: update}) + self.assertAllClose(init, result[0]) + self.assertAllClose(update, result[1]) + self.assertAllClose(update, result[2]) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py new file mode 100644 index 00000000000000..1e30ebd55d09fe --- /dev/null +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for XLA devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class XlaDeviceGpuTest(test.TestCase): + + def testCopiesToAndFromGpuWork(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f5c228f8305d74..f0b010fa67f2ff 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,30 +18,40 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test -class XlaDeviceTest(test.TestCase): +class XlaDeviceTest(XLATestCase): def testCopies(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return - - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) + """Tests that copies onto and off XLA devices work.""" + shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], + [16384, 1], [1, 16384], [1, 20000, 1, 1]] + for dtype in self.numeric_types: + for shape in shapes: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(dtype, shape) + with self.test_scope(): + y = x + x + with ops.device("CPU"): + z = array_ops.identity(y) + + inputs = np.random.randint(-100, 100, shape).astype(dtype) + result = sess.run(z, {x: inputs}) + self.assertAllCloseAccordingToType(result, inputs + inputs) + + def testControlTrigger(self): + with self.test_session() as sess: + with self.test_scope(): + x = gen_control_flow_ops.control_trigger() + sess.run(x) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 382d3d1aa9cf27..b75bfe96073e6f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -82,7 +82,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -171,9 +171,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -218,7 +218,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -329,6 +328,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 4f8bb8ad743afe..ea8d1b3d14939d 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -27,3 +27,25 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_gen_op_wrapper_cc( + name = "xla_jit_op_gen", + out_ops_file = "ops/xla_jit_op", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) + +cc_library( + name = "xla_jit_ops", + srcs = ["ops/xla_jit_op.cc"], + hdrs = ["ops/xla_jit_op.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/jit/ops:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 8d1f2684909e87..42585ad4d8a17d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, return Status::OK(); } -Status FunctionalizeLoop(Graph* graph, Frame* frame, +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function refered by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << dump_graph::DumpGraphToFile("functionalize_before", *graph, @@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } // Builds a While operator. NodeDef while_def; @@ -1365,6 +1424,12 @@ Status FunctionalizeCond::Functionalize(Graph* graph, // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); @@ -1434,7 +1499,8 @@ Status FunctionalizeControlFlow(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 4d4ee3054c2914..d941041d155324 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -22,9 +22,13 @@ limitations under the License. namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index e494f42e8ed254..14977a908ae2b0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { + FunctionDef fdef = FunctionDefHelper::Create( + "increment_fn", {"x:int32"}, {"add:int32"}, {}, + { + {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, + {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, + }, + {{"add", "add_0:z:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); + NodeDef increment_fn; + increment_fn.set_name(node_name); + increment_fn.set_op("increment_fn"); + *increment_fn.add_input() = "while/Identity"; + *increment_fn.add_input() = "^while/Identity"; + Status status; + graph->AddNode(increment_fn, &status); + return status; +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) +TEST(FunctionalizeControlFlow, NoinlineLoopBody) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, + "while/while_context"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + + NodeDef next_iter; + next_iter.set_name("while/NextIteration"); + next_iter.set_op("NextIteration"); + *next_iter.add_input() = noinline_node_name; + (*next_iter.mutable_attr())["T"].set_type(DT_INT32); + + Status status; + Node* n = scope.graph()->AddNode(next_iter, &status); + TF_ASSERT_OK(status); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(n, 0, merge.output.node(), 1); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // Function increment_fn will be copied from lookup_lib to library. + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Tests functionalizing OneLoopVar where the loop value is not used post the // loop. // Graph: diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index b20c1ffc7d8956..212f6f3966149c 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { auto builder = ctx->builder(); + auto client = ctx->compiler()->client(); std::vector compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( @@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.kind = XlaCompiler::Argument::kConstant; TF_RET_CHECK(expressions[i]->resource() == nullptr) << "Input with resource is not yet implemented."; + TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( + expressions[i]->handle())); TF_ASSIGN_OR_RETURN(auto literal, - builder->ComputeConstant(expressions[i]->handle())); + client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); } else { @@ -205,14 +208,15 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; XlaCompiler::CompilationResult result; - - TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), - func, arguments, &result)); + TF_RETURN_IF_ERROR( + compiler->CompileFunction(compile_options, func, arguments, &result)); TF_RET_CHECK(arguments.size() == expressions.size()); - std::vector handles; + std::vector handles; for (int64 i = 0; i < expressions.size(); ++i) { if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; @@ -226,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto output_handle = b->Call(*result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. + int computation_output = 0; for (int64 i = 0; i < n->num_outputs(); ++i) { if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + xla_op_context.SetOutput( + i, b->GetTupleElement(output_handle, computation_output)); + ++computation_output; } } return b->first_error(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 00fd08b1a07507..edd2ab6301ee89 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,6 +18,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", @@ -45,6 +46,7 @@ tf_kernel_library( "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", + "listdiff_op.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", @@ -114,8 +116,8 @@ tf_kernel_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -151,7 +153,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -167,7 +169,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -203,8 +205,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 5c9f66df101bfb..1e59868621475c 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel { OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("AddN requires at least one argument")); - xla::ComputationDataHandle sum = ctx->Input(0); + xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { sum = ctx->builder()->Add(sum, ctx->Input(i)); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 931175be1111ed..15e1815a4cf07f 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); int feature_index = @@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel { input = builder->ConvertElementType(input, scale_type); if (is_training_) { - xla::ComputationDataHandle output = builder->BatchNormTraining( + xla::XlaOp output = builder->BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the @@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel { ctx->SetOutput(3, builder->GetTupleElement(output, 1)); ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = builder->BatchNormInference( + xla::XlaOp output = builder->BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); @@ -118,7 +118,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); DataType input_dtype = ctx->input_type(0); DataType scale_dtype = ctx->input_type(2); @@ -137,11 +137,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { const int feature_index = GetTensorFeatureDimIndex(input_dims, data_format_); - xla::ComputationDataHandle x_backprop; - xla::ComputationDataHandle scale_backprop; - xla::ComputationDataHandle offset_backprop; + xla::XlaOp x_backprop; + xla::XlaOp scale_backprop; + xla::XlaOp offset_backprop; if (is_training_) { - xla::ComputationDataHandle output = + xla::XlaOp output = b->BatchNormGrad(activations, scale, mean, var, grad_backprop, epsilon_, feature_index); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 569950c2dfaeb6..642278ab994bf3 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -20,9 +20,8 @@ limitations under the License. namespace tensorflow { namespace { -void BatchToSpace(XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, DataType input_dtype, - const TensorShape& input_tensor_shape, +void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, + DataType input_dtype, const TensorShape& input_tensor_shape, gtl::ArraySlice block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); @@ -46,7 +45,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -73,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -91,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = b->Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -111,8 +110,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::ComputationDataHandle reshaped_permuted = - b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -139,7 +137,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, "Cropped size must be non-negative: start: ", crop_start, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } - xla::ComputationDataHandle output = + xla::XlaOp output = b->Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index ed33b8ed2e823f..9d677f426650ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -60,7 +60,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::ComputationDataHandle result = + xla::XlaOp result = ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -103,7 +103,7 @@ class BiasAddGradOp : public XlaOpKernel { std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), feature_dim + 1); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); const DataType accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2436a6074a11ad..f04cde878e9800 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -34,14 +34,13 @@ namespace { class NAME##Op : public XlaBinaryOp { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::ComputationDataHandle Computation( \ - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \ - const gtl::ArraySlice& lhs_shape, \ - const xla::ComputationDataHandle& rhs, \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ const gtl::ArraySlice& rhs_shape, \ const BCast& broadcast_helper, \ const std::vector& extend_dimensions) override { \ - xla::ComputationBuilder* b = ctx->builder(); \ + xla::XlaBuilder* b = ctx->builder(); \ return HLO; \ } \ }; \ @@ -63,11 +62,8 @@ XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); // } else { // return x / y; // } -static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, - DataType dtype, - xla::ComputationDataHandle x, - xla::ComputationDataHandle y, - const BCast& broadcast_helper) { +static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); @@ -87,11 +83,8 @@ XLA_MAKE_BINARY(FloorDiv, // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); -static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, - DataType dtype, - xla::ComputationDataHandle x, - xla::ComputationDataHandle y, - const BCast& broadcast_helper) { +static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); @@ -127,8 +120,7 @@ XLA_MAKE_BINARY(SqrtGrad, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { return builder->Mul(x, x); } @@ -175,11 +167,11 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); - auto abs_type = abs_shape.ValueOrDie()->element_type(); + auto abs_type = abs_shape.ValueOrDie().element_type(); auto result = b->Lt( abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc new file mode 100644 index 00000000000000..ca9a6b40688d1e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BucketizeOp : public XlaOpKernel { + public: + explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); + OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), + errors::InvalidArgument("Expected sorted boundaries")); + } + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + const DataType dtype = context->input_type(0); + xla::XlaOp input = context->Input(0); + + xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + // TODO(phawkins): the following behavior matches the behavior of the core + // Bucketize kernel. However, comparing an int32 or int64 against float may + // lead to inaccurate bucketing due to rounding. + if (dtype == DT_DOUBLE) { + input = builder->ConvertElementType(input, xla::F64); + boundaries = builder->ConvertElementType(boundaries, xla::F64); + } else { + input = builder->ConvertElementType(input, xla::F32); + } + xla::XlaOp comparison = builder->ConvertElementType( + builder->Ge(builder->Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = builder->Reduce( + comparison, /*init_value=*/builder->ConstantR0(0), + /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + context->SetOutput(0, buckets); + } + + private: + std::vector boundaries_; +}; + +REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index c52b2dcb7e9ef8..e9d98c768572c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -33,9 +33,9 @@ class CastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output; if (src_dtype_ == dst_dtype_) { output = input; @@ -72,9 +72,9 @@ class BitcastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output; if (src_dtype_ == dst_dtype_) { output = input; diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 545aa364f937b2..835a7f568945f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -34,7 +34,7 @@ class CategoricalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // Get the logits - const xla::ComputationDataHandle& logits = ctx->Input(0); + const xla::XlaOp& logits = ctx->Input(0); TensorShape logits_shape = ctx->InputShape(0); int64 num_samples; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples)); @@ -56,7 +56,7 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); std::array uniform_shape_array = { {batch_size, num_samples, num_classes}}; @@ -78,7 +78,7 @@ class CategoricalOp : public XlaOpKernel { /*broadcast_dimensions=*/{0, 2}); TensorShape softmax_shape(uniform_shape_array); - xla::ComputationDataHandle argmax; + xla::XlaOp argmax; OP_REQUIRES_OK( ctx, XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index fdf75be7b11565..a00bc912f9f400 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -29,7 +29,7 @@ class ClipByValueOp : public XlaOpKernel { const TensorShape min_shape = ctx->InputShape(1); const TensorShape max_shape = ctx->InputShape(2); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto input = ctx->Input(0); auto min = ctx->Input(1); auto max = ctx->Input(2); diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 1a246e8df9b2cd..78285affa1c399 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -54,7 +54,7 @@ class ConcatBaseOp : public XlaOpKernel { // TODO(annarev): add a helper to support int64 input. const int32 concat_dim = literal.Get({}); - std::vector values; + std::vector values; std::vector shapes; OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); const int N = values.size(); @@ -70,13 +70,13 @@ class ConcatBaseOp : public XlaOpKernel { "[", -input_dims, ", ", input_dims, "), but got ", concat_dim)); - // Make a vector holding the ComputationDataHandles for each of - // the inputs that has non-zero elements. - std::vector input_data; + // Make a vector holding the XlaOp for each of the inputs that has non-zero + // elements. + std::vector input_data; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { - xla::ComputationDataHandle handle = values[i]; + xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 8f78b4c8f90cf0..59d06c654de18c 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -45,7 +45,7 @@ class ConstOp : public XlaOpKernel { ctx->SetInvalidOutput(0); return; } - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, // recognize that case and emit a scalar broadcast instead. diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index c0ee0c9c2ea849..627bad12f33c82 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -47,9 +47,8 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( } // Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::ComputationDataHandle CreateExpandedZero( - const TensorShape& filter_shape, DataType dtype, - xla::ComputationBuilder* builder) { +xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); return builder->Broadcast(XlaHelpers::Zero(builder, dtype), @@ -87,8 +86,8 @@ xla::ComputationDataHandle CreateExpandedZero( // // Finally compare A and broadcasted B in dimension 2 amd return the result at // the beginning of the comment. -xla::ComputationDataHandle CreateExpandedFilterMask( - const TensorShape& filter_shape, xla::ComputationBuilder* builder) { +xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); @@ -96,11 +95,11 @@ xla::ComputationDataHandle CreateExpandedFilterMask( // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::ComputationDataHandle input_feature_iota; + xla::XlaOp input_feature_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, &input_feature_iota)); - xla::ComputationDataHandle expanded_feature_iota; + xla::XlaOp expanded_feature_iota; TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature * depthwise_multiplier, &expanded_feature_iota)); @@ -126,10 +125,10 @@ xla::ComputationDataHandle CreateExpandedFilterMask( // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. -xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( - const TensorShape& filter_shape, DataType dtype, - const xla::ComputationDataHandle& filter, - xla::ComputationBuilder* builder) { +xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, + DataType dtype, + const xla::XlaOp& filter, + xla::XlaBuilder* builder) { int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = @@ -156,10 +155,11 @@ xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( } // Inverse of ExpandFilterForDepthwiseConvolution. -xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, - const xla::ComputationDataHandle& filter_backprop, - xla::ComputationBuilder* builder) { +xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, + const TensorShape& filter_shape, + DataType dtype, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); auto masked_expanded_filter = builder->Select( @@ -248,9 +248,9 @@ class ConvOp : public XlaOpKernel { "input and filter must have the same depth: ", in_depth, " vs ", input_shape.dim_size(feature_dim))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); - xla::ComputationDataHandle filter = ctx->Input(1); + xla::XlaOp filter = ctx->Input(1); TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { filter = ExpandFilterForDepthwiseConvolution( @@ -288,7 +288,7 @@ class ConvOp : public XlaOpKernel { &unused_output_size, &padding[i].first, &padding[i].second)); } - xla::ComputationDataHandle conv = + xla::XlaOp conv = b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); @@ -391,7 +391,7 @@ class ConvBackpropInputOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto filter = ctx->Input(1); auto out_backprop = ctx->Input(2); @@ -435,12 +435,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::ComputationDataHandle mirrored_weights = - b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights - xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = b->ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -546,9 +545,9 @@ class ConvBackpropFilterOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle activations = ctx->Input(0); - xla::ComputationDataHandle gradients = ctx->Input(2); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp activations = ctx->Input(0); + xla::XlaOp gradients = ctx->Input(2); // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 3df8c00f1b8355..7fcd4170fb79a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -53,7 +53,7 @@ class CrossOp : public XlaOpKernel { } std::vector strides(in0_shape.dims(), 1); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto in0 = ctx->Input(0); auto in1 = ctx->Input(1); starts.back() = 0; diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 0cf03ceb948a51..01aa1a83e79679 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" @@ -75,7 +75,7 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } // Call virtual method to emit the computation. - xla::ComputationDataHandle output = + xla::XlaOp output = Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle, rhs_shape.dim_sizes(), bcast, extend_dimension); @@ -85,11 +85,9 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(0, output); } -/* static */ std::pair -XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& lhs, - const xla::ComputationDataHandle& rhs, - const BCast& broadcast_helper) { +/* static */ std::pair XlaBinaryOp::Broadcast( + xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, + const BCast& broadcast_helper) { // Manually construct the broadcasting since MapN does not do // automatic broadcasting. The bcast helper ensures that // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 5bc1d5fb1f08fb..4f92dbc8740b69 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -30,7 +30,7 @@ namespace tensorflow { // inputs that can be broadcast to the same shape. The base class // contains pure virtual methods to override: description is a textual // description of the operation; and Computation adds the -// implementation of the operation to a xla::ComputationBuilder. For most +// implementation of the operation to a xla::XlaBuilder. For most // arithmetic Ops XLA handles the broadcasting automatically given the input // tensors. class XlaBinaryOp : public XlaOpKernel { @@ -55,10 +55,9 @@ class XlaBinaryOp : public XlaOpKernel { // higher-rank input should be matched when broadcasting the // lower-rank input. See comment below and the documentation on broadcasting // in the XLA documentation. - virtual xla::ComputationDataHandle Computation( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, - const gtl::ArraySlice& lhs_shape, - const xla::ComputationDataHandle& rhs, + virtual xla::XlaOp Computation( + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, + const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; @@ -67,11 +66,9 @@ class XlaBinaryOp : public XlaOpKernel { // Helper function that performs the broadcasting described by // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. - static std::pair - Broadcast(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& lhs, - const xla::ComputationDataHandle& rhs, - const BCast& broadcast_helper); + static std::pair Broadcast( + xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, + const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 96d7809f799563..23243f62462c63 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -50,8 +50,8 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,8 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -152,8 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 765ea922a532a0..931705ba837153 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -25,10 +25,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr CreateDiagonal( - const xla::ComputationDataHandle& input, int64 last_dim_size, +xla::StatusOr CreateDiagonal( + const xla::XlaOp& input, int64 last_dim_size, tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, - xla::ComputationBuilder* builder) { + xla::XlaBuilder* builder) { // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -38,12 +38,11 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::ComputationDataHandle iota; + xla::XlaOp iota; TF_RETURN_IF_ERROR( XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::ComputationDataHandle iota_broadcast = - builder->Broadcast(iota, {last_dim_size}); - xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. @@ -65,8 +64,7 @@ xla::StatusOr CreateDiagonal( std::vector broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::ComputationDataHandle input_broadcast = - builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; xla::PrimitiveType element_type; @@ -74,7 +72,7 @@ xla::StatusOr CreateDiagonal( DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = Zeros(builder, broadcast_shape); input_broadcast = builder->Add(input_broadcast, zeros); return builder->Select(mask, input_broadcast, zeros); @@ -85,7 +83,7 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); @@ -96,7 +94,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -112,7 +110,7 @@ class DiagOp : public XlaOpKernel { auto diag_or_status = CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -131,7 +129,7 @@ class DiagPartOp : public XlaOpKernel { explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -158,7 +156,7 @@ class DiagPartOp : public XlaOpKernel { new_dims.push_back(dims[i]); } - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); // TODO(b/30878775): use Slice with strides when supported, in place of // the Pad -> Reshape -> Slice. @@ -199,7 +197,7 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); @@ -210,7 +208,7 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); @@ -232,7 +230,7 @@ class MatrixDiagPartOp : public XlaOpKernel { explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -241,7 +239,7 @@ class MatrixDiagPartOp : public XlaOpKernel { errors::InvalidArgument("Expected 2 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = dims[last_dim]; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 800ef5ab98d70a..0419de78b2ee83 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -57,7 +57,7 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::ComputationDataHandle result = ctx->builder()->DynamicUpdateSlice( + xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( ctx->Input(0), ctx->Input(1), ctx->Input(2)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index f2cd21ffb9ce88..dd4a1690877950 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -56,7 +56,7 @@ class DynamicStitchOp : public XlaOpKernel { std::vector indices_input; OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input)); - std::vector data; + std::vector data; std::vector data_shapes; OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes)); @@ -136,7 +136,7 @@ class DynamicStitchOp : public XlaOpKernel { // Look up all the children expressions that represent the data // inputs. - std::vector input(indices.size()); + std::vector input(indices.size()); for (int input_num = 0; input_num < indices.size(); input_num++) { TensorShape new_shape; // first reshaped dimension is the number of indices for this input. @@ -166,7 +166,7 @@ class DynamicStitchOp : public XlaOpKernel { for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } - std::vector to_concat(number_of_indices); + std::vector to_concat(number_of_indices); for (int index_num = 0; index_num < number_of_indices; index_num++) { const auto& expression = input[src_input_vector[index_num]]; // Take the appropriate slice of data. diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 2fd27c5ca7e87c..493781a1e68b89 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -32,11 +32,10 @@ class EluOp : public XlaOpKernel { explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); } }; @@ -47,7 +46,7 @@ class EluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return lhs * (1 + rhs). void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); @@ -66,15 +65,14 @@ class SeluOp : public XlaOpKernel { explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), b->Mul(scale_alpha, expm1))); } @@ -86,9 +84,8 @@ class SeluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return lhs * (1 + rhs). void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index b2970eae20a3fb..6df01cabbf1d98 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -93,7 +93,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { input_shape.DebugString())); const int64 depth = input_shape.dim_size(feature_dim); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); // The following code is equivalent to: // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) @@ -110,7 +110,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::ComputationDataHandle iota; + xla::XlaOp iota; TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, kernel_size * depth, &iota)); @@ -147,7 +147,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { &padding[i].first, &padding[i].second)); } - xla::ComputationDataHandle conv = + xla::XlaOp conv = builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 99470d70e709dd..8f0de0a524c908 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -44,23 +44,20 @@ void CpuNudge(const float min, const float max, const float quant_min, } // An XLA version of CpuNudge(). -void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, - const xla::ComputationDataHandle& min, - const xla::ComputationDataHandle& max, +void XlaNudge(xla::XlaBuilder* b, const DataType data_type, + const xla::XlaOp& min, const xla::XlaOp& max, const float quant_min_value, const float quant_max_value, - xla::ComputationDataHandle* nudged_min, - xla::ComputationDataHandle* nudged_max, - xla::ComputationDataHandle* scale) { + xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, + xla::XlaOp* scale) { *scale = b->Div(b->Sub(max, min), XlaHelpers::FloatLiteral(b, data_type, quant_max_value - quant_min_value)); - xla::ComputationDataHandle quant_min = + xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::ComputationDataHandle zero_point_from_min = - b->Sub(quant_min, b->Div(min, *scale)); - xla::ComputationDataHandle quant_max = + xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); - xla::ComputationDataHandle nudged_zero_point = + xla::XlaOp nudged_zero_point = b->Select(b->Le(zero_point_from_min, quant_min), quant_min, b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, b->Round(zero_point_from_min))); @@ -68,22 +65,18 @@ void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); } -xla::ComputationDataHandle Quantize( - xla::ComputationBuilder* b, const xla::ComputationDataHandle& input, - const DataType data_type, - const xla::ComputationDataHandle& nudged_input_min, - const xla::ComputationDataHandle& nudged_input_max, - const xla::ComputationDataHandle& input_scale) { - xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::ComputationDataHandle inv_scale = b->Div(one, input_scale); - xla::ComputationDataHandle half = - XlaHelpers::FloatLiteral(b, data_type, 0.5f); - - xla::ComputationDataHandle clamped = - b->Clamp(nudged_input_min, input, nudged_input_max); - xla::ComputationDataHandle clamped_shifted = - b->Sub(clamped, nudged_input_min); - xla::ComputationDataHandle rounded = +xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, + const DataType data_type, + const xla::XlaOp& nudged_input_min, + const xla::XlaOp& nudged_input_max, + const xla::XlaOp& input_scale) { + xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); + xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); + + xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp rounded = b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); return b->Add(b->Mul(rounded, input_scale), nudged_input_min); } @@ -111,18 +104,18 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min = + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); - xla::ComputationDataHandle nudged_input_max = + xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::ComputationDataHandle input_scale = + xla::XlaOp input_scale = XlaHelpers::FloatLiteral(b, data_type, input_scale_); - xla::ComputationDataHandle output = Quantize( - b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min, + nudged_input_max, input_scale); ctx->SetOutput(0, output); } @@ -159,23 +152,22 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle gradient = ctx->Input(0); + xla::XlaOp gradient = ctx->Input(0); const TensorShape gradient_shape = ctx->InputShape(0); - xla::ComputationDataHandle input = ctx->Input(1); + xla::XlaOp input = ctx->Input(1); const DataType data_type = ctx->input_type(1); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min = + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); - xla::ComputationDataHandle nudged_input_max = + xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::ComputationDataHandle between_nudged_min_max = + xla::XlaOp between_nudged_min_max = b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::ComputationDataHandle zeroes = b->Broadcast( - XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes()); - xla::ComputationDataHandle output = - b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -204,18 +196,18 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::ComputationDataHandle input_min = ctx->Input(1); - xla::ComputationDataHandle input_max = ctx->Input(2); + xla::XlaOp input_min = ctx->Input(1); + xla::XlaOp input_max = ctx->Input(2); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::ComputationDataHandle output = Quantize( - b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min, + nudged_input_max, input_scale); ctx->SetOutput(0, output); } @@ -243,47 +235,43 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle gradient = ctx->Input(0); + xla::XlaOp gradient = ctx->Input(0); const TensorShape gradient_shape = ctx->InputShape(0); - xla::ComputationDataHandle input = ctx->Input(1); + xla::XlaOp input = ctx->Input(1); const DataType data_type = ctx->input_type(1); const DataType accumulation_type = XlaHelpers::SumAccumulationType(data_type); - xla::ComputationDataHandle input_min = ctx->Input(2); - xla::ComputationDataHandle input_max = ctx->Input(3); + xla::XlaOp input_min = ctx->Input(2); + xla::XlaOp input_max = ctx->Input(3); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::ComputationDataHandle between_nudged_min_max = + xla::XlaOp between_nudged_min_max = b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type); - xla::ComputationDataHandle zeroes = - b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::ComputationDataHandle output0 = - b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zero = XlaHelpers::Zero(b, data_type); + xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); - xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes); - xla::ComputationDataHandle reduce1 = b->ReduceAll( + xla::XlaOp below_min = b->Lt(input, nudged_input_min); + xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = b->ReduceAll( XlaHelpers::ConvertElementType(b, select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::ComputationDataHandle output1 = - XlaHelpers::ConvertElementType(b, reduce1, data_type); + xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); - xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes); - xla::ComputationDataHandle reduce2 = b->ReduceAll( + xla::XlaOp above_max = b->Gt(input, nudged_input_max); + xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = b->ReduceAll( XlaHelpers::ConvertElementType(b, select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::ComputationDataHandle output2 = - XlaHelpers::ConvertElementType(b, reduce2, data_type); + xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index a4f3c1c3ad9a92..933924cad1c7ca 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -62,9 +62,8 @@ class GenericFftOp : public XlaOpKernel { } } - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle fft = - b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } @@ -82,9 +81,11 @@ class FFTOp : public GenericFftOp { explicit FFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); -REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); -REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); +REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<3>); template class IFFTOp : public GenericFftOp { @@ -92,9 +93,12 @@ class IFFTOp : public GenericFftOp { explicit IFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); -REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); -REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); +REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<3>); template class RFFTOp : public GenericFftOp { diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index eaa13b8dfacce9..e4467a0fb138ed 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -48,7 +48,7 @@ class FillOp : public XlaOpKernel { 0, {dims_shape.num_elements()}, &dims_literal)); // Convert the dims literal into a vector that we can pass to - // ComputationBuilder. + // XlaBuilder. std::vector broadcast; broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { @@ -56,7 +56,7 @@ class FillOp : public XlaOpKernel { } // Look up the value input, reshaping to a scalar if it was a // 'legacy' scalar (secretly a vector). - xla::ComputationDataHandle data = ctx->Input(1); + xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); data = ctx->builder()->Reshape(data, {}); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 0b79cb0916ee8a..d13e25bcddae16 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -26,13 +26,11 @@ limitations under the License. namespace tensorflow { -Status XlaGather(const xla::ComputationDataHandle& input, - const TensorShape& input_shape, - const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, - bool indices_are_nd, DataType dtype, DataType index_type, - xla::ComputationBuilder* builder, - xla::ComputationDataHandle* gather_output) { +Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, const TensorShape& indices_shape, + int64 axis, bool indices_are_nd, DataType dtype, + DataType index_type, xla::XlaBuilder* builder, + xla::XlaOp* gather_output) { // There is no deep reason why we need this precondition, but this is the only // combination that is used and tested today. CHECK(!indices_are_nd || axis == 0); @@ -153,7 +151,7 @@ class GatherOp : public XlaOpKernel { explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto input = context->Input(0); auto input_shape = context->InputShape(0); auto indices = context->Input(1); @@ -182,7 +180,7 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, errors::InvalidArgument("indices must be int32 or int64")); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( context, XlaGather(input, input_shape, indices, indices_shape, axis, /*indices_are_nd=*/false, input_type(0), index_type, @@ -220,10 +218,10 @@ class GatherNdOp : public XlaOpKernel { indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", params_shape.dims())); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto params = context->Input(0); auto indices = context->Input(1); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/true, params_type, diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index f9376f0eabdc0f..d898e43b858bac 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -33,13 +33,11 @@ namespace tensorflow { // If `indices_are_nd` is true, the last dimension of `indices` are treated as // a multidimensional index values. Otherwise, `indices` is treated as a tensor // of scalar indices. -Status XlaGather(const xla::ComputationDataHandle& input, - const TensorShape& input_shape, - const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, - bool indices_are_nd, DataType dtype, DataType index_type, - xla::ComputationBuilder* builder, - xla::ComputationDataHandle* gather_output); +Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, const TensorShape& indices_shape, + int64 axis, bool indices_are_nd, DataType dtype, + DataType index_type, xla::XlaBuilder* builder, + xla::XlaOp* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index eefbe55c815d80..8b9b026643cf35 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -37,7 +37,7 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { // TODO(b/35949885): There is duplication here with the handling of the // while_op. Refactor the common code out/rework. void XlaIfOp::Compile(XlaOpKernelContext* ctx) { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); OP_REQUIRES(ctx, cond_type_ == DT_BOOL, errors::InvalidArgument( @@ -48,7 +48,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); + std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; @@ -175,19 +175,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } - xla::ComputationDataHandle outputs = + xla::XlaOp outputs = b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, b->Tuple(inputs), *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i); + xla::XlaOp output_handle = b->GetTupleElement(outputs, i); if (VLOG_IS_ON(2)) { LOG(INFO) << "Setting output " << i; auto shape_or = b->GetShape(output_handle); if (shape_or.ok()) { LOG(INFO) << "Shape for output " << i << ": " - << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie()); + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); } else { LOG(INFO) << "Shape unknown for output " << i; } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 5eeda79a935e81..1568b33679963c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -23,10 +23,9 @@ namespace { // Converts 'input' from RGB format to HSV format. // 'shape' is the shape of the red/green/blue tensors. -std::array RGBToHSV( - XlaOpKernelContext* ctx, xla::ComputationBuilder* b, - const std::array& rgb, DataType dtype, - const TensorShape& shape) { +std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, + const std::array& rgb, + DataType dtype, const TensorShape& shape) { auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); @@ -54,12 +53,12 @@ std::array RGBToHSV( } // Converts 'input' from HSV format to RGB format. -std::array HSVToRGB( - xla::ComputationBuilder* b, - const std::array& hsv, DataType dtype) { - xla::ComputationDataHandle hue = hsv[0]; - xla::ComputationDataHandle saturation = hsv[1]; - xla::ComputationDataHandle value = hsv[2]; +std::array HSVToRGB(xla::XlaBuilder* b, + const std::array& hsv, + DataType dtype) { + xla::XlaOp hue = hsv[0]; + xla::XlaOp saturation = hsv[1]; + xla::XlaOp value = hsv[2]; auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); @@ -95,16 +94,16 @@ class RGBToHSVOp : public XlaOpKernel { errors::FailedPrecondition("input must have 3 channels but input has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; @@ -133,15 +132,15 @@ class HSVToRGBOp : public XlaOpKernel { errors::FailedPrecondition("input must have 3 channels but input has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle hue = + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp hue = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle saturation = + xla::XlaOp saturation = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle value = + xla::XlaOp value = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); @@ -174,9 +173,9 @@ class AdjustContrastOpV2 : public XlaOpKernel { errors::InvalidArgument("contrast_factor must be scalar: ", factor_shape.DebugString())); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle factor = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp factor = context->Input(1); DataType type = context->input_type(0); @@ -221,19 +220,19 @@ class AdjustSaturationOp : public XlaOpKernel { errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle scale = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp scale = context->Input(1); DataType type = context->input_type(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; @@ -271,19 +270,19 @@ class AdjustHueOp : public XlaOpKernel { errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle delta = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp delta = context->Input(1); DataType type = context->input_type(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index f36b3f594826c2..79d3a6979cec4c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -99,28 +99,35 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } -xla::ComputationDataHandle MakeBilinearResizeKernel( - xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, - int64 channels) { - // Form a 2D convolution kernel like: - // 1 2 3 2 1 - // 2 4 6 4 2 - // 1/9 * 3 6 9 6 3 - // 2 4 6 4 2 - // 1 2 3 2 1 - // by multiplying two 1D kernels of the form: - // 1/3 * [1 2 3 2 1] - auto make_1d_kernel = [](int64 n) { - std::vector kernel(n * 2 - 1); - for (int64 i = 0; i < n; ++i) { - float v = (i + 1.0f) / n; - kernel[i] = v; - kernel[n * 2 - 2 - i] = v; - } - return kernel; - }; +// Form a 2D convolution kernel like: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 +// by multiplying two 1D kernels of the form: +// 1/3 * [1 2 3 2 1] +// If the 2D kernel would be very large, the 1D kernel can be applied once in +// each dimension due to the symmetry of the kernel along all axis to reduce the +// computational intensity. +std::vector Make1DKernel(int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; +} + +// Kernels with more than 16 spatial elements are considered intense and the +// kernel should applied to each dimension independently. +const int64 kMax2DKernelSize = 16; - xla::ComputationDataHandle channels_iota; +xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels) { + xla::XlaOp channels_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK( XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); @@ -133,16 +140,43 @@ xla::ComputationDataHandle MakeBilinearResizeKernel( xla::PrimitiveType::F32); return builder->Mul( builder->Mul(diag, - builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR1(Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}), - builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->ConstantR1(Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } -xla::ComputationDataHandle ResizeUsingDilationAndConvolution( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input, - const int num_spatial_dims, std::vector in_size, - std::vector out_size, const int64 channels) { +xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels, int64 dim) { + xla::XlaOp channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq(builder->Broadcast( + channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + if (dim == 1) { + return builder->Mul( + diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}); + } + return builder->Mul(diag, + builder->ConstantR1(Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + +xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, + const xla::XlaOp& input, + const int num_spatial_dims, + std::vector in_size, + std::vector out_size, + const int64 channels) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -163,20 +197,42 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution( dimension_numbers.add_output_spatial_dimensions(1 + i); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::ComputationDataHandle kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::ComputationDataHandle output = builder->ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp output; + // Split convolutions into independent dimensions if they wmuld be a very + // large kernel. + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + output = builder->ConvGeneralDilated( + input, kernel0, {dims.stride[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.kernel_size[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + output = builder->ConvGeneralDilated( + output, kernel1, {1, dims.stride[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.kernel_size[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. @@ -189,10 +245,12 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution( return output; } -xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad, - const int num_spatial_dims, std::vector in_size, - std::vector grad_size, const int64 channels) { +xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, + const xla::XlaOp& grad, + const int num_spatial_dims, + std::vector in_size, + std::vector grad_size, + const int64 channels) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size); @@ -210,26 +268,63 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::ComputationDataHandle kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp output; + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = + builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + if (in_size[0] == 1 && grad_size[0] > 1) { + kernel0 = + builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), + /*broadcast_dimensions=*/{0}); + } + if (in_size[1] == 1 && grad_size[1] > 1) { + kernel1 = + builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - } - xla::ComputationDataHandle output = builder->ConvGeneralDilated( - grad, kernel, /*window_strides=*/dims.kernel_size, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = builder->ConvGeneralDilated( + grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.stride[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + output = builder->ConvGeneralDilated( + output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.stride[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. // Opposite of the slice performed by the forward op. @@ -258,7 +353,7 @@ class ResizeBilinearOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == 4, @@ -283,7 +378,7 @@ class ResizeBilinearOp : public XlaOpKernel { const int num_spatial_dims = 2; - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. @@ -318,7 +413,7 @@ class ResizeBilinearOp : public XlaOpKernel { // from image of size axb -> cxd is same as resizing axb -> exf -> cxd. // // This makes the convolutions kernels smaller and the operation faster. - xla::ComputationDataHandle output = input; + xla::XlaOp output = input; while (in_size != out_size) { if (in_size[0] != 1 && in_size[1] != 1) { std::vector k = { @@ -369,7 +464,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape input_shape = ctx->InputShape(1); OP_REQUIRES(ctx, input_shape.dims() == 4, @@ -406,9 +501,9 @@ class ResizeBilinearGradOp : public XlaOpKernel { const int num_spatial_dims = 2; - xla::ComputationDataHandle grad = ctx->Input(0); + xla::XlaOp grad = ctx->Input(0); - xla::ComputationDataHandle output = grad; + xla::XlaOp output = grad; while (in_size != grad_size) { if (in_size[0] != 1 && in_size[1] != 1) { std::vector k = { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 7bf4b435f526af..36eb4c75454ed8 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -61,10 +61,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { DataType index_type = output_type(0); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaOp output; if (is_min_) { OP_REQUIRES_OK(ctx, XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index b1f3c3c298ce0c..2c2d88486fda99 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -71,10 +71,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), errors::InvalidArgument( "ArgMax implementation requires a CustomCall on CPU")); - xla::ComputationBuilder& b = *ctx->builder(); + xla::XlaBuilder& b = *ctx->builder(); // XLA passes to the function, so it is not included here. - std::vector args; + std::vector args; args.push_back(ctx->Input(0)); args.push_back(b.ConstantLiteral( *xla::Literal::CreateR1(input_shape.dim_sizes()))); @@ -91,7 +91,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. - xla::ComputationDataHandle output; + xla::XlaOp output; switch (input_shape.dims()) { case 1: output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index c177f08d9c4687..1decf7d72d72bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -33,7 +33,7 @@ class L2LossOp : public XlaOpKernel { std::iota(dims.begin(), dims.end(), 0); DataType dtype = ctx->input_type(0); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); // output = sum(t ** 2) / 2 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc new file mode 100644 index 00000000000000..0388b4c830702e --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 +// input. + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +constexpr std::array kListDiffTypes = {DT_INT32, DT_INT64}; + +// ListDiffOp is an XLA kernel that supports constant-only x and y input. +class ListDiffOp : public XlaOpKernel { + public: + explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)), + errors::InvalidArgument("ListDiff expects x as a vector, not ", + context->InputShape(0).DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)), + errors::InvalidArgument("ListDiff expects y as a vector, not ", + context->InputShape(1).DebugString())); + + DataType val_type = context->expected_output_dtype(0); + DataType idx_type = context->expected_output_dtype(1); + + Status status; + switch (val_type) { + case DT_INT32: + status = ListDiffWithIndexType(context, idx_type); + break; + case DT_INT64: + status = ListDiffWithIndexType(context, idx_type); + break; + default: + // This should never happen since we restrict this kernel to only match + // inputs with supported Tensor datatype. + status = errors::InvalidArgument("ListDiff expects x and y as either ", + "int32 or int64, not ", + DataTypeString(val_type)); + } + OP_REQUIRES_OK(context, status); + } + + private: + template + Status ListDiff(XlaOpKernelContext* context) { + std::vector x_input, y_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); + + std::unordered_set y_input_set; + y_input_set.reserve(y_input.size()); + for (auto y : y_input) { + y_input_set.insert(y); + } + + std::vector val_output; + std::vector idx_output; + auto x_size = x_input.size(); + for (Tidx i = 0; i < x_size; ++i) { + if (y_input_set.count(x_input[i]) > 0) { + continue; + } + val_output.push_back(x_input[i]); + idx_output.push_back(i); + } + + context->SetOutput(0, context->builder()->ConstantR1(val_output)); + context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + return Status::OK(); + } + + template + Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) { + switch (idx_type) { + case DT_INT32: + return ListDiff(context); + case DT_INT64: + return ListDiff(context); + default: + return errors::InvalidArgument( + "ListDiff expects idx_out as either int32 or int64, not ", + DataTypeString(idx_type)); + } + } +}; + +REGISTER_XLA_OP(Name("ListDiff") + .TypeConstraint("T", kListDiffTypes) + .CompileTimeConstInput("x") + .CompileTimeConstInput("y"), + ListDiffOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 1cfee3070f384a..39fbf98a627491 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -38,8 +38,8 @@ class LRNOp : public XlaOpKernel { OP_REQUIRES(ctx, in_shape.dims() == 4, errors::InvalidArgument("in must be 4-dimensional")); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); // sqr_sum[a, b, c, d] = // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) @@ -111,10 +111,10 @@ class LRNGradOp : public XlaOpKernel { "input_grads, input_image, and out_image should have the same " "shape")); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle in_grads = ctx->Input(0); - xla::ComputationDataHandle in_image = ctx->Input(1); - xla::ComputationDataHandle out_image = ctx->Input(2); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp in_grads = ctx->Input(0); + xla::XlaOp in_image = ctx->Input(1); + xla::XlaOp out_image = ctx->Input(2); // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python // pseudo-code, the Eigen code does this for each spatial position: @@ -166,7 +166,7 @@ class LRNGradOp : public XlaOpKernel { auto dy_reduced = XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::ComputationDataHandle gradients = builder->Add( + xla::XlaOp gradients = builder->Add( builder->Mul(in_image, dy_reduced), builder->Mul(in_grads, builder->Pow(norm, builder->ConstantR0(-beta_)))); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 886baf8115243a..6949b296f4b9af 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -66,8 +66,8 @@ class MatMulOp : public XlaOpKernel { a_shape.DebugString(), ", In[1]: ", b_shape.DebugString())); - xla::ComputationDataHandle a = ctx->Input(0); - xla::ComputationDataHandle b = ctx->Input(1); + xla::XlaOp a = ctx->Input(0); + xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { a = ctx->builder()->ConvertElementType(a, xla::F32); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index faa415a97b053b..fbd5dc0fdad448 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -44,10 +44,10 @@ class MatrixBandPartOp : public XlaOpKernel { errors::InvalidArgument("num_upper must be scalar, got shape ", num_upper_in_shape.DebugString())); - xla::ComputationBuilder* builder = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle num_lower = context->Input(1); - xla::ComputationDataHandle num_upper = context->Input(2); + xla::XlaBuilder* builder = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp num_lower = context->Input(1); + xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); @@ -58,10 +58,10 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::ComputationDataHandle iota_m; + xla::XlaOp iota_m; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); - xla::ComputationDataHandle iota_n; + xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index b2940bdcff75a0..db53f6fef8d6bf 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -54,16 +54,16 @@ class MatrixSetDiagOp : public XlaOpKernel { input_shape.DebugString(), " and diagonal shape: ", diag_shape.DebugString())); - xla::ComputationBuilder* builder = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle diag = context->Input(1); + xla::XlaBuilder* builder = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp diag = context->Input(1); auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::ComputationDataHandle iota_m; + xla::XlaOp iota_m; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::ComputationDataHandle iota_n; + xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); auto indicator = builder->Eq(iota_m, builder->Broadcast(iota_n, {m}), diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 05a36a031ad73b..7e9de3ef9b245c 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -25,10 +25,11 @@ class MirrorPadOp : public XlaOpKernel { public: explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} - xla::StatusOr DoMirrorPad( - const xla::ComputationDataHandle& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, xla::ComputationBuilder* b) { - xla::ComputationDataHandle accum = t; + xla::StatusOr DoMirrorPad(const xla::XlaOp& t, + const xla::Shape& original_shape, + const xla::Literal& pad_literal, + xla::XlaBuilder* b) { + xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = b->Rev(accum, {dimno}); @@ -76,12 +77,12 @@ class MirrorPadOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto in0 = ctx->Input(0); - xla::StatusOr> in0_shape = b->GetShape(in0); + xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); - xla::StatusOr accum_status = - DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b); + xla::StatusOr accum_status = + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 8c8a9bbe787f32..65ab9da8d7ca05 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -24,8 +24,7 @@ namespace tensorflow { REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); // We register ControlTrigger as a no-op. This is correct since nodes seen -// by the XLA compiler are never dead. This may need rethinking when we add -// support for conditionals to XLA. -REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); +// by the XLA compiler are never dead. +REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 9f7c9913802d31..cac2eea96eeed7 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -62,7 +62,7 @@ class OneHotOp : public XlaOpKernel { ctx, depth >= 0, errors::InvalidArgument("depth must be non-negative, got: ", depth)); - xla::ComputationDataHandle one_hot; + xla::XlaOp one_hot; OP_REQUIRES_OK( ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0), indices_shape, ctx->Input(0), ctx->Input(2), diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a4318e29d2532f..aecaabb6dcf46b 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -43,7 +43,7 @@ class PackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - std::vector values; + std::vector values; std::vector shapes; OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); const int num = values.size(); @@ -69,7 +69,7 @@ class PackOp : public XlaOpKernel { -expanded_num_dims, ", ", expanded_num_dims, ")")); - std::vector reshaped_inputs(num); + std::vector reshaped_inputs(num); TensorShape child_shape(shapes[0]); child_shape.InsertDim(axis, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 791351637aee61..7c95475e7b1f02 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -70,7 +70,7 @@ class PadOp : public XlaOpKernel { } // PadV2 added a "constant_values" input that indicates the pad value. - xla::ComputationDataHandle constant_values; + xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 5f635dd1bc6122..f8e7b48a0fd948 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -66,15 +66,15 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } // Method that builds an initial value to use in reductions. - virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0; + virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; // The reduction operation to apply to each window. - virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0; + virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) = 0; + virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) = 0; void Compile(XlaOpKernelContext* ctx) override { std::vector ksize = ksize_; @@ -110,7 +110,7 @@ class PoolingOp : public XlaOpKernel { " operator must have ", num_dims(), " dimensions")); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); auto input = XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); auto reduce = ctx->builder()->ReduceWindow( @@ -135,17 +135,17 @@ class MaxPoolOp : public PoolingOp { : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + xla::XlaOp InitValue(xla::XlaBuilder* b) override { return XlaHelpers::MinValue(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { return ctx->GetOrCreateMax(reduction_type_); } - xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) override { + xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) override { return output; } }; @@ -176,9 +176,9 @@ REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); // Common computation shared between AvgPool and AvgPoolGrad. Divide each // element of an image by the count of elements that contributed to that // element during pooling. -static xla::ComputationDataHandle AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape, xla::Padding padding, +static xla::XlaOp AvgPoolDivideByCount( + XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape, xla::Padding padding, const std::vector& ksize, const std::vector& stride, int num_spatial_dims, TensorFormat data_format) { if (padding == xla::Padding::kValid) { @@ -234,17 +234,17 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + xla::XlaOp InitValue(xla::XlaBuilder* b) override { return XlaHelpers::Zero(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { return ctx->GetOrCreateAdd(reduction_type_); } - xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) override { + xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) override { return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, ksize_, stride_, num_spatial_dims_, data_format_); @@ -344,11 +344,10 @@ class MaxPoolGradOp : public XlaOpKernel { xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - xla::ComputationDataHandle init_value = - XlaHelpers::Zero(ctx->builder(), input_type(2)); + xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter( + xla::XlaOp gradients = ctx->builder()->SelectAndScatter( input, select, ksize_, stride_, xla_padding, out_backprop, init_value, scatter); @@ -462,7 +461,7 @@ class AvgPoolGradOp : public XlaOpKernel { // The input gradients are computed by a convolution of the output gradients // and the filter, with some appropriate padding. See the comment at the top // of conv_grad_ops.h for details. - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); auto dtype = input_type(1); xla::Padding xla_padding = diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 4171e076ff6d9d..661cd5923e1023 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -35,7 +35,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); // Comments taken from semantics description at @@ -46,8 +46,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // m = max(abs(input_min), abs(input_max)) if range_given is true, // m = max(abs(min_elem(input)), // abs(max_elem(input))) otherwise. - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input_min, input_max; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input_min, input_max; if (range_given_) { double input_min_value, input_max_value; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); @@ -55,14 +55,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); } else { - const xla::Computation* fmax = ctx->GetOrCreateMax(data_type); - const xla::Computation* fmin = ctx->GetOrCreateMin(data_type); + const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); + const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); input_min = b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); input_max = b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); } - xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max)); + xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); // Next, we choose our fixed-point quantization buckets, [min_fixed, // max_fixed]. If signed_input is true, this is @@ -85,7 +85,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // From this we compute our scaling factor, s: // // s = (max_fixed - min_fixed) / (2 * m). - xla::ComputationDataHandle s = + xla::XlaOp s = b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); @@ -93,7 +93,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // e is transformed into e': // // e' = (e * s).round_to_nearest() / s. - xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s); + xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index c0994c434bca51..39149d56adb244 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,7 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -41,9 +42,9 @@ class RandomUniformOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle result = b->RngUniform( - XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -100,11 +101,11 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::ComputationDataHandle result = b->RngNormal( - XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -127,22 +128,16 @@ class TruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::Shape xla_element_shape = - xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); - xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); - xla::ComputationDataHandle candidate = - b->RngNormal(mean, stddev, xla_shape); + xla::XlaBuilder* b = ctx->builder(); - auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) { + auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); }; - auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate, - xla::ComputationBuilder* b) { - xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); - xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); + auto out_of_range_mask = [two_sd](xla::XlaOp candidate, + xla::XlaBuilder* b) { + xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); + xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); return b->Or(too_large, too_small); }; @@ -152,37 +147,38 @@ class TruncatedNormalOp : public XlaOpKernel { // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd // candidate = select(out_of_range_mask, rng_normal(), candidate) // } - std::unique_ptr test_builder = - b->CreateSubBuilder("truncated_normal_test"); - { - auto* b = test_builder.get(); - xla::ComputationDataHandle candidate = - b->Parameter(0, xla_shape, "candidate"); - xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b); - OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); - } - - std::unique_ptr body_builder = - b->CreateSubBuilder("truncated_normal_body"); - { - auto* b = body_builder.get(); - xla::ComputationDataHandle candidate = - b->Parameter(0, xla_shape, "candidate"); - xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b); - xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); - xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); - b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); - } - - xla::StatusOr test_computation = test_builder->Build(); - OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); - OP_REQUIRES_OK(ctx, body_computation.status()); - xla::ComputationDataHandle result = - b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(), - candidate); - - ctx->SetOutput(0, result); + std::vector initial_values = { + // The current candidate. + b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), + // The to_resample mask, where 'true' identifies a location in the + // current candidate that is out of range and must be regenerated. + b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), + // Is any element in the mask true? + b->ConstantR0(true)}; + auto condition = [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr { + // Continue while any element in the mask is true. + return values[2]; + }; + auto body = + [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr> { + xla::XlaOp candidate = values[0]; + xla::XlaOp to_resample = values[1]; + xla::XlaOp mean = XlaHelpers::Zero(b, dtype); + xla::XlaOp stddev = XlaHelpers::One(b, dtype); + candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), + candidate); + // Compute a new to_resample mask, and determine whether any value is + // still out of range. + to_resample = out_of_range_mask(candidate, b); + TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); + return std::vector{candidate, to_resample, done}; + }; + auto result = + XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()[0]); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index cb144bea9e429b..08894489ac77bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -65,7 +64,7 @@ class ReduceWindowOp : public XlaOpKernel { "rank (", padding_high_.size(), " vs. ", rank, ")")); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -95,15 +94,15 @@ class ReduceWindowOp : public XlaOpKernel { xla::ShapeUtil::HumanString(reducer.xla_output_shape))); // Wraps the reducer in a computation that unpacks the output tuple. - xla::Computation wrapper; + xla::XlaComputation wrapper; { - std::unique_ptr cb = + std::unique_ptr cb = builder->CreateSubBuilder("wrapper"); auto x = cb->Parameter(0, scalar_shape, "x"); auto y = cb->Parameter(1, scalar_shape, "y"); auto outputs = cb->Call(*reducer.computation, {x, y}); cb->GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); + xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(context, result.status()); wrapper = std::move(result.ValueOrDie()); } @@ -113,7 +112,7 @@ class ReduceWindowOp : public XlaOpKernel { padding[i] = {padding_low_[i], padding_high_[i]}; } - xla::ComputationDataHandle output = builder->ReduceWindowWithGeneralPadding( + xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), wrapper, window_dimensions_, window_strides_, padding); context->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 812d258cd1677e..0f425637795e96 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -30,13 +30,11 @@ class SumOp : public XlaReductionOp { explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::Zero(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Add(scalar_lhs, scalar_rhs); } }; @@ -49,14 +47,12 @@ class ProdOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::One(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Mul(scalar_lhs, scalar_rhs); } }; @@ -69,14 +65,12 @@ class MinOp : public XlaReductionOp { explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::MaxValue(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Min(scalar_lhs, scalar_rhs); } }; @@ -88,14 +82,12 @@ class MaxOp : public XlaReductionOp { explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::MinValue(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Max(scalar_lhs, scalar_rhs); } }; @@ -108,20 +100,17 @@ class MeanOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::Zero(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Add(scalar_lhs, scalar_rhs); } - xla::ComputationDataHandle BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced) override { + xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); return builder->Div(reduce_output, divisor); @@ -136,14 +125,12 @@ class AllOp : public XlaReductionOp { explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return builder->ConstantR0(true); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->And(scalar_lhs, scalar_rhs); } }; @@ -155,14 +142,12 @@ class AnyOp : public XlaReductionOp { explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return builder->ConstantR0(false); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index f3181f0dadc2d3..2ecfb854a1c862 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -28,35 +28,33 @@ namespace tensorflow { // to override: description is a textual description of the mapped // function; InitialValue constructs the base case for the reduction; // BuildReducer adds the implementation of the reduction lambda to a -// xla::ComputationBuilder and BuildFinalizer adds the +// xla::XlaBuilder and BuildFinalizer adds the // implementation of the finalizer lambda (if there is one) to a -// xla::ComputationBuilder. +// xla::XlaBuilder. class XlaReductionOp : public XlaOpKernel { public: XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); ~XlaReductionOp() override {} // Return the base case for the reduction. - virtual xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) = 0; + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired // computation should be added to 'builder' and // '(scalar_lhs,scalar_rhs)' are the function's inputs. - virtual void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) = 0; + virtual void BuildReducer(xla::XlaBuilder* builder, + const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired // computation should be added to 'builder'. Argument 'reduce_output' is the // output of the reduction. 'num_elements_reduced' is the number of elements // that contributed to the reduction. Returns the transformed reduction // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::ComputationDataHandle BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced); + virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 64fe765ae9a945..4fd5bfd03999a7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -35,10 +35,9 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, // Unless BuildFinalizer is overridden the reduction has no // finalizer. -xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced) { +xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced) { return reduce_output; } @@ -96,9 +95,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { string desc = ctx->op_kernel().name(); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); @@ -110,7 +109,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); - xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); + xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 12a35529992e61..ba7d484d53d725 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -32,7 +32,7 @@ class ReluOp : public XlaOpKernel { explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); } @@ -43,7 +43,7 @@ class Relu6Op : public XlaOpKernel { explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Clamp the scalar input between 0 and 6. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); @@ -56,7 +56,7 @@ class ReluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); @@ -71,7 +71,7 @@ class Relu6GradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index c283e3b02c2676..a711278638444b 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -45,7 +45,7 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); auto is_constant = ctx->builder()->IsConstant(input); @@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + if (tc.resolve_compile_time_constants() && + (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - // The core from which a return value is returned depends on the core - // assignment of the input to the retval .Since we can't change the core - // assignment of as this point, create a tuple/get-tuple-element - // combination so that the core will be set on them. - auto tuple_elem = - ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); - tc.AddRetval(index_, dtype_, tuple_elem); + TensorShape shape = ctx->InputShape(0); + TensorShape representation_shape = + tc.is_entry_computation() + ? tc.RepresentationShape(shape, ctx->input_type(0)) + : shape; + + xla::XlaOp output = input; + if (tc.is_entry_computation()) { + output = + ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + } else { + // The core from which a return value is returned depends on the + // device assignment of the input to the retval. Since we can't change + // the device assignment of "input" at this point, we must always + // introduce an operator here, even if the shape does not change. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + output = ctx->builder()->GetTupleElement( + ctx->builder()->Tuple({output}), 0); + } + tc.AddRetval(index_, dtype_, shape, output); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index e51d386926763e..2872a3c4d49d0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -48,7 +48,7 @@ class ReverseOp : public XlaOpKernel { ctx->SetOutput(0, ctx->Input(0)); return; } - // ComputationBuilder::Rev() requires concrete values for dimensions arg. + // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); std::vector revdims(x_shape.dims()); @@ -90,7 +90,7 @@ class ReverseV2Op : public XlaOpKernel { ctx->SetOutput(0, ctx->Input(0)); return; } - // ComputationBuilder::Rev() requires concrete values for dimensions arg. + // XlaBuilder::Rev() requires concrete values for dimensions arg. std::vector axes; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 6bc5d3adb091cd..5d1c05268493f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -54,7 +54,7 @@ class ReverseSequenceOp : public XlaOpKernel { "), ", "(", seq_lens_shape.num_elements(), " vs. ", input_shape.dim_size(batch_dim_))); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); const auto input = context->Input(0); const auto seq_lens = context->Input(1); @@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens, body_builder->Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( + auto batch_element_indices = body_builder->Broadcast( XlaHelpers::Zero(body_builder.get(), seq_lens_type), {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), + batch_element_indices = body_builder->DynamicUpdateSlice( + batch_element_indices, body_builder->Reshape(i, {1}), body_builder->Reshape( XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, batch_dim_), {1})); - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, + // Slice out the current batch element and pad it out in the sequence + // dimension. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = body_builder->DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = body_builder->Pad( + slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = body_builder->DynamicUpdateSlice( + sequence_start_indices, body_builder->Sub(XlaHelpers::IntegerLiteral( body_builder.get(), seq_lens_type, max_seq_len), seq_len), @@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel { XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, seq_dim_), {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice = body_builder->DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = body_builder->DynamicUpdateSlice(output, slice, + batch_element_indices); body_builder->Tuple( {body_builder->Add( @@ -155,7 +169,7 @@ class ReverseSequenceOp : public XlaOpKernel { auto output = builder->GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::ComputationDataHandle iota; + xla::XlaOp iota; OP_REQUIRES_OK( context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); std::vector dims(input_shape.dims(), 1); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 4cfa28a0ce3d7d..1819fb543317ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -74,7 +74,7 @@ class ScanOp : public XlaOpKernel { return; } - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); std::vector window_strides(input_shape.dims(), 1); std::vector window_dims(input_shape.dims(), 1); @@ -91,8 +91,8 @@ class ScanOp : public XlaOpKernel { std::swap(padding[axis].first, padding[axis].second); } - xla::ComputationDataHandle init; - const xla::Computation* reducer; + xla::XlaOp init; + const xla::XlaComputation* reducer; if (sum_) { init = XlaHelpers::Zero(builder, dtype); reducer = ctx->GetOrCreateAdd(dtype); diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 8433a29c4e203c..f2c63b4f9083ad 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -102,7 +102,7 @@ class ScatterNdOp : public XlaOpKernel { OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape)); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), buffer_shape.dim_sizes()); auto indices = context->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 498342a98881df..664078ca16c6d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -62,16 +62,16 @@ class UnsortedSegmentSum : public XlaOpKernel { d, " differs ", data_shape.dim_size(d), " vs. ", indices_shape.dim_size(d))); } - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), buffer_shape.dim_sizes()); - auto combiner = - [](xla::ComputationDataHandle a, xla::ComputationDataHandle b, - xla::ComputationBuilder* builder) { return builder->Add(a, b); }; + auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { + return builder->Add(a, b); + }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 8081d3c41c4363..f9f48164d63492 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -40,7 +40,7 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index d079b89861817a..9ce01d0d44509b 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 463788b8b461c3..bbf5ee8b12186a 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -43,8 +43,8 @@ class SoftmaxOp : public XlaOpKernel { const DataType type = input_type(0); auto logits = ctx->Input(0); - xla::ComputationBuilder* const b = ctx->builder(); - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); + xla::XlaBuilder* const b = ctx->builder(); + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = @@ -76,16 +76,15 @@ class SoftmaxOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); -std::pair -CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, - const xla::ComputationDataHandle& logits, - const xla::ComputationDataHandle& labels) { - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); +std::pair CrossEntropyWithLogits( + XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, + const xla::XlaOp& labels) { + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; const int kClassDim = 1; - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); @@ -123,7 +122,7 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) - xla::ComputationDataHandle backprop = + xla::XlaOp backprop = b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -150,7 +149,7 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { auto logits = ctx->Input(0); auto labels = ctx->Input(1); - xla::ComputationDataHandle loss, backprop; + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, type, logits, labels); ctx->SetOutput(0, loss); @@ -191,10 +190,10 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { DataType logits_type = input_type(0); DataType indices_type = input_type(1); - xla::ComputationDataHandle indices = ctx->Input(1); + xla::XlaOp indices = ctx->Input(1); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle labels; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp labels; OP_REQUIRES_OK(ctx, XlaHelpers::OneHot( builder, depth, /*axis=*/1, input_type(1), labels_shape, @@ -207,7 +206,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::ComputationDataHandle nan_or_zero = builder->Select( + xla::XlaOp nan_or_zero = builder->Select( builder->And( builder->Le(XlaHelpers::Zero(builder, indices_type), indices), builder->Lt(indices, XlaHelpers::IntegerLiteral( @@ -218,7 +217,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { {batch_size})); labels = builder->Add(labels, nan_or_zero, {0}); - xla::ComputationDataHandle loss, backprop; + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 01b46e160d1f1f..ec077924b5b5af 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -20,9 +20,8 @@ limitations under the License. namespace tensorflow { namespace { -void SpaceToBatch(XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, DataType input_dtype, - const TensorShape& input_tensor_shape, +void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, + DataType input_dtype, const TensorShape& input_tensor_shape, gtl::ArraySlice block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); @@ -46,7 +45,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, ", 2] instead of ", xla::ShapeUtil::HumanString(paddings.shape()))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the // input according to `paddings` to produce `padded` of shape `padded_shape`. @@ -73,7 +72,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, errors::InvalidArgument( "The product of the block dimensions must be positive")); - xla::ComputationDataHandle padded = + xla::XlaOp padded = b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: @@ -101,8 +100,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::ComputationDataHandle reshaped_padded = - b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -121,7 +119,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, permutation[block_rank] = 0; std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::ComputationDataHandle permuted_reshaped_padded = + xla::XlaOp permuted_reshaped_padded = b->Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the @@ -142,8 +140,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 806fda632cde64..4c5886ee2a0f63 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -50,8 +50,8 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,8 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -156,8 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 43c15e75380535..8958b2e7701e62 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -124,7 +124,7 @@ class SplitVOp : public XlaOpKernel { input_shape.dims(), "), but got ", split_dim_orig)); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); OP_REQUIRES(ctx, input_shape.dims() > 0, errors::InvalidArgument("Can't split a 0 dimensional input")); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 1a78c7ab9be701..0fb05a2be7b103 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -38,13 +38,13 @@ limitations under the License. namespace tensorflow { namespace { -Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, +Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, TensorShape* stack_shape) { auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } - xla::Shape shape = *shape_or_status.ValueOrDie(); + xla::Shape shape = shape_or_status.ValueOrDie(); TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), stack_shape); @@ -60,9 +60,8 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, // // TODO(phawkins): consider changing the API of the stack operators to // allow an optional element shape at stack construction time. -Status MaybeInitializeStack(xla::ComputationBuilder* builder, - XlaResource* resource, DataType dtype, - const TensorShape& elem_shape) { +Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, + DataType dtype, const TensorShape& elem_shape) { if (resource->type() != dtype) { return errors::InvalidArgument( "Stack dtype is ", DataTypeString(resource->type()), @@ -75,8 +74,6 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder, if (!resource->initialized()) { // Stack has not been initialized. - xla::ComputationDataHandle zero = - XlaHelpers::Zero(builder, resource->type()); TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { @@ -111,7 +108,7 @@ class StackOp : public XlaOpKernel { // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. - xla::ComputationDataHandle value; + xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; string name = strings::StrCat("Stack: ", stack_name_); @@ -138,7 +135,7 @@ class StackPushOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape elem_shape = ctx->InputShape(1); XlaResource* resource; @@ -147,9 +144,9 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0); - xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1); - xla::ComputationDataHandle value = ctx->Input(1); + xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); + xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -184,7 +181,7 @@ class StackPopOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -199,9 +196,9 @@ class StackPopOp : public XlaOpKernel { TensorShape stack_shape; OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); - xla::ComputationDataHandle state = resource->value(); - xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); - xla::ComputationDataHandle index = b->GetTupleElement(state, 1); + xla::XlaOp state = resource->value(); + xla::XlaOp ta = b->GetTupleElement(state, 0); + xla::XlaOp index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); @@ -216,8 +213,7 @@ class StackPopOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::ComputationDataHandle read = - b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5bb773d97fc5ce..a99d4ddc7c4956 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -30,9 +30,8 @@ namespace tensorflow { namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& v, - int distance) { +xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, + int distance) { return builder->Or( builder->ShiftLeft(v, builder->ConstantR0(distance)), builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); @@ -40,25 +39,24 @@ xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, // TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than // building XOR out of other bitwise operators. -xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& y) { +xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x, + const xla::XlaOp& y) { return builder->Or(builder->And(x, builder->Not(y)), builder->And(builder->Not(x), y)); } -using ThreeFry2x32State = std::array; +using ThreeFry2x32State = std::array; // Implements the ThreeFry counter-based PRNG algorithm. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, +ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, ThreeFry2x32State input, ThreeFry2x32State key) { // Rotation distances specified by the Threefry2x32 algorithm. constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; ThreeFry2x32State x; - std::array ks; + std::array ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. ks[2] = builder->ConstantR0(0x1BD11BDA); for (int i = 0; i < 2; ++i) { @@ -121,10 +119,9 @@ ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, // Returns a tensor of 'shape' random values uniformly distributed in the range // [minval, maxval) -xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& seed, - const TensorShape& shape, - double minval, double maxval) { +xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, + const TensorShape& shape, double minval, + double maxval) { // Split the seed into two 32-bit scalars to form a key. auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); @@ -178,9 +175,8 @@ xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, // p = sum_{i=1}^n gq[i]*w^i // } // return p*x -xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& x, - const TensorShape& shape) { +xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, + const TensorShape& shape) { constexpr int kDegree = 9; constexpr std::array w_less_than_5_constants = { 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, @@ -220,7 +216,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); TensorShape shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); @@ -229,7 +225,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); - xla::ComputationDataHandle seed = ctx->Input(1); + xla::XlaOp seed = ctx->Input(1); ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); } @@ -257,9 +253,10 @@ class StatelessRandomNormalOp : public XlaOpKernel { OP_REQUIRES(ctx, seed_shape == TensorShape({2}), errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); - xla::ComputationDataHandle seed = ctx->Input(1); - xla::ComputationBuilder* builder = ctx->builder(); - auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + xla::XlaOp seed = ctx->Input(1); + xla::XlaBuilder* builder = ctx->builder(); + auto uniform = + RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 6204aa4e27000f..55254c746e5eba 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -90,7 +90,7 @@ class StridedSliceOp : public XlaOpKernel { } } - xla::ComputationDataHandle slice = ctx->Input(0); + xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } @@ -168,7 +168,7 @@ class StridedSliceGradOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); - xla::ComputationDataHandle grad = ctx->Input(4); + xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); @@ -255,7 +255,7 @@ class StridedSliceAssignOp : public XlaOpKernel { &strides_tensor)); TensorShape lhs_shape; - xla::ComputationDataHandle lhs; + xla::XlaOp lhs; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -284,7 +284,7 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle rhs = ctx->Input(4); + xla::XlaOp rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_dims; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 000b50af6bd86b..9adee78a1fd1fb 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -47,7 +47,7 @@ namespace { // the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. -Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, +Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { if (resource->kind() != XlaResource::kTensorArray) { @@ -64,9 +64,6 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, << resource->name() << " size " << resource->tensor_array_size(); if (!resource->initialized()) { - xla::ComputationDataHandle zero = - XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { @@ -77,7 +74,7 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, } TensorShape shape; TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TensorShape ta_shape; ta_shape.AddDim(resource->tensor_array_size()); @@ -114,23 +111,21 @@ Status CheckTensorArrayIsInitialized(const string& op_name, } Status GetTensorArrayShape(const XlaResource* resource, - xla::ComputationBuilder* builder, - TensorShape* shape) { + xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } -// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the +// Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the // relevant slice of 'operand'. -xla::ComputationDataHandle DynamicAddSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand, - const xla::ComputationDataHandle& update, - const gtl::ArraySlice& update_dims, - const xla::ComputationDataHandle& start_indices) { - xla::ComputationDataHandle current = +xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, + const xla::XlaOp& update, + const gtl::ArraySlice& update_dims, + const xla::XlaOp& start_indices) { + xla::XlaOp current = builder->DynamicSlice(operand, start_indices, update_dims); - xla::ComputationDataHandle sum = builder->Add(current, update); + xla::XlaOp sum = builder->Add(current, update); return builder->DynamicUpdateSlice(operand, sum, start_indices); } @@ -155,18 +150,18 @@ class TensorArrayOp : public XlaOpKernel { OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument("TensorArray size must be >= 0")); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. - xla::ComputationDataHandle value; + xla::XlaOp value; TensorShape shape; if (element_shape_.IsFullyDefined()) { CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); ta_shape.AppendShape(shape); - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_); + xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); value = b->Broadcast(zero, ta_shape.dim_sizes()); } @@ -202,7 +197,7 @@ class TensorArrayWriteOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape elem_shape = ctx->InputShape(2); @@ -213,10 +208,10 @@ class TensorArrayWriteOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value(); - xla::ComputationDataHandle index = ctx->Input(1); - xla::ComputationDataHandle value = ctx->Input(2); - xla::ComputationDataHandle flow = ctx->Input(3); + xla::XlaOp ta = resource->value(); + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + xla::XlaOp flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -227,7 +222,7 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = b->Reshape(value, slice_shape.dim_sizes()); - xla::ComputationDataHandle written = + xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); OP_REQUIRES_OK(ctx, resource->SetValue(written)); @@ -249,7 +244,7 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -259,8 +254,8 @@ class TensorArrayReadOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value(); - xla::ComputationDataHandle index = ctx->Input(1); + xla::XlaOp ta = resource->value(); + xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -270,8 +265,7 @@ class TensorArrayReadOp : public XlaOpKernel { auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::ComputationDataHandle read = - b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); @@ -293,7 +287,7 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -309,7 +303,7 @@ class TensorArrayGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); // Look for the case where the gather takes a simple slice from the // tensor array (0, 1, 2, 3, 4, ..., N) @@ -337,7 +331,7 @@ class TensorArrayGatherOp : public XlaOpKernel { } } - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, @@ -360,7 +354,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape value_shape = ctx->InputShape(2); @@ -375,11 +369,11 @@ class TensorArrayScatterOp : public XlaOpKernel { OP_REQUIRES(ctx, indices_shape.dims() >= 1, errors::InvalidArgument("indices must be rank 1")); const int num_indices = indices_shape.dim_size(0); - const xla::ComputationDataHandle indices = ctx->Input(1); + const xla::XlaOp indices = ctx->Input(1); - xla::ComputationDataHandle ta = resource->value(); - const xla::ComputationDataHandle value = ctx->Input(2); - const xla::ComputationDataHandle flow = ctx->Input(3); + xla::XlaOp ta = resource->value(); + const xla::XlaOp value = ctx->Input(2); + const xla::XlaOp flow = ctx->Input(3); // Look for the case where the scatter is for each sub-tensor in order. The // tensor array implementation allows for this to be a straight addition. @@ -443,7 +437,7 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -453,7 +447,7 @@ class TensorArrayConcatOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -503,12 +497,12 @@ class TensorArraySplitOp : public XlaOpKernel { TensorShape elem_shape = value_shape; elem_shape.set_dim(0, length); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); TensorShape ta_shape; ta_shape.AddDim(resource->tensor_array_size()); @@ -520,8 +514,8 @@ class TensorArraySplitOp : public XlaOpKernel { "TensorArray's size is not equal to the size of lengths (", lengths.size(), " vs. ", resource->tensor_array_size(), ")")); - const xla::ComputationDataHandle value = ctx->Input(1); - const xla::ComputationDataHandle flow = ctx->Input(3); + const xla::XlaOp value = ctx->Input(1); + const xla::XlaOp flow = ctx->Input(3); OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), errors::InvalidArgument("mismatched element count ", @@ -569,7 +563,7 @@ class TensorArrayGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9aefcd4fc7f94a..e91075196bd841 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -112,7 +112,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::ComputationDataHandle output = + xla::XlaOp output = ctx->builder()->Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index f750f7003be288..34caefa050c0d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -30,8 +30,8 @@ class ResourceApplyGradientDescent : public XlaOpKernel { explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaOp handle; + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -63,12 +63,12 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; - xla::ComputationDataHandle var, accum; + xla::XlaOp var, accum; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); @@ -93,9 +93,9 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(2); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle momentum = ctx->Input(4); + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); accum = b->Add(b->Mul(accum, momentum), grad); if (use_nesterov_) { @@ -121,12 +121,12 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; - xla::ComputationDataHandle var, accum; + xla::XlaOp var, accum; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); @@ -146,8 +146,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(2); - xla::ComputationDataHandle grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); var = b->Sub( @@ -168,7 +168,7 @@ class ResourceApplyAdam : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape var_shape, m_shape, v_shape; - xla::ComputationDataHandle var, m, v; + xla::XlaOp var, m, v; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); @@ -213,25 +213,25 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle beta1_power = ctx->Input(3); - xla::ComputationDataHandle beta2_power = ctx->Input(4); - xla::ComputationDataHandle lr = ctx->Input(5); - xla::ComputationDataHandle beta1 = ctx->Input(6); - xla::ComputationDataHandle beta2 = ctx->Input(7); - xla::ComputationDataHandle epsilon = ctx->Input(8); - xla::ComputationDataHandle grad = ctx->Input(9); + xla::XlaOp beta1_power = ctx->Input(3); + xla::XlaOp beta2_power = ctx->Input(4); + xla::XlaOp lr = ctx->Input(5); + xla::XlaOp beta1 = ctx->Input(6); + xla::XlaOp beta2 = ctx->Input(7); + xla::XlaOp epsilon = ctx->Input(8); + xla::XlaOp grad = ctx->Input(9); // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); - xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::ComputationDataHandle alpha = + xla::XlaOp alpha = b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), b->Sub(one, beta1_power)); m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); @@ -255,12 +255,12 @@ class ResourceApplyRMSProp : public XlaOpKernel { explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(3); TensorShape var_shape, ms_shape, mom_shape; - xla::ComputationDataHandle var, ms, mom; + xla::XlaOp var, ms, mom; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); @@ -297,11 +297,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(3); - xla::ComputationDataHandle rho = ctx->Input(4); - xla::ComputationDataHandle momentum = ctx->Input(5); - xla::ComputationDataHandle epsilon = ctx->Input(6); - xla::ComputationDataHandle grad = ctx->Input(7); + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp momentum = ctx->Input(5); + xla::XlaOp epsilon = ctx->Input(6); + xla::XlaOp grad = ctx->Input(7); // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -320,16 +320,16 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::ComputationDataHandle new_ms = b->Add( + xla::XlaOp new_ms = b->Add( ms, b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); - xla::ComputationDataHandle new_mom = + xla::XlaOp new_mom = b->Add(b->Mul(mom, momentum), b->Mul(b->Mul(grad, lr), b->Pow(b->Add(new_ms, epsilon), XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::ComputationDataHandle new_var = b->Sub(var, new_mom); + xla::XlaOp new_var = b->Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -341,10 +341,10 @@ REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape var_shape, accum_shape, linear_shape; - xla::ComputationDataHandle var, accum, linear; + xla::XlaOp var, accum, linear; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); @@ -399,12 +399,12 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle lr = ctx->Input(4); - xla::ComputationDataHandle l1 = ctx->Input(5); - xla::ComputationDataHandle l2 = ctx->Input(6); - xla::ComputationDataHandle l2_shrinkage; - xla::ComputationDataHandle lr_power; + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp l1 = ctx->Input(5); + xla::XlaOp l2 = ctx->Input(6); + xla::XlaOp l2_shrinkage; + xla::XlaOp lr_power; if (has_l2_shrinkage) { l2_shrinkage = ctx->Input(7); lr_power = ctx->Input(8); @@ -421,26 +421,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, // var = (linear_clipped - linear) / quadratic // accum = new_accum - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - xla::ComputationDataHandle grad_to_use; + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::XlaOp grad_to_use; if (has_l2_shrinkage) { grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); } else { grad_to_use = grad; } - xla::ComputationDataHandle new_accum = - b->Add(accum, b->Pow(grad_to_use, two)); - xla::ComputationDataHandle new_accum_lr_pow = - b->Pow(new_accum, b->Neg(lr_power)); - xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); + xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); + xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); + xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); linear = b->Add( linear, b->Sub(grad_to_use, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::ComputationDataHandle linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::ComputationDataHandle quadratic = - b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); + xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); + xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); var = b->Div(b->Sub(linear_clipped, linear), quadratic); accum = new_accum; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 7cb47f908d4ff4..71a9fd051bfc8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -33,9 +33,9 @@ namespace { public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ - xla::ComputationBuilder* b = ctx->builder(); \ - xla::ComputationDataHandle x = ctx->Input(0); \ - xla::ComputationDataHandle y = COMPUTATION; \ + xla::XlaBuilder* b = ctx->builder(); \ + xla::XlaOp x = ctx->Input(0); \ + xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ } \ }; \ @@ -100,8 +100,7 @@ XLAJIT_MAKE_UNARY(Cosh, XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); -// TODO(b/34703906): use a more accurate implementation of expm1. -XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); @@ -115,8 +114,7 @@ XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Log, b->Log(x)); -// TODO(b/34703906): use a more accurate implementation of log1p. -XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); +XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); XLAJIT_MAKE_UNARY(Invert, b->Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); @@ -124,9 +122,8 @@ XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, - DataType dtype, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, + const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); @@ -148,9 +145,8 @@ XLAJIT_MAKE_UNARY(Rsqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, - DataType dtype, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, + const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); } @@ -162,26 +158,17 @@ XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -static xla::ComputationDataHandle Softplus( - xla::ComputationBuilder* b, DataType dtype, - const xla::ComputationDataHandle& features) { - xla::ComputationDataHandle threshold = - b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), - XlaHelpers::FloatLiteral(b, dtype, 2.0)); - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold)); - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - xla::ComputationDataHandle too_small = b->Lt(features, threshold); - xla::ComputationDataHandle features_exp = b->Exp(features); - xla::ComputationDataHandle output = b->Select( - too_large, features, - b->Select(too_small, features_exp, - b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); - return output; -} -XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); +// softplus(x) = log(1 + exp(x)) +// +// This is not numerically stable when x is large, it can easily overflow. +// However, we can compute it as LogSumExp(x, 0): +// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) +// +// This is equivalent to: +// max(x, 0) + log1p(exp(-abs(x))) +XLAJIT_MAKE_UNARY(Softplus, + b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 71173f5aead477..a163fa0a5b3467 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -48,7 +48,7 @@ class ReadVariableOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK( ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); @@ -57,7 +57,7 @@ class ReadVariableOp : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); +REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); class AssignVariableOp : public XlaOpKernel { public: @@ -67,14 +67,14 @@ class AssignVariableOp : public XlaOpKernel { ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); } }; -REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); +REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); @@ -90,7 +90,7 @@ class AssignSubVariableOp : public XlaOpKernel { explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); @@ -105,19 +105,19 @@ class ResourceGatherOp : public XlaOpKernel { public: explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); DataType type = ctx->expected_output_dtype(0); TensorShape resource_shape; - xla::ComputationDataHandle resource_handle; + xla::XlaOp resource_handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/false, type, index_type, diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 0ff1b65ae9179d..5467c5d9946846 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -101,7 +101,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { ctx, MakeXlaCompilerArgumentsFromInputs( ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); XlaCompiler* compiler = ctx->compiler(); VLOG(1) << "Compiling body"; @@ -234,7 +234,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); - std::vector inputs(num_inputs); + std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; if (ctx->input_type(input_num) == DT_RESOURCE) { @@ -246,24 +246,24 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::ComputationDataHandle init = builder->Tuple(inputs); + xla::XlaOp init = builder->Tuple(inputs); VLOG(1) << "Building while loop"; // Wraps the condition in a computation that unpacks the output tuple. - xla::Computation cond_wrapper; + xla::XlaComputation cond_wrapper; { - std::unique_ptr cb = + std::unique_ptr cb = builder->CreateSubBuilder("cond_wrapper"); auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); auto outputs = cb->Call(*cond.computation, {inputs}); cb->GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); + xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::ComputationDataHandle while_result = + xla::XlaOp while_result = builder->While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index fde1977c1b1834..ee7f5d510ab7a3 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,8 +44,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -62,9 +62,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -80,10 +80,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -91,6 +90,7 @@ cc_library( xla_test( name = "triangular_solve_test", srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", @@ -100,9 +100,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -121,8 +121,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -140,7 +140,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -160,8 +159,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 798f0fa78055e8..526694d5a0c712 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,24 +25,22 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x, bool conjugate_y) { - TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, - builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, - builder->GetShape(y)); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, + bool conjugate_y) { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); // Check that both tensors have the same number of dimensions. There must be // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(*x_shape), " vs. ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs. ", + xla::ShapeUtil::HumanString(y_shape)); } - const int ndims = xla::ShapeUtil::Rank(*x_shape); + const int ndims = xla::ShapeUtil::Rank(x_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to BatchedDot must have rank >= 2: ", ndims); @@ -52,46 +50,46 @@ xla::StatusOr BatchDot( // valid. std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - if (x_shape->dimensions(i) != y_shape->dimensions(i)) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " vs ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); } batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(*y_shape), + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), " transpose: ", transpose_y); } // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(*x_shape) || - xla::ShapeUtil::HasZeroElements(*y_shape)) { + if (xla::ShapeUtil::HasZeroElements(x_shape) || + xla::ShapeUtil::HasZeroElements(y_shape)) { std::vector dimensions(batch_dimension_numbers.size()); for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); } int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), dimensions); } - if (x_shape->element_type() == xla::C64 && conjugate_x) { + if (x_shape.element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && conjugate_y) { + if (y_shape.element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b230e885f10f45..1acc72033b05e7 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -43,10 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x = false, bool conjugate_y = false); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 203365e2ab07e0..3f1384bc864abd 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -47,23 +47,21 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - const int n_dims = xla::ShapeUtil::Rank(*a_shape); - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape->dimensions()), +xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, + const xla::XlaOp& a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); - xla::ComputationDataHandle l = Zeros(builder, *a_shape); + xla::XlaOp l = Zeros(builder, a_shape); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr> { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { xla::Shape col_shape; xla::Shape row_shape; for (int64 d : major_dims) { @@ -72,12 +70,12 @@ xla::StatusOr CholeskyUnblocked( } row_shape.add_dimensions(1); row_shape.add_dimensions(n); - row_shape.set_element_type(a_shape->element_type()); + row_shape.set_element_type(a_shape.element_type()); auto mask_zeros_row = Zeros(body_builder, row_shape); col_shape.add_dimensions(n); col_shape.add_dimensions(1); - col_shape.set_element_type(a_shape->element_type()); + col_shape.set_element_type(a_shape.element_type()); auto mask_zeros_col = Zeros(body_builder, col_shape); std::vector mask_vector(n); @@ -101,7 +99,7 @@ xla::StatusOr CholeskyUnblocked( TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, {i, i}, {1, 1})); // np.dot(row, np.swapaxes(row, -1, -2)) - xla::ComputationDataHandle diag_dot; + xla::XlaOp diag_dot; TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, /*transpose_x=*/false, /*transpose_y=*/true)); @@ -109,7 +107,7 @@ xla::StatusOr CholeskyUnblocked( // np.swapaxes(row, -1, -2))) auto l_ii = body_builder->Pow( body_builder->Sub(a_ii, diag_dot), - FloatLiteral(body_builder, a_shape->element_type(), 0.5)); + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); @@ -140,7 +138,7 @@ xla::StatusOr CholeskyUnblocked( TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( body_builder, body_l, l_ii, {i, i})); - return std::vector{body_a, body_l}; + return std::vector{body_a, body_l}; }; TF_ASSIGN_OR_RETURN( @@ -152,22 +150,20 @@ xla::StatusOr CholeskyUnblocked( } // namespace -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to Cholesky must have rank >= 2: ", ndims); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { @@ -179,7 +175,7 @@ xla::StatusOr Cholesky( // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::ComputationDataHandle l = Zeros(builder, *a_shape); + xla::XlaOp l = Zeros(builder, a_shape); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -218,7 +214,7 @@ xla::StatusOr Cholesky( /*lower=*/true, /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/8)); + /*block_size=*/block_size)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 17da8d8b22d107..20fca7969ece27 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -30,9 +30,8 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size = 256); +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 45699233ea8b2a..d5a27abb2585f6 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -30,24 +30,19 @@ limitations under the License. namespace tensorflow { -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, - builder->GetShape(buffer)); - TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, - builder->GetShape(updates)); - TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, - builder->GetShape(indices)); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); + TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); gtl::ArraySlice indices_dims = - xla::AsInt64Slice(indices_shape->dimensions()); + xla::AsInt64Slice(indices_shape.dimensions()); gtl::ArraySlice buffer_dims = - xla::AsInt64Slice(buffer_shape->dimensions()); + xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -55,12 +50,12 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", - xla::ShapeUtil::HumanString(*indices_shape), + xla::ShapeUtil::HumanString(indices_shape), ") must be <= the rank of the buffer (shape: ", - xla::ShapeUtil::HumanString(*buffer_shape), ")"); + xla::ShapeUtil::HumanString(buffer_shape), ")"); } indices_dims.pop_back(); } @@ -78,10 +73,10 @@ xla::StatusOr XlaScatter( // If any of the indexed dimensions are zero in the buffer, the update cannot // succeed since it updates a slice of size 1. for (int64 i = 0; i < num_index_dims; ++i) { - if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { - return errors::InvalidArgument( - "Scatter dimension ", i, " is of size zero in tensor with shape ", - xla::ShapeUtil::HumanString(*buffer_shape)); + if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { + return errors::InvalidArgument("Scatter dimension ", i, + " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(buffer_shape)); } } @@ -111,18 +106,17 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* body_builder) { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; auto buffer = loop_vars[2]; auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape->element_type())); + xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; + xla::XlaOp index; auto indices_offset = body_builder->Reshape(i, {1}); if (indices_are_vectors) { indices_offset = body_builder->Pad(indices_offset, zero_index, @@ -180,12 +174,12 @@ xla::StatusOr XlaScatter( // Apply the update. buffer = body_builder->DynamicUpdateSlice(buffer, update, index); - return std::vector{indices, updates, buffer}; + return std::vector{indices, updates, buffer}; }; - TF_ASSIGN_OR_RETURN( - auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), - body_fn, init, "scatter", builder)); + TF_ASSIGN_OR_RETURN(auto outputs, + XlaForEachIndex(num_indices, indices_shape.element_type(), + body_fn, init, "scatter", builder)); return outputs[2]; } diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 41e6d3b195ebf9..87309e10ede320 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { @@ -39,14 +39,12 @@ namespace tensorflow { // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 9bf5821b54abe3..b4503601f94baa 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,21 +29,20 @@ limitations under the License. namespace tensorflow { -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(*a_shape), " vs. ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - const int ndims = xla::ShapeUtil::Rank(*a_shape); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to TriangularSolve must have rank >= 2: ", ndims); @@ -51,30 +50,30 @@ xla::StatusOr TriangularSolve( // The batch dimensions must be equal. std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); - int64 b_size = b_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { return errors::InvalidArgument( "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(*a_shape, -1) != - xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { @@ -83,26 +82,18 @@ xla::StatusOr TriangularSolve( block_size); } - // Applies a complex conjugation operation if `a` is complex and `conjugate_a` - // is true, otherwise returns its argument. - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - - std::map base_computations; + std::map base_computations; auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::Computation& computation = base_computations[k]; + [&](int k) -> xla::StatusOr { + xla::XlaComputation& computation = base_computations[k]; if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( + std::unique_ptr sub = builder->CreateSubBuilder( tensorflow::strings::StrCat("trsm_base_", k)); auto a_param = sub->Parameter( 0, xla::ShapeUtil::MakeShape( - b_shape->element_type(), + b_shape.element_type(), PrependMajorDims(sub.get(), batch_dimensions, {k, k})), "a"); @@ -115,20 +106,25 @@ xla::StatusOr TriangularSolve( auto b_param = sub->Parameter( 1, xla::ShapeUtil::MakeShape( - b_shape->element_type(), + b_shape.element_type(), PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), "b"); - // We use a left-looking subroutine on the block diagonal in some common - // cases, while falling back to a recursive call in unsupported cases. The - // left-looking subroutine is written with a While loop and so yields much - // faster compile times. Moreover, the left-looking variant can give - // higher performance on smaller (sub)problems. + // We use a left-looking or right-looking subroutine on the block diagonal + // in the lower=true cases, while falling back to a recursive call in + // others. The left-looking and right-looking subroutines are written with + // a While loop and so yields much faster compile times. Moreover, they + // can give higher performance on smaller (sub)problems. if (left_side && lower) { TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, b_param, transpose_a, conjugate_a) .status()); + } else if (!left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); } else { TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, left_side, lower, transpose_a, @@ -142,7 +138,7 @@ xla::StatusOr TriangularSolve( return &computation; }; - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); // Right-looking blocked triangular solve. // For an explanation of the algorithm, see the TRSM discussion in: @@ -165,13 +161,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -181,7 +179,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) if (i + k < n) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); @@ -215,13 +213,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -231,7 +231,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) if (i + k < m) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); @@ -264,13 +264,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -280,7 +282,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -314,13 +316,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -330,7 +334,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -356,29 +360,23 @@ xla::StatusOr TriangularSolve( return output; } -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); batch_dimensions.push_back(a_size); } - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - // The main computation is performed in a While loop. // Allocate the output and set its first or last row, @@ -387,14 +385,16 @@ xla::StatusOr TriangularSolveLeftLooking( // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] // else: // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); { auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + auto update = builder->Div(b_slice, a_slice_conj); TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); } @@ -408,11 +408,11 @@ xla::StatusOr TriangularSolveLeftLooking( // The loop iteration counter is a scalar, incremented each iteration. xla::ShapeUtil::MakeShape(xla::S32, {}), // The output has the shape of b, with one row updated each iteration. - *b_shape, + b_shape, // The coefficient matrix a is a loop invariant. - *a_shape, + a_shape, // The right-hand-side matrix b is a loop invariant. - *b_shape}; + b_shape}; xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); auto init = builder->Tuple({init_i, output, a, b}); @@ -421,7 +421,7 @@ xla::StatusOr TriangularSolveLeftLooking( // def cond_fun(loop_carry): // i, output, a, b = loop_carry // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); { auto i = condb->GetTupleElement( @@ -451,7 +451,7 @@ xla::StatusOr TriangularSolveLeftLooking( // return (i + 1, output, a, b) // We have to do some extra FLOPs propagating zeros in the matrix multiply // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); { auto input_tuple = bodyb->Parameter(0, tuple_shape, @@ -475,7 +475,7 @@ xla::StatusOr TriangularSolveLeftLooking( // But since we can't have intermediate array sizes depend on the loop // counter, we instead exploit the fact that we initialized the output to // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::ComputationDataHandle a_row; + xla::XlaOp a_row; if (transpose_a) { TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, {zero, i}, {m, 1})); @@ -496,7 +496,9 @@ xla::StatusOr TriangularSolveLeftLooking( // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, {i, i}, {1, 1})); - auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); + TF_ASSIGN_OR_RETURN(auto a_elt_conj, + MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_elt_conj); TF_ASSIGN_OR_RETURN(body_out, DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, div_result, {i, zero})); @@ -516,4 +518,130 @@ xla::StatusOr TriangularSolveLeftLooking( return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); + } + + // The main computation is performed in a While loop. + xla::XlaOp output = Zeros(builder, b_shape); + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (0, output, a, b) + // else: + // init = (n-1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Lt(i, condb->ConstantR0(n)); + } else { + condb->Ge(i, condb->ConstantR0(0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter( + 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, + // i:i+1]) But since we can't have intermediate array sizes depend on the + // loop counter, we instead exploit the fact that we initialized the output + // to all zeros and use that as zero-padding (doing unnecessary FLOPs). + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + // result = b - np.matmul(output, a) + auto result = bodyb->Sub(body_b, b_update); + // result_row = result[..., :, i:i+1] + TF_ASSIGN_OR_RETURN( + auto result_row, + DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); + + // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); + TF_ASSIGN_OR_RETURN(auto a_ii_conj, + MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_ii_conj); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {zero, i})); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index e32223bfdddda8..540c26b2473df9 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -57,14 +57,23 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 256); +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); + +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 66170706291626..87ea4763f7c235 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -80,9 +80,9 @@ xla::Array2D AValsFull() { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -102,9 +102,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -124,9 +124,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -146,9 +146,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -168,9 +168,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -191,9 +191,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -214,9 +214,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -237,9 +237,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -260,9 +260,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -288,9 +288,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -318,9 +318,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, @@ -340,9 +340,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { } XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 31d823ca336039..d9ff7e6259f3fb 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -27,15 +27,14 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape) { +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value) { +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value) { switch (type) { case xla::F16: return builder->ConstantR0(static_cast(value)); @@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, - int64 value) { +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value) { xla::Literal literal; switch (type) { case xla::U8: @@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, return builder->ConstantLiteral(literal); } -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end) { +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end) { TF_RET_CHECK(start.size() == end.size()); int64 n_minor_dims = start.size(); - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - n_minor_dims); @@ -140,7 +139,7 @@ xla::StatusOr SliceInMinorDims( return builder->Slice(x, padded_start, padded_end, strides); } -std::vector PrependMajorDims(xla::ComputationBuilder* builder, +std::vector PrependMajorDims(xla::XlaBuilder* builder, const gtl::ArraySlice& major_dims, const gtl::ArraySlice& indices) { std::vector output(indices.size() + major_dims.size()); @@ -149,16 +148,16 @@ std::vector PrependMajorDims(xla::ComputationBuilder* builder, return output; } -xla::StatusOr DynamicSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts, +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, const gtl::ArraySlice& sizes) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - sizes.size()); TF_ASSIGN_OR_RETURN(auto padded_starts, @@ -167,27 +166,29 @@ xla::StatusOr DynamicSliceInMinorDims( return builder->DynamicSlice(x, padded_starts, padded_sizes); } -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); auto start_constant = builder->ConstantR1(start_as_int32); - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); - TF_ASSIGN_OR_RETURN(std::unique_ptr start_constant_shape, + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, builder->GetShape(start_constant)); const int64 start_length = - xla::ShapeUtil::GetDimension(*start_constant_shape, -1); + xla::ShapeUtil::GetDimension(start_constant_shape, -1); TF_RET_CHECK(start_length == n_dims); return builder->DynamicUpdateSlice(x, update, start_constant); } -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -196,22 +197,21 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, - const std::vector& starts) { +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts) { TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(builder, x, starts)); return builder->DynamicUpdateSlice(x, update, padded_starts); } -xla::StatusOr PrependZerosInMajorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); auto zero = builder->Reshape(builder->ConstantR0(0), {1}); - std::vector padded_starts(n_dims, zero); + std::vector padded_starts(n_dims, zero); for (int i = 0; i < starts.size(); ++i) { padded_starts[n_dims - starts.size() + i] = builder->Reshape(starts[i], {1}); @@ -219,10 +219,10 @@ xla::StatusOr PrependZerosInMajorDims( return builder->ConcatInDim(padded_starts, 0); } -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); @@ -230,4 +230,11 @@ xla::StatusOr TransposeInMinorDims( return builder->Transpose(x, permutation); } +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? builder->Conj(x) : x; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index b684123f1363cf..3c120a2548576d 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,75 +16,79 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape); +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value); +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value); // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. -xla::ComputationDataHandle PrependZerosInMajorDims( - xla::ComputationBuilder* builder, - gtl::ArraySlice starts); +xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, + gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, int64 value); +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value); // Builds a vector of zeros of length rank(x) with the last two values being // those in `starts`. -xla::StatusOr PrependZerosInMajorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts); +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end); +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end); // Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. -std::vector PrependMajorDims(xla::ComputationBuilder* builder, +std::vector PrependMajorDims(xla::XlaBuilder* builder, const gtl::ArraySlice& major_dims, const gtl::ArraySlice& indices); // Performs a dynamic slice in the minor dimensions of a Tensor. -xla::StatusOr DynamicSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const std::vector& starts, - const gtl::ArraySlice& sizes); +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, const gtl::ArraySlice& sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); -xla::StatusOr DynamicUpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, - const std::vector& starts); +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x); + +// Applies a complex conjugation operation if `a` is complex and `conjugate_a` +// is true, otherwise returns its argument. +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc index b6bd33af2e42a4..265b39402c832f 100644 --- a/tensorflow/compiler/tf2xla/lib/util_test.cc +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -65,9 +64,9 @@ xla::Array3D BatchedAValsFull() { } XLA_TEST_F(UtilTest, Simple2dLookup) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, x, y; + xla::XlaOp a, x, y; auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); @@ -80,9 +79,9 @@ XLA_TEST_F(UtilTest, Simple2dLookup) { } XLA_TEST_F(UtilTest, Simple3dLookup) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, index; + xla::XlaOp a, index; auto a_data = CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); @@ -97,9 +96,9 @@ XLA_TEST_F(UtilTest, Simple3dLookup) { } XLA_TEST_F(UtilTest, SimpleSliceUpdate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b, x, y; + xla::XlaOp a, b, x, y; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); @@ -117,11 +116,11 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) { } XLA_TEST_F(UtilTest, RowBatchDot) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); int n = 4; - xla::ComputationDataHandle a, row, index; + xla::XlaOp a, row, index; auto a_data = CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 495d9c60780b0a..09ce594930efc0 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -20,24 +20,24 @@ limitations under the License. namespace tensorflow { -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::ComputationDataHandle& input : initial_values) { + for (const xla::XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); - var_shapes.push_back(std::move(*shape)); + var_shapes.push_back(std::move(shape)); } xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, - xla::ComputationBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](xla::XlaOp tuple, int arity, + xla::XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { elements[i] = builder->GetTupleElement(tuple, i); } @@ -45,20 +45,20 @@ xla::StatusOr> XlaWhileLoop( }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); - TF_ASSIGN_OR_RETURN( - auto result, + TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), - cond_builder.get())); + cond_builder.get()) + .status()); } TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); @@ -78,38 +78,38 @@ xla::StatusOr> XlaWhileLoop( return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { - auto while_cond_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* cond_builder) - -> xla::StatusOr { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { + auto while_cond_fn = + [&](gtl::ArraySlice values, + xla::XlaBuilder* cond_builder) -> xla::StatusOr { return cond_builder->Lt( values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr> { - xla::ComputationDataHandle iteration = values[0]; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); updated_values.push_back(body_builder->Add( iteration, body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); values.push_back( builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 2e67a0c99b6deb..5b6684c995889e 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -29,14 +29,14 @@ namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function(gtl::ArraySlice, + xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function>( + gtl::ArraySlice, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -47,27 +47,26 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::ComputationDataHandle, gtl::ArraySlice, - xla::ComputationBuilder*)> +typedef std::function>( + xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2c3cd658e04623..43e1c1e9fecec1 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,7 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && xla::ShapeUtil::ElementsIn(literal.shape()) == @@ -63,8 +63,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal, return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index f283b0236811f8..220bec15538c36 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -36,13 +36,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor); +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 6051d7dffd7493..ac768b206e2a8d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -251,7 +251,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation) { + xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; @@ -303,7 +302,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, } // InitGraph creates a graph based on the graph_def, that may then be converted -// to an xla::Computation via ConvertGraphToXla. +// to an xla::XlaComputation via ConvertGraphToXla. // // The graph is rewritten with _Arg and _Retval nodes, representing the inputs // and outputs of the function that will be compiled. Each feed id causes a new @@ -348,7 +347,7 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation) { + xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index 473c431b12d441..d02fc56c5b8f58 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -18,21 +18,21 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { -// Converts a tensorflow::GraphDef into an xla::Computation. The given `config` -// specifies the portion of the graph to convert, via feeds and fetches. Each -// feed is a positional input argument for the generated computation, while each -// fetch is a positional output argument. +// Converts a tensorflow::GraphDef into an xla::XlaComputation. The given +// `config` specifies the portion of the graph to convert, via feeds and +// fetches. Each feed is a positional input argument for the generated +// computation, while each fetch is a positional output argument. // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation); + xla::XlaComputation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index b813668a9edd3a..84c133ffabe20d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -69,7 +69,7 @@ TEST(ConvertGraphDefToXla, Sum) { tf2xla::Config config = SumConfig(); xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 7ec85aa3cdec62..9203e8d9e607e9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -232,7 +232,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = id.first.ToString(); + const string node_name = std::string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fcb0a4e63814b4..fe7ec633eca250 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" @@ -108,7 +109,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device // 0. - xla::ScopedShardingAssignment assign_sharding(b, op_sharding); + xla::XlaScopedShardingAssignment assign_sharding(b, op_sharding); op_kernel->Compute(context); b->ClearOpMetadata(); @@ -126,9 +127,7 @@ Status XlaCompilationDevice::MakeTensorFromProto( XlaExpression::XlaExpression() = default; -void XlaExpression::set_handle(const xla::ComputationDataHandle& h) { - handle_ = h; -} +void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } void XlaExpression::set_constant_value(Tensor value) { has_constant_value_ = true; diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 0243ee332fbdca..d0b9e34e162f34 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" @@ -69,7 +69,7 @@ class XlaCompilationDevice : public LocalDevice { // A XlaExpression wraps an XLA computation. Each Tensor on an // XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the ComputationDataHandle. Each +// matches the shape of the subcomputation in the XlaOp. Each // expression is either a constant, or a function of previously-compiled // expressions. class XlaExpression { @@ -78,8 +78,8 @@ class XlaExpression { // handle() stores the XLA handle of the computation that the // expression represents. - void set_handle(const xla::ComputationDataHandle& h); - const xla::ComputationDataHandle& handle() const { return handle_; } + void set_handle(const xla::XlaOp& h); + const xla::XlaOp& handle() const { return handle_; } void set_constant_value(Tensor value); bool has_constant_value() const { return has_constant_value_; } @@ -90,7 +90,7 @@ class XlaExpression { private: // The XLA handle of the expression's computation. - xla::ComputationDataHandle handle_; + xla::XlaOp handle_; // If this expression is a constant with a known value, 'constant_value' is a // host-memory Tensor containing the value. Used to avoid invoking XLA for diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c0e996768491a6..a8bd199675b4ad 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include #include +#include -#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -40,7 +38,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -86,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -110,10 +104,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); - // The default variable representation shape is the identity function. - if (!options_.variable_representation_shape_fn) { - options_.variable_representation_shape_fn = - [](const TensorShape& shape, DataType type) { return shape; }; + // The default shape representation function is the identity. + if (!options_.shape_representation_fn) { + options_.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return shape; }; } } @@ -230,20 +224,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, - xla::Shape* xla_shape) { + bool is_entry_computation, + xla::Shape* xla_shape) const { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), - xla_shape); - case XlaCompiler::Argument::kParameter: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + LOG(FATAL) << "Unreachable case"; + case XlaCompiler::Argument::kParameter: { + TensorShape shape = + is_entry_computation + ? options_.shape_representation_fn(arg.shape, arg.type) + : arg.shape; + return TensorShapeToXLAShape(arg.type, shape, xla_shape); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.variable_representation_shape_fn(arg.shape, arg.type); + options_.shape_representation_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -337,16 +336,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& arg_cores, - const std::vector& retvals, + const std::vector& retvals, const std::vector>& resources, - bool return_updated_values_for_all_resources, - xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, + bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, + xla::XlaComputation* computation, int* num_computation_outputs, + int* num_nonconst_outputs, + std::vector* outputs, std::vector* resource_updates) { - std::vector elems; + std::vector elems; elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + output.type = retvals[i].type; + output.shape = retvals[i].shape; + const XlaExpression& retval = retvals[i].expression; + if (retval.has_constant_value()) { + output.is_constant = true; + output.constant_value = retval.constant_value(); + } else { + output.is_constant = false; elems.push_back(retval.handle()); } } @@ -376,14 +384,12 @@ Status BuildComputation( const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = - resource->value().handle() != resource->initial_value().handle(); + bool modified = resource->value() != resource->initial_value(); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { modified = modified || - grad.second->value().handle() != - grad.second->initial_value().handle() || + grad.second->value() != grad.second->initial_value() || arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { @@ -398,11 +404,11 @@ Status BuildComputation( } // Request that the value be returned on a specific core. - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - xla::ComputationDataHandle handle; + xla::XlaOp handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); // Since we can't change the sharding metadata of as this point, @@ -421,7 +427,7 @@ Status BuildComputation( builder->Tuple(elems); builder->ClearOpMetadata(); - xla::StatusOr computation_status = builder->Build(); + xla::StatusOr computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status(); } @@ -435,7 +441,7 @@ Status BuildComputation( // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { @@ -461,8 +467,7 @@ Status XlaCompiler::BuildArguments( // alias. XlaResource* resource; TF_RETURN_IF_ERROR(context->CreateResource( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::ComputationDataHandle(), + arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); @@ -493,8 +498,8 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR( - XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR(XLAShapeForArgument( + args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -531,9 +536,9 @@ Status XlaCompiler::BuildArguments( builder->SetOpMetadata(arg_metadata); // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { - xla::ComputationDataHandle tuple; + xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); @@ -544,15 +549,15 @@ Status XlaCompiler::BuildArguments( core == -1 ? xla::sharding_builder::AssignDevice(root_device) : xla::sharding_builder::AssignDevice(core); } - xla::ScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); + xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, + tuple_sharding); tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } else { tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); @@ -560,7 +565,7 @@ Status XlaCompiler::BuildArguments( } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = @@ -570,7 +575,8 @@ Status XlaCompiler::BuildArguments( builder->ClearOpMetadata(); - // Fill in the handles in non-constant arguments. + // Fill in the handles in non-constant arguments, and reshape parameters + // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -589,7 +595,15 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - arg_expression.set_handle(arg_handles[i]); + // Reshape parameters back to their correct shapes. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + if (is_entry_computation) { + arg_expression.set_handle( + builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + } else { + arg_expression.set_handle(arg_handles[i]); + } break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -642,12 +656,71 @@ Status XlaCompiler::CompileSingleOp( return CompileGraph(options, name, std::move(graph), args, result); } +namespace { + +// Check that the ops of all non-functional nodes have been registered. +string ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + std::vector invalid_ops; + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { + invalid_ops.push_back(op); + } + } + return tensorflow::str_util::Join(invalid_ops, ", "); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + std::vector invalid_ops; + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + if (fdef) { + string error_msg = ValidateFunctionDef(fdef, flib_def); + if (!error_msg.empty()) { + invalid_ops.push_back( + strings::StrCat(node->def().op(), ":{", error_msg, "}")); + } + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { + invalid_ops.push_back(node->def().op()); + continue; + } + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { + invalid_ops.push_back(node->def().op()); + } + } + if (!invalid_ops.empty()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ":", + tensorflow::str_util::Join(invalid_ops, ", "))); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, const std::vector& args, CompilationResult* result) { - VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " @@ -661,13 +734,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Converts Tensorflow's graph control-flow constructs into functional // control-flow that can be compiled into XLA code. TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); - - xla::ComputationBuilder builder(client(), name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, - &options_.variable_representation_shape_fn); + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); + + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); + + xla::XlaBuilder builder(name); + XlaContext* context = new XlaContext( + this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, options.is_entry_computation, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -683,36 +762,23 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; - result->computation = std::make_shared(); + result->computation = std::make_shared(); + result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(context->retvals().size()); - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (retval.has_constant_value()) { - OutputDescription& output = result->outputs[i]; - output.shape = retval.constant_value().shape(); - output.is_constant = true; - output.constant_value = retval.constant_value(); - } - } - // Compute the output shapes, if there is a computation with non-constant + // Compute the XLA output shape, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(*result->computation); - if (!computation_shape.ok()) { - return computation_shape.status(); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, + client()->GetComputationShape(*result->computation)); - result->xla_output_shape.Swap( - computation_shape.ValueOrDie()->mutable_result()); + result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); @@ -727,23 +793,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - // Converts the output shapes to TensorShapes. - int computation_output = 0; - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (!retval.has_constant_value()) { - TF_RET_CHECK(computation_output < num_computation_outputs) - << "Computation has more outputs than expected"; - OutputDescription& output = result->outputs[i]; - output.is_constant = false; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, - computation_output), - &output.shape)); - ++computation_output; - } - } return Status::OK(); } @@ -814,7 +863,7 @@ Status XlaCompiler::SetHostToDeviceMetadata( } Status XlaCompiler::GetHostComputeControlDependency( - const string& host_compute_name, xla::ComputationDataHandle* handle) { + const string& host_compute_name, xla::XlaOp* handle) { const auto iter = host_compute_control_output_.find(host_compute_name); if (iter == host_compute_control_output_.end()) { return errors::InvalidArgument( @@ -827,7 +876,7 @@ Status XlaCompiler::GetHostComputeControlDependency( } Status XlaCompiler::SetHostComputeControlDependency( - const string& host_compute_name, const xla::ComputationDataHandle& handle) { + const string& host_compute_name, const xla::XlaOp& handle) { if (host_compute_control_output_.find(host_compute_name) != host_compute_control_output_.end()) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 8f564f35ec8176..c93850ce270502 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -38,7 +39,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // @@ -67,6 +68,15 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -171,7 +181,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. + // Type and shape of the output. The shape is the unflattened shape. DataType type; TensorShape shape; @@ -206,10 +216,12 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Input shapes of the computation. + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. std::vector xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -227,13 +239,15 @@ class XlaCompiler { std::vector resource_updates; // The XLA computation built from the tensorflow subgraph. - std::shared_ptr computation; + std::shared_ptr computation; }; + typedef std::function + ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; @@ -250,8 +264,7 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - std::function - variable_representation_shape_fn; + ShapeRepresentationFn shape_representation_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -281,7 +294,7 @@ class XlaCompiler { const NameAttrList& fn_name_attrs, std::vector args, CompilationResult* result); - // Compiles a tensorflow::Graph into an xla::Computation. + // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, @@ -290,7 +303,7 @@ class XlaCompiler { CompilationResult* result); // Compiles a single Op, given by an OpKernelContext, into an - // xla::Computation. Similar to CompileFunction but takes a single Op as + // xla::XlaComputation. Similar to CompileFunction but takes a single Op as // input. Status CompileSingleOp(const CompileOptions& options, string const& name, OpKernelContext* ctx, @@ -300,7 +313,8 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -337,10 +351,9 @@ class XlaCompiler { // a given HostCompute Op as long as the names are unique within the // compilation. Status GetHostComputeControlDependency(const string& host_compute_name, - xla::ComputationDataHandle* handle); - Status SetHostComputeControlDependency( - const string& host_compute_name, - const xla::ComputationDataHandle& handle); + xla::XlaOp* handle); + Status SetHostComputeControlDependency(const string& host_compute_name, + const xla::XlaOp& handle); const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } @@ -358,7 +371,7 @@ class XlaCompiler { // `args` are the arguments to the computation. Status BuildArguments(const Graph& graph, const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, @@ -408,8 +421,7 @@ class XlaCompiler { std::unordered_map host_compute_sends_; std::unordered_map host_compute_recvs_; - std::unordered_map - host_compute_control_output_; + std::unordered_map host_compute_control_output_; TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 096dc7160bfc0a..5fbf4b952c6e6f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,12 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -43,8 +45,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -56,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -66,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -164,7 +163,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); - // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); @@ -226,7 +224,7 @@ TEST_F(XlaCompilerTest, Simple) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { @@ -321,7 +319,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE( + xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } { @@ -356,10 +355,80 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } +TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { + // Define a function with one compile-time constant output and one + // data-dependent output. + // @function.Defun(noinline=True) + // foo(a) {b=7; return b, a; } + const Tensor seven = test::AsScalar(7); + FunctionDef fdef = FunctionDefHelper::Create( + "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {}, + { + {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}}, + }, + {{"a", "a_0"}, {"const", "Const:output:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0); + NodeDef foo; + foo.set_name("foo"); + foo.set_op("foo"); + *foo.add_input() = "input_arg"; + Status status; + scope.graph()->AddNode(foo, &status); + TF_ASSERT_OK(status); + NodeDef retval_1; + retval_1.set_name("retval_0"); + retval_1.set_op(FunctionLibraryDefinition::kRetOp); + *retval_1.add_input() = "foo"; + (*retval_1.mutable_attr())["T"].set_type(DT_INT32); + (*retval_1.mutable_attr())["index"].set_i(0); + scope.graph()->AddNode(retval_1, &status); + TF_ASSERT_OK(status); + NodeDef retval_2; + retval_2.set_name("retval_1"); + retval_2.set_op(FunctionLibraryDefinition::kRetOp); + *retval_2.add_input() = "foo:1"; + (*retval_2.mutable_attr())["T"].set_type(DT_INT32); + (*retval_2.mutable_attr())["index"].set_i(1); + scope.graph()->AddNode(retval_2, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1}); + + XlaCompiler::Options options = DefaultOptions(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph), args, &result)); + + ASSERT_EQ(2, result.outputs.size()); + EXPECT_TRUE(result.outputs[0].is_constant); + test::ExpectTensorEqual(result.outputs[0].constant_value, + test::AsScalar(7)); + EXPECT_FALSE(result.outputs[1].is_constant); +} + // Tests compilation and execution of a graph that adds two tensors. TEST_F(XlaCompilerTest, ResourceManager) { // Builds a graph that calls the dummy resource Op. @@ -433,21 +502,26 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { } for (int64 i = 1; i < test_count; ++i) { - auto m1 = - results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); - auto m2 = - results[i].computation->Snapshot().ValueOrDie()->entry().requests(); - // Check if every entry is the same. - for (auto& entry1 : m1) { - int64 key = entry1.first; - auto value1 = entry1.second; - auto entry2 = m2.find(key); - auto value2 = entry2->second; - EXPECT_TRUE(entry2 != m2.end()); - string str1, str2; - value1.AppendToString(&str1); - value2.AppendToString(&str2); - EXPECT_EQ(str1, str2); + const auto& m1 = results[i - 1].computation->proto(); + const auto& m2 = results[i].computation->proto(); + ASSERT_EQ(m1.computations_size(), m2.computations_size()); + // Check if every hlo computation is the same. + for (int k = 0; k < m1.computations_size(); k++) { + const auto& c1 = m1.computations(k); + const auto& c2 = m2.computations(k); + ASSERT_EQ(c1.instructions_size(), c2.instructions_size()); + for (int j = 0; j < c1.instructions_size(); j++) { + auto instr1 = c1.instructions(j); + auto instr2 = c2.instructions(j); + instr1.clear_name(); + instr2.clear_name(); + // The names of instructions were uniquified by the XlaBuilder, the rest + // of the fields should be identical. + string str1, str2; + instr1.AppendPartialToString(&str1); + instr2.AppendPartialToString(&str2); + EXPECT_EQ(str1, str2); + } } } } @@ -519,7 +593,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -742,13 +816,10 @@ TEST_F(XlaCompilerTest, Variables) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a simple graph that reads and writes a variable, with a -// variable_representation_shape_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { +xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -759,7 +830,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_ASSERT_OK(scope.ToGraph(graph.get())); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + return std::move(graph); +} + +// Tests a simple graph that reads and writes a variable, with a +// shape_representation_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); // Builds a description of the arguments. std::vector args(2); @@ -774,15 +853,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.variable_representation_shape_fn = [](const TensorShape& shape, - DataType type) { + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; // Only reshape variables. + XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::Compatible(program_shape->parameters(0), + xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -807,7 +904,149 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; // Reshape args and retvals. + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {4}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({4, 55, 1, -3}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({27, 67, 35, 402}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); } } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 8423921086fec1..098072d33cd4eb 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,28 +63,32 @@ void XlaContext::set_args(std::vector args) { } XlaContext::XlaContext( - XlaCompiler* compiler, xla::ComputationBuilder* builder, + XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn) + shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - variable_representation_shape_fn_(variable_representation_shape_fn) {} + is_entry_computation_(is_entry_computation), + shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const xla::ComputationDataHandle& handle) { + const TensorShape& shape, const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - retvals_[retval_index].set_handle(handle); + XlaExpression e; + e.set_handle(handle); + retvals_[retval_index] = Retval{type, shape, e}; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -94,23 +98,20 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - if (resolve_compile_time_constants_) { - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - retvals_[retval_index].set_constant_value(std::move(value)); - } else { - retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); - } + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + XlaExpression e; + e.set_constant_value(value); + retvals_[retval_index] = Retval{dtype, value.shape(), e}; return Status::OK(); } -xla::ComputationBuilder* XlaContext::builder() { return builder_; } +xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, - TensorShape shape, const xla::ComputationDataHandle& handle, - int64 tensor_array_size, const std::set& tensor_array_gradients, - XlaResource** resource) { + TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource) { resources_.emplace_back( new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), handle, tensor_array_size, tensor_array_gradients)); @@ -118,16 +119,16 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, - DataType type) const { - return (*variable_representation_shape_fn_)(shape, type); +TensorShape XlaContext::RepresentationShape(const TensorShape& shape, + DataType type) const { + return (*shape_representation_fn_)(shape, type); } -const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">"); + xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -137,11 +138,11 @@ const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { return LookupOrCreate(type, &min_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "min<" + type_string + ">"); + xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -151,11 +152,11 @@ const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { return LookupOrCreate(type, &add_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">"); + xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -165,11 +166,11 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { return LookupOrCreate(type, &mul_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -179,9 +180,9 @@ const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { }); } -const xla::Computation* XlaContext::LookupOrCreate( +const xla::XlaComputation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create) { + const std::function& create) { { const auto& entry = (*out)[type]; if (!entry.IsNull()) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 00fbaba37c5429..341bf6ff1f37fa 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -42,32 +42,44 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. - XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn); + shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; XlaCompiler* compiler() const { return compiler_; } - // Returns the ComputationBuilder that Ops use for compiling new - // expressions. - xla::ComputationBuilder* builder(); + // Returns the XlaBuilder that Ops use for compiling new expressions. + xla::XlaBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool resolve_compile_time_constants() const { + return resolve_compile_time_constants_; + } + bool is_entry_computation() const { return is_entry_computation_; } + const std::vector& args() const { return args_; } void set_args(std::vector args); - const std::vector& retvals() { return retvals_; } + struct Retval { + DataType type; + TensorShape shape; + // An XlaExpression representing the Retval's value. + XlaExpression expression; + }; + const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, - const xla::ComputationDataHandle& handle); + void AddRetval(int retval_index, DataType type, const TensorShape& shape, + const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -79,8 +91,7 @@ class XlaContext : public ResourceBase { // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, - const xla::ComputationDataHandle& handle, - int64 tensor_array_size, + const xla::XlaOp& handle, int64 tensor_array_size, const std::set& tensor_array_gradients, XlaResource** resource); @@ -89,29 +100,29 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`. - TensorShape VariableRepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`, or of an argument or return value of a top-level computation. + TensorShape RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(const DataType type); // Get an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(const DataType type); // Get an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(const DataType type); // Get an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(const DataType type); // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -119,9 +130,8 @@ class XlaContext : public ResourceBase { private: XlaCompiler* const compiler_; - // The ComputationBuilder used to construct the subgraph's compiled - // representation. - xla::ComputationBuilder* builder_; + // The XlaBuilder used to construct the subgraph's compiled representation. + xla::XlaBuilder* builder_; // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; @@ -135,25 +145,33 @@ class XlaContext : public ResourceBase { std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // A function that describes how variable shapes should be represented - // in XLA. Variable values will be reshaped to this shape. Must be non-null. + // Is this a top-level computation, or an inner computation (e.g., a while + // body)? + const bool is_entry_computation_; + + // A function that describes how the shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared shapes + // for computations. Must be non-null. const std::function* - variable_representation_shape_fn_; + shape_representation_fn_; // Cache of prebuilt computations indexed by their type. - using ComputationMap = std::map; + using ComputationMap = std::map; // Finds the value for the given type in out map if it already // exists or makes a new value with create function and keeps it the // map. The returned value != nullptr and is owned by the map. - const xla::Computation* LookupOrCreate( + const xla::XlaComputation* LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create); + const std::function& create); // Cached computation to compute Max of two elements, specialized by type. ComputationMap max_func_; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 62a5114837e07f..f1594193af09c7 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -32,13 +32,12 @@ namespace tensorflow { namespace { -Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, bool is_min, - xla::ComputationDataHandle* argminmax) { - xla::ComputationDataHandle init_value; - const xla::Computation* reducer; +Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + bool is_min, xla::XlaOp* argminmax) { + xla::XlaOp init_value; + const xla::XlaComputation* reducer; if (is_min) { init_value = XlaHelpers::MaxValue(builder, input_type); reducer = ctx->GetOrCreateMin(input_type); @@ -50,13 +49,13 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, xla::PrimitiveType xla_output_type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - xla::ComputationDataHandle input_max = builder->Reduce( - input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); + xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, + /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.dims() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. - xla::ComputationDataHandle partial_mask = builder->ConvertElementType( + xla::XlaOp partial_mask = builder->ConvertElementType( builder->Eq(input, input_max, broadcast_dims), xla_output_type); // In order to make identity elements for a bitwise And, we: @@ -65,23 +64,23 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, // 0xFF...F int32 bits_in_type = xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; - xla::ComputationDataHandle shift_amount = + xla::XlaOp shift_amount = XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic( + xla::XlaOp full_mask = builder->ShiftRightArithmetic( builder->ShiftLeft(partial_mask, shift_amount), shift_amount); // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its // index. - xla::ComputationDataHandle iota; + xla::XlaOp iota; const int64 axis_size = input_shape.dim_size(axis); TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); - xla::ComputationDataHandle product = + xla::XlaOp product = builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); // If there are multiple maximum elements, choose the one with the highest // index. - xla::ComputationDataHandle output = + xla::XlaOp output = builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), *ctx->GetOrCreateMax(output_type), /*dimensions_to_reduce=*/{axis}); @@ -91,36 +90,31 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, } // namespace -xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MinValue(type)); } -xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MaxValue(type)); } -xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::Zero(type)); } -xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::One(type)); } -xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { switch (data_type) { case DT_HALF: return b->ConstantR0( @@ -137,16 +131,15 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, } } -xla::ComputationDataHandle XlaHelpers::IntegerLiteral( - xla::ComputationBuilder* b, DataType data_type, int64 value) { +xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64 value) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return ::tensorflow::IntegerLiteral(b, type, value); } -xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, - DataType data_type, - double value) { +xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return ::tensorflow::FloatLiteral(b, type, value); @@ -183,28 +176,24 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, +Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmax) { + DataType output_type, int axis, xla::XlaOp* argmax) { return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, axis, /*is_min=*/false, argmax); } -Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, +Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmin) { + DataType output_type, int axis, xla::XlaOp* argmin) { return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, axis, /*is_min=*/true, argmin); } -Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, - int64 size, xla::ComputationDataHandle* iota) { +Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, + xla::XlaOp* iota) { TensorShape linspace_shape({size}); Tensor linspace; switch (dtype) { @@ -227,13 +216,10 @@ Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, return Status::OK(); } -Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, - int axis, DataType index_type, - const TensorShape& indices_shape, - const xla::ComputationDataHandle& indices, - const xla::ComputationDataHandle& on_value, - const xla::ComputationDataHandle& off_value, - xla::ComputationDataHandle* one_hot) { +Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, + DataType index_type, const TensorShape& indices_shape, + const xla::XlaOp& indices, const xla::XlaOp& on_value, + const xla::XlaOp& off_value, xla::XlaOp* one_hot) { const int indices_dims = indices_shape.dims(); const int output_dims = indices_dims + 1; @@ -267,7 +253,7 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::ComputationDataHandle one_hot_bool = builder->Eq( + xla::XlaOp one_hot_bool = builder->Eq( indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. @@ -278,16 +264,15 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { - if (dtype == DT_BFLOAT16) { + if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } return dtype; } -xla::ComputationDataHandle XlaHelpers::ConvertElementType( - xla::ComputationBuilder* const builder, - const xla::ComputationDataHandle& operand, - const DataType new_element_type) { +xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, + const xla::XlaOp& operand, + const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); return builder->ConvertElementType(operand, convert_to); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 68ab93b64a5fa8..c3fdc5252e7436 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -30,41 +30,34 @@ class XlaHelpers { public: // Returns a handle representing the minimum value of a scalar // element of data_type. - static xla::ComputationDataHandle MinValue(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the maximum value of a scalar // element of data_type. - static xla::ComputationDataHandle MaxValue(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the zero value of a scalar // element of data_type. - static xla::ComputationDataHandle Zero(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the one value of a scalar // element of data_type. - static xla::ComputationDataHandle One(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); // Returns the machine epsilon for floating-point type `data_type`, i.e., // the difference between 1.0 and the next representable value. - static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. - static xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* b, - DataType data_type, - int64 value); + static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64 value); // Returns a handle representing the given value of a floating-point scalar // element of data_type. - static xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* b, - DataType data_type, - double value); + static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value); // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. @@ -75,38 +68,32 @@ class XlaHelpers { // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and // `input_dtype` are the shape and dtype of `input` respectively, and // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmax); + static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + xla::XlaOp* argmax); // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and // `input_dtype` are the shape and dtype of `input` respectively, and // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmin); + static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + xla::XlaOp* argmin); // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::ComputationBuilder* builder, DataType dtype, - int64 size, xla::ComputationDataHandle* iota); + static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, + xla::XlaOp* iota); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and // `off_value` represent the values to use for the on and off positions, // respectively. - static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, + static Status OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, - const xla::ComputationDataHandle& indices, - const xla::ComputationDataHandle& on_value, - const xla::ComputationDataHandle& off_value, - xla::ComputationDataHandle* one_hot); + const xla::XlaOp& indices, const xla::XlaOp& on_value, + const xla::XlaOp& off_value, xla::XlaOp* one_hot); // Certain DataTypes should use increased precision DataTypes when performing // reductions. This function remaps a given DataType to a higher precision @@ -115,10 +102,9 @@ class XlaHelpers { // A helper for creating a ConvertElementType xla op given a DataType rather // than the xla::PrimitiveType. - static xla::ComputationDataHandle ConvertElementType( - xla::ComputationBuilder* const builder, - const xla::ComputationDataHandle& operand, - const DataType new_element_type); + static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder, + const xla::XlaOp& operand, + const DataType new_element_type); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1fe6e69ff2dc83..9e17756b27733e 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -112,10 +112,10 @@ void CollectNames(const T& entries, std::vector* nonempty_names, XlaJitCompiledCpuFunction::Compile( const GraphDef& graph_def, const tf2xla::Config& config, const xla::ExecutableBuildOptions& build_options) { - // Convert the graph_def into an xla::Computation. + // Convert the graph_def into an xla::XlaComputation. TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, &computation)); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c4bb90d58755f1..76c68d81af4dd9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -30,7 +30,7 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); } -xla::ComputationBuilder* XlaOpKernelContext::builder() const { +xla::XlaBuilder* XlaOpKernelContext::builder() const { return XlaContext::Get(this).builder(); } @@ -38,9 +38,9 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().handle() != 0 || + CHECK(expression->handle().builder() != nullptr || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle().handle(); + VLOG(1) << "Fetched T" << expression->handle(); return expression; } @@ -48,20 +48,18 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().handle(), 0); + CHECK_EQ(expression->handle().builder(), nullptr); return const_cast(expression); } -// Retrieves the ComputationDataHandle from an input Tensor to an Op. This -// computation was constructed by an Op that executed previously and -// created the output Tensor using CreateOutputTensorFromComputation -// or CreateConstantOutputTensor. -static const xla::ComputationDataHandle& GetComputationFromTensor( - const Tensor& tensor) { +// Retrieves the XlaOp from an input Tensor to an Op. This computation was +// constructed by an Op that executed previously and created the output Tensor +// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. +static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { return CastExpressionFromTensor(tensor)->handle(); } -const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { +const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } @@ -106,7 +104,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( return HostTensorToLiteral(temp, constant_literal); } - xla::ComputationDataHandle handle = expression->handle(); + xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. handle = builder()->Reshape(handle, new_shape.dim_sizes()); @@ -141,8 +139,17 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Ask the XLA compiler to evaluate the data handle to a literal. + xla::StatusOr constant_graph = + builder()->BuildConstantSubGraph(handle); + if (!constant_graph.ok()) { + return errors::Internal( + "Error getting a compile-time constant graph for ", + context_->op_kernel().name(), " input ", index, + ".\nError: ", constant_graph.status().error_message()); + } xla::StatusOr> computed = - builder()->ComputeConstant(handle, &layout); + compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), + &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, @@ -260,9 +267,9 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList( - StringPiece name, std::vector* handles, - std::vector* shapes) { +Status XlaOpKernelContext::InputList(StringPiece name, + std::vector* handles, + std::vector* shapes) { OpInputList inputs; TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); @@ -285,9 +292,9 @@ Status XlaOpKernelContext::ConstantInputList( return Status::OK(); } -Status XlaOpKernelContext::ReadVariableInput( - int index, DataType type, TensorShape* shape, - xla::ComputationDataHandle* value) { +Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, + TensorShape* shape, + xla::XlaOp* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -307,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput( } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = xla_context.VariableRepresentationShape( - variable->shape(), variable->type()); + TensorShape representation_shape = + xla_context.RepresentationShape(variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -334,8 +341,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -void XlaOpKernelContext::SetOutput(int index, - const xla::ComputationDataHandle& handle) { +void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; auto shape = builder()->GetShape(handle); @@ -349,7 +355,7 @@ void XlaOpKernelContext::SetOutput(int index, // corresponds. TensorShape tensor_shape; OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape)); + XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, context_->allocate_output(index, tensor_shape, &output)); @@ -364,8 +370,8 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { xla::Literal literal; OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); - xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); - CHECK_NE(handle.handle(), 0); + xla::XlaOp handle = builder()->ConstantLiteral(literal); + CHECK_NE(handle.builder(), nullptr); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -386,8 +392,7 @@ void XlaOpKernelContext::SetInvalidOutput(int index) { OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape({}), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::ComputationDataHandle handle; - handle.set_handle(0); + xla::XlaOp handle; expression->set_handle(handle); } @@ -410,8 +415,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { } Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, - xla::ComputationDataHandle handle) { - TF_RET_CHECK(handle.handle() != 0); + xla::XlaOp handle) { + TF_RET_CHECK(handle.builder() != nullptr); const XlaExpression* expression = CastExpressionFromTensor(context_->input(input_index)); @@ -425,13 +430,13 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, } TensorShape shape; TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.VariableRepresentationShape(shape, type); + xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); } @@ -457,22 +462,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, context_->CtxFailureWithWarning(file, line, s); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMax( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { return XlaContext::Get(context_).GetOrCreateMax(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMin( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( const DataType type) { return XlaContext::Get(context_).GetOrCreateMin(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( const DataType type) { return XlaContext::Get(context_).GetOrCreateAdd(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMul( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const DataType type) { return XlaContext::Get(context_).GetOrCreateMul(type); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 4e4b97e0cec8d1..667dc262ca03ca 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -58,8 +58,8 @@ class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); - // Returns the XLA ComputationBuilder containing the output of compilation. - xla::ComputationBuilder* builder() const; + // Returns the XLA XlaBuilder containing the output of compilation. + xla::XlaBuilder* builder() const; // Inputs @@ -72,10 +72,10 @@ class XlaOpKernelContext { // Returns the shape of input 'index'. TensorShape InputShape(int index); - // Returns input 'index' as a ComputationDataHandle. Unlike + // Returns input 'index' as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::ComputationDataHandle& Input(int index); + const xla::XlaOp& Input(int index); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -85,8 +85,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, - std::vector* handles, + Status InputList(StringPiece name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -132,10 +131,10 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } - // Sets output 'index' to the ComputationDataHandle 'handle'. + // Sets output 'index' to the XlaOp 'handle'. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. - void SetOutput(int index, const xla::ComputationDataHandle& handle); + void SetOutput(int index, const xla::XlaOp& handle); // Sets output 'index' to compile-time constant 'host_tensor', where // 'host_tensor' is a tensor in host memory. It is preferable to use @@ -168,14 +167,13 @@ class XlaOpKernelContext { // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. Status ReadVariableInput(int index, DataType type, TensorShape* shape, - xla::ComputationDataHandle* value); + xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. - Status AssignVariable(int input_index, DataType type, - xla::ComputationDataHandle handle); + Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -205,22 +203,22 @@ class XlaOpKernelContext { // Gets an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(const DataType type); // Gets an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(const DataType type); // Gets an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(const DataType type); // Gets an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(const DataType type); private: OpKernelContext* const context_; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index bbe808595d9583..4692038b61f687 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); @@ -311,7 +311,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = name.ToString(); + registration_->name = std::string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { @@ -323,14 +323,14 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( gtl::ArraySlice devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); return *this; } @@ -347,7 +347,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; types.insert(allowed); return *this; } @@ -355,7 +355,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, gtl::ArraySlice allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -364,7 +364,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(input_name.ToString()); + registration_->compile_time_constant_inputs.insert(std::string(input_name)); return *this; } @@ -394,7 +394,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( StringPiece name, gtl::ArraySlice types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(name.ToString(), types, op_filter); + registry.RegisterBackend(std::string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index c2075b44b82ba2..540c65c597f20d 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, - TensorShape shape, - const xla::ComputationDataHandle& initial_value, + TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, const std::set& tensor_array_gradients) : kind_(kind), @@ -41,11 +40,10 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, CHECK(kind_ != kInvalid); for (const string& gradient : tensor_array_gradients) { - tensor_array_gradients_[gradient].reset( - new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), - type_, shape_, xla::ComputationDataHandle(), - tensor_array_size_, /*tensor_array_gradients=*/{})); + tensor_array_gradients_[gradient].reset(new XlaResource( + /*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, + xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); } } @@ -73,7 +71,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { return Status::OK(); } -Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { +Status XlaResource::SetValue(const xla::XlaOp& value) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -83,7 +81,7 @@ Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { return Status::OK(); } -Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { +Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -121,9 +119,9 @@ Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { return Status::OK(); } -Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::ComputationBuilder* builder, - XlaResource** gradient_out) { +Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, + xla::XlaBuilder* builder, + XlaResource** gradient_out) { VLOG(2) << "Gradient lookup for resource: " << name_ << " gradient: " << source; TF_RET_CHECK(kind_ == kTensorArray); @@ -132,7 +130,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient( TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::ComputationDataHandle gradient_value = builder->Broadcast( + xla::XlaOp gradient_value = builder->Broadcast( XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, @@ -144,13 +142,12 @@ Status XlaResource::GetOrCreateTensorArrayGradient( return Status::OK(); } -Status XlaResource::Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const { +Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { if (tensor_array_gradients_.empty()) { *pack = value_; } else { TF_RET_CHECK(kind_ == kTensorArray); - std::vector elems; + std::vector elems; elems.push_back(value_); for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); @@ -161,8 +158,8 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack, } Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder) { + const xla::XlaOp& pack, + xla::XlaBuilder* builder) { if (gradient_sources.empty()) { if (!initialized()) { initial_value_ = pack; diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 1bb2c7274ecdf0..9ce36d1aa76223 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -37,8 +37,7 @@ class XlaResource { }; XlaResource(Kind kind, int arg_num, string name, DataType type, - TensorShape shape, - const xla::ComputationDataHandle& initial_value, + TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, const std::set& tensor_array_gradients); @@ -69,16 +68,14 @@ class XlaResource { // this is the shape of each entry in the TensorArray/Stack. const TensorShape& shape() const { return shape_; } - const xla::ComputationDataHandle& value() const { return value_; } + const xla::XlaOp& value() const { return value_; } // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. - const xla::ComputationDataHandle& initial_value() const { - return initial_value_; - } + const xla::XlaOp& initial_value() const { return initial_value_; } // A variable is initialized if it has a value. - bool initialized() const { return value_.handle() > 0; } + bool initialized() const { return value_.builder() != nullptr; } // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. @@ -86,17 +83,17 @@ class XlaResource { // Sets the current value of the resource. Returns an error if the type is not // set to a valid value. - Status SetValue(const xla::ComputationDataHandle& value); + Status SetValue(const xla::XlaOp& value); // Sets the current value of the resource to an all-zero value. - Status SetZeroValue(xla::ComputationBuilder* builder); + Status SetZeroValue(xla::XlaBuilder* builder); // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator // documentation for TensorArrayGradV3 for details. Status GetOrCreateTensorArrayGradient(const string& source, - xla::ComputationBuilder* builder, + xla::XlaBuilder* builder, XlaResource** gradient_out); // Packs a resource into a single XLA value `pack`, suitable for use as @@ -104,8 +101,7 @@ class XlaResource { // gradients, sets `*pack` to `value`. // For TensorArrays with gradients, packs the value and its gradient values in // a tuple; the gradients values are packed in order by source name. - Status Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const; + Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and @@ -114,8 +110,7 @@ class XlaResource { // values. // Opposite of Pack(). Status SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder); + const xla::XlaOp& pack, xla::XlaBuilder* builder); // TensorArray and Stack specific fields @@ -144,8 +139,8 @@ class XlaResource { DataType type_; TensorShape shape_; - xla::ComputationDataHandle value_; - xla::ComputationDataHandle initial_value_; + xla::XlaOp value_; + xla::XlaOp initial_value_; int64 tensor_array_size_ = -1; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1af9cb6d2ab15a..1b8e516770c3e2 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -53,7 +53,6 @@ xla_proto_library( deps = [ ":xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", ], ) @@ -99,8 +98,9 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":status", + ":xla_data_proto", ":xla_proto", - "//tensorflow/core:lib", ], ) @@ -244,6 +244,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protobuf_util", + ":status", ":status_macros", ":statusor", ":types", @@ -302,13 +303,13 @@ cc_library( ":array2d", ":array3d", ":array4d", - ":shape_tree", ":shape_util", ":sparse_index_array", ":status_macros", ":types", ":util", ":xla_data_proto", + "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) @@ -323,12 +324,30 @@ tf_cc_test( ":shape_util", ":test", ":types", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) +cc_library( + name = "error_spec", + hdrs = ["error_spec.h"], +) + +cc_library( + name = "literal_comparison", + srcs = ["literal_comparison.cc"], + hdrs = ["literal_comparison.h"], + deps = [ + ":error_spec", + ":literal_util", + ":util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "metric_table_report", srcs = ["metric_table_report.cc"], @@ -563,6 +582,7 @@ tf_cc_test( ":shape_util", ":test", ":xla_data_proto", + "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index c93c39e180655e..39f8caaa961dc7 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -1 +1,7 @@ -This is the home of XLA. +

+ +

+ +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear +algebra that optimizes TensorFlow computations. See the +[documentation](https://www.tensorflow.org/performance/xla/) for more details. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 286d06d12ffca7..8f08d3b2e04670 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -76,7 +75,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -87,6 +86,7 @@ cc_library( hdrs = ["executable_build_options.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", @@ -99,17 +99,18 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", @@ -125,7 +126,6 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -161,47 +161,6 @@ cc_library( ], ) -cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "computation_builder", - srcs = ["computation_builder.cc"], - hdrs = ["computation_builder.h"], - deps = [ - ":client", - ":computation", - ":global_data", - ":padding", - "//tensorflow/compiler/xla:array", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 328e1b8fa84e7b..3d596a6e65430b 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -64,7 +64,7 @@ StatusOr> Client::Transfer( } StatusOr> Client::TransferToServer( - const Literal& literal, const DeviceHandle* device_handle) { + const LiteralSlice& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; *request.mutable_literal() = literal.ToProto(); if (device_handle) { @@ -91,7 +91,7 @@ StatusOr> Client::TransferToServer( return MakeUnique(stub_, response.data()); } -Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, +Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; *request.mutable_literal() = literal.ToProto(); @@ -161,22 +161,6 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr data, - Execute(computation, arguments, execution_options, execution_profile)); - - const Shape* shape_with_output_layout = nullptr; - if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); - } - return Transfer(*data, shape_with_output_layout); -} - StatusOr> Client::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -221,65 +205,11 @@ StatusOr> Client::ComputeConstant( return Literal::CreateFromProto(response.literal()); } -StatusOr Client::LoadSnapshot(const SessionModule& module) { - LoadComputationSnapshotRequest request; - *request.mutable_module() = module; - LoadComputationSnapshotResponse response; - - Status s = stub_->LoadComputationSnapshot(&request, &response); - if (!s.ok()) { - return s; - } - - VLOG(1) << "load snapshot response: " << response.ShortDebugString(); - return Computation(stub_, response.computation()); -} - StatusOr Client::LoadSnapshot(const HloSnapshot& module) { TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); return XlaComputation(module.hlo().hlo_module()); } -StatusOr> Client::Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); - VLOG(1) << execution_stats; - } - } - - return MakeUnique(stub_, response.output()); -} - StatusOr> Client::Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -320,41 +250,6 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } -StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; - - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { - outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); -} - StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteGraphParallelRequest request; @@ -372,7 +267,7 @@ StatusOr>> Client::ExecuteParallel( ExecuteParallelResponse response; VLOG(1) << "making execute-graph-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + Status s = stub_->ExecuteGraphParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -401,7 +296,7 @@ StatusOr> Client::GetDeviceHandles( GetDeviceHandlesResponse response; VLOG(1) << "making get device request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + Status s = stub_->GetDeviceHandles(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -449,24 +344,6 @@ StatusOr>> Client::DeconstructTuple( return std::move(handles); } -StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); - *request.mutable_debug_options() = debug_options; - ComputationStatsResponse response; - - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - CHECK(response.has_stats()); - return response.stats(); -} - StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { @@ -488,23 +365,6 @@ StatusOr Client::GetComputationStats( return response.stats(); } -StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); -} - StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); @@ -527,28 +387,6 @@ StatusOr Client::GetShape(const GlobalData& data) { return response.shape(); } -StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN( - auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); - int64 total_flops = - computation_stats.flop_count() + computation_stats.transcendental_count(); - if (profile.compute_time_ns() > 0) { - int64 nanoseconds = profile.compute_time_ns(); - int64 cycle_count = profile.compute_cycle_count(); - double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( - "[Execution Statistics] flop count: ", computation_stats.flop_count(), - ", transcendental count: ", computation_stats.transcendental_count(), - ", compute execution time: ", nanoseconds, " nsec", - ", compute cycles: ", cycle_count, ", performance: ", gflops, - "gflop/s"); - } - return string("[Execution Statistics] not available."); -} - StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a63ff4c56d1dd7..68f0d0ac78c859 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,11 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,21 +51,6 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. - StatusOr> Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and returns the global - // data that was produced from the execution. - // * If execution_options is not nullptr, these options are passed to the - // service to affect how it compiles our computation. (The pointer does not - // need to live beyond this call.) - // * If execution_profile is not nullptr then the pointed-to ExecutionProfile - // will be filled with profile data from the execution. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -78,34 +62,6 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; - - // Executes a list ComputationInstances and returns global data produced from - // each computation. - StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); - - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct XlaComputationInstance { const XlaComputation& computation; std::vector arguments; @@ -125,7 +81,6 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); @@ -152,14 +107,14 @@ class Client { // device (and its replicas if replication is enabled). Otherwise, data is // transferred to the default device (and its replicas). StatusOr> TransferToServer( - const Literal& literal, const DeviceHandle* device_handle = nullptr); + const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr); // Transfer the given literal to the Infeed interface of the device. // // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, + Status TransferToInfeed(const LiteralSlice& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); // Transfers from the Outfeed of the device. @@ -177,17 +132,6 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and transfers the result - // to the client as a literal. Parameters are defined the same as for - // Execute() and Transfer(). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -209,8 +153,6 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; @@ -223,12 +165,6 @@ class Client { const GlobalData& data); // Retrieves the statistics of the given computation. - StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; - - // Retrieves the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; @@ -239,13 +175,6 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). - StatusOr> GetComputationShape( - const Computation& computation); - - // As above, but returns the shape of the provided computation (parameter - // types/names and return type). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> GetComputationShape( const XlaComputation& computation); @@ -253,9 +182,6 @@ class Client { // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - StatusOr LoadSnapshot(const SessionModule& module); - - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } @@ -263,8 +189,6 @@ class Client { private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, - const ExecutionProfile& profile); StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 96e38bca010879..dc69d2097ebe14 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -21,24 +21,6 @@ limitations under the License. namespace xla { -StatusOr>> -CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector service_instances; - service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = - service_instances.back(); - TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); - service_instance.argument_layouts = instance.argument_layouts; - service_instance.result_layout = instance.result_layout; - } - return compiler_service_->CompileAheadOfTime(service_instances, options); -} - StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index c8725b8517484a..f9a7c31270c7a1 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -38,26 +37,7 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; - // Inform the compiler of the expected layout for arguments. - std::vector argument_layouts; - // Specifies the expected result layout. - const Shape* result_layout; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); - // A description of an xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. @@ -69,8 +49,6 @@ class CompileOnlyClient : public Client { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. The |options| parameter describes // the target for which the compiler should emit code. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bda0f0c4c..00000000000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index a53fc9e9cf3470..00000000000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc deleted file mode 100644 index 83c7cb17440213..00000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ /dev/null @@ -1,1574 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation_builder.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { - -ComputationBuilder::ComputationBuilder(Client* client, - const string& computation_name) - : name_(computation_name), client_(client) {} - -ComputationBuilder::~ComputationBuilder() {} - -void ComputationBuilder::NoteError(const Status& error) { - if (die_immediately_on_error_) { - LOG(FATAL) << "error building computation: " << error; - } - - if (first_error_.ok()) { - first_error_ = error; - first_error_backtrace_.CreateCurrent(/*skip_count=*/1); - } -} - -std::unique_ptr ComputationBuilder::CreateSubBuilder( - const string& computation_name) { - auto sub_builder = MakeUnique(client_, computation_name); - sub_builder->parent_builder_ = this; - sub_builder->die_immediately_on_error_ = die_immediately_on_error_; - return sub_builder; -} - -Status ComputationBuilder::PrepareComputation() { - TF_RETURN_IF_ERROR(first_error_); - - if (!computation_.IsNull()) { - return Status::OK(); - } - - ComputationRequest request; - request.set_name(name_); - ComputationResponse response; - - VLOG(2) << "making computation request"; - Status s = client_->stub()->Computation(&request, &response); - VLOG(2) << "done with computation request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - computation_ = Computation(client_->stub(), response.computation()); - return Status::OK(); -} - -Status ComputationBuilder::RunOp(OpRequest* op_request, - OpResponse* op_response) { - TF_RETURN_IF_ERROR(first_error_); - TF_RETURN_IF_ERROR(PrepareComputation()); - - // Fill in fields that are set on every OpRequest. - *op_request->mutable_computation() = computation_.handle(); - *op_request->mutable_metadata() = metadata_; - if (sharding_) { - *op_request->mutable_sharding() = *sharding_; - } - - const string& op_name = - OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); - VLOG(2) << "running op request: " << op_name; - Status status = client_->stub()->Op(op_request, op_response); - VLOG(2) << "done with op request: " << op_name; - return status; -} - -void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - } -} - -ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( - OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - if (op_response.output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return op_response.output(); -} - -bool ComputationBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const size_t x, const char* x_name) { - if (x == 0 || x == window_dimensions.size()) { - return true; - } else { - NoteError(InvalidArgument( - "%s", tensorflow::strings::StrCat( - "Window has different number of window dimensions than of ", - x_name, "\nNumber of window dimensions: ", - window_dimensions.size(), "\nNumber of ", x_name, ": ", x, - "\n") - .c_str())); // - return false; - } - }; - if (!verify_size(window_strides.size(), "window strides") || - !verify_size(padding.size(), "padding entries") || - !verify_size(lhs_dilation.size(), "lhs dilation factors") || - !verify_size(rhs_dilation.size(), "rhs dilation factors")) { - return false; - } - - window->Clear(); - for (size_t i = 0; i < window_dimensions.size(); i++) { - auto dim = window->add_dimensions(); - dim->set_size(window_dimensions[i]); - if (!window_strides.empty()) { - dim->set_stride(window_strides[i]); - } else { - dim->set_stride(1); - } - if (!padding.empty()) { - dim->set_padding_low(padding[i].first); - dim->set_padding_high(padding[i].second); - } else { - dim->set_padding_low(0); - dim->set_padding_high(0); - } - if (!lhs_dilation.empty()) { - dim->set_base_dilation(lhs_dilation[i]); - } else { - dim->set_base_dilation(1); - } - if (!rhs_dilation.empty()) { - dim->set_window_dilation(rhs_dilation[i]); - } else { - dim->set_window_dilation(1); - } - dim->set_window_reversal(false); - } - return true; -} - -ComputationDataHandle ComputationBuilder::ConstantLiteral( - const Literal& literal) { - OpRequest op_request; - ConstantRequest* request = op_request.mutable_constant_request(); - *request->mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request->literal().ShortDebugString(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { - OpRequest op_request; - ParameterRequest* request = op_request.mutable_parameter_request(); - *request->mutable_shape() = shape; - request->set_parameter(parameter_number); - request->set_name(name); - return RunOpAndParseResponse(&op_request); -} - -StatusOr> ComputationBuilder::GetShapeWithoutNoteError( - const ComputationDataHandle& operand) { - GetLocalShapeRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - GetLocalShapeResponse response; - - VLOG(2) << "making get-shape request"; - TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); - VLOG(2) << "done with request"; - - TF_RET_CHECK(response.has_shape()); - std::unique_ptr shape = WrapUnique(response.release_shape()); - TF_RET_CHECK(shape != nullptr); - return std::move(shape); -} - -StatusOr> ComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - auto status_or_shape = GetShapeWithoutNoteError(operand); - if (!status_or_shape.ok()) { - NoteError(status_or_shape.status()); - return first_error_; - } - return status_or_shape; -} - -StatusOr ComputationBuilder::GetProgramShape() { - TF_RETURN_IF_ERROR(first_error_); - - GetComputationShapeRequest request; - *request.mutable_computation() = computation_.handle(); - GetComputationShapeResponse response; - - VLOG(2) << "making get-program-shape-request"; - Status status = client_->stub()->GetComputationShape(&request, &response); - VLOG(2) << "done with get-program-shape-request"; - - if (!status.ok()) { - first_error_ = status; - return status; - } - - TF_RET_CHECK(response.has_program_shape()); - return std::move(*response.mutable_program_shape()); -} - -ComputationDataHandle ComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - OpRequest op_request; - SliceRequest* request = op_request.mutable_slice_request(); - *request->mutable_operand() = operand; - for (int64 index : start_indices) { - request->add_start_indices(index); - } - for (int64 index : limit_indices) { - request->add_limit_indices(index); - } - for (int64 index : strides) { - request->add_strides(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - NoteError(shape_status.status()); - return ComputationDataHandle{}; - } - const Shape& shape = *shape_status.ValueOrDie(); - std::vector starts(ShapeUtil::Rank(shape), 0); - std::vector limits(shape.dimensions().begin(), - shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); - starts[dimno] = start_index; - limits[dimno] = limit_index; - strides[dimno] = stride; - return Slice(operand, starts, limits, strides); -} - -ComputationDataHandle ComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - OpRequest op_request; - DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_start_indices() = start_indices; - for (int64 index : slice_sizes) { - request->add_slice_sizes(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - OpRequest op_request; - DynamicUpdateSliceRequest* request = - op_request.mutable_dynamic_update_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_update() = update; - *request->mutable_start_indices() = start_indices; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - OpRequest op_request; - ConcatenateRequest* request = op_request.mutable_concatenate_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - request->set_dimension(dimension); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { - OpRequest op_request; - BroadcastRequest* request = op_request.mutable_broadcast_request(); - *request->mutable_operand() = operand; - for (int64 size : broadcast_sizes) { - request->add_broadcast_sizes(size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - OpRequest op_request; - PadRequest* request = op_request.mutable_pad_request(); - *request->mutable_operand() = operand; - *request->mutable_padding_value() = padding_value; - *request->mutable_padding_config() = padding_config; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { - OpRequest op_request; - ReshapeRequest* request = op_request.mutable_reshape_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (int64 new_size : new_sizes) { - request->add_new_sizes(new_size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - std::vector dimensions(shape.ValueOrDie()->dimensions().size()); - std::iota(dimensions.begin(), dimensions.end(), 0); - return Reshape(operand, dimensions, new_sizes); -} - -ComputationDataHandle ComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - // Don't support out-of-order collapse here. - // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { - if (dimensions[i] - 1 != dimensions[i - 1]) { - NoteError(InvalidArgument( - "Collapsed dimensions are not in order and consecutive.")); - return ComputationDataHandle(); - } - } - - // Create a new sizes vector from the old shape, replacing the collapsed - // dimensions by the product of their sizes. - StatusOr> shape_or_status = GetShape(operand); - if (!shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); - - VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); - - if (dimensions.size() <= 1) { - // Not collapsing anything, trivially we can return the operand versus - // enqueueing a trivial reshape. - return operand; - } - - std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dimensions.front() || i > dimensions.back()) { - new_sizes.push_back(original_shape->dimensions(i)); - } else { - new_sizes.back() *= original_shape->dimensions(i); - } - } - - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; - - return Reshape(operand, new_sizes); -} - -void ComputationBuilder::Trace(const string& tag, - const ComputationDataHandle& operand) { - OpRequest op_request; - TraceRequest* request = op_request.mutable_trace_request(); - request->set_tag(tag); - *request->mutable_operand() = operand; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); -} - -ComputationDataHandle ComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - OpRequest op_request; - VariadicOpRequest* request = op_request.mutable_variadic_op_request(); - request->set_varop(VAROP_TUPLE); - for (const ComputationDataHandle& operand : elements) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - OpRequest op_request; - GetTupleElementRequest* request = - op_request.mutable_get_tuple_element_request(); - *request->mutable_operand() = tuple_data; - request->set_index(index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - - DotDimensionNumbers dimension_numbers; - dimension_numbers.add_lhs_contracting_dimensions( - lhs_shape->dimensions_size() == 1 ? 0 : 1); - dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers) { - OpRequest op_request; - DotRequest* request = op_request.mutable_dot_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conv( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return ConvWithGeneralDimensions( - lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -bool ComputationBuilder::VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { - NoteError( - InvalidArgument("Convolution arguments must have same number of " - "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 2) { - NoteError(InvalidArgument( - "Convolution expects argument arrays with >= 3 dimensions. " - "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_spatial_dims = num_dims - 2; - - const auto check_spatial_dimensions = - [&](const char* const field_name, - const tensorflow::protobuf::RepeatedField& - numbers) { - if (numbers.size() != num_spatial_dims) { - NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", - num_spatial_dims, field_name, - numbers.size())); - return false; - } - for (int i = 0; i < numbers.size(); ++i) { - if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - NoteError( - InvalidArgument("Convolution %s[%d] is out of bounds: %lld", - field_name, i, numbers.Get(i))); - return false; - } - } - return true; - }; - return check_spatial_dimensions( - "input_spatial_dimensions", - dimension_numbers.input_spatial_dimensions()) && - check_spatial_dimensions( - "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()) && - check_spatial_dimensions( - "output_spatial_dimensions", - dimension_numbers.output_spatial_dimensions()); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - NoteError(InternalError("failed to verify convolution")); - return ComputationDataHandle(); - } - - std::vector base_area_dimensions( - dimension_numbers.input_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < base_area_dimensions.size(); - ++i) { - base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - return ConvGeneral(lhs, rhs, window_strides, - MakePadding(base_area_dimensions, window_dimensions, - window_strides, padding), - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - // Error is recorded in VerifyConvolution. - return ComputationDataHandle(); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - OpRequest op_request; - ConvolveRequest* request = op_request.mutable_convolve_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - - if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request->mutable_window())) { - // Error is recorded in MakeWindow. - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Fft( - const ComputationDataHandle& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { - OpRequest op_request; - FftRequest* request = op_request.mutable_fft_request(); - *request->mutable_operand() = operand; - request->set_fft_type(fft_type); - for (int64 dim_len : fft_length) { - request->add_fft_length(dim_len); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - OpRequest op_request; - InfeedRequest* request = op_request.mutable_infeed_request(); - *request->mutable_shape() = shape; - *request->mutable_config() = config; - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, - const string& outfeed_config) { - OpRequest op_request; - OutfeedRequest* request = op_request.mutable_outfeed_request(); - request->set_outfeed_config(outfeed_config); - *request->mutable_operand() = operand; - *request->mutable_shape() = shape_with_layout; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands) { - OpRequest op_request; - CallRequest* request = op_request.mutable_call_request(); - *request->mutable_to_apply() = computation.handle(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { - OpRequest op_request; - CustomCallRequest* request = op_request.mutable_custom_call_request(); - request->set_call_target_name(call_target_name); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - OpRequest op_request; - HostComputeRequest* request = op_request.mutable_host_compute_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - request->set_channel_name(channel_name); - request->set_cost_estimate_ns(cost_estimate_ns); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Conj( - const ComputationDataHandle& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -ComputationDataHandle ComputationBuilder::Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); -} - -// TODO(b/65209188): Create a dedicated lowering for Xor -ComputationDataHandle ComputationBuilder::Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); -} - -ComputationDataHandle ComputationBuilder::Not( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NOT, operand); -} - -ComputationDataHandle ComputationBuilder::ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Abs( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ABS, operand); -} - -ComputationDataHandle ComputationBuilder::Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Exp( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_EXP, operand); -} - -ComputationDataHandle ComputationBuilder::Floor( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_FLOOR, operand); -} - -ComputationDataHandle ComputationBuilder::Ceil( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CEIL, operand); -} - -ComputationDataHandle ComputationBuilder::Round( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand); -} - -ComputationDataHandle ComputationBuilder::Log( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOG, operand); -} - -ComputationDataHandle ComputationBuilder::Sign( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIGN, operand); -} - -ComputationDataHandle ComputationBuilder::Cos( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_COS, operand); -} - -ComputationDataHandle ComputationBuilder::Sin( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIN, operand); -} - -ComputationDataHandle ComputationBuilder::Tanh( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_TANH, operand); -} - -ComputationDataHandle ComputationBuilder::Real( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_REAL, operand); -} - -ComputationDataHandle ComputationBuilder::Imag( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IMAG, operand); -} - -ComputationDataHandle ComputationBuilder::IsFinite( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IS_FINITE, operand); -} - -ComputationDataHandle ComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - OpRequest op_request; - TransposeRequest* request = op_request.mutable_transpose_request(); - *request->mutable_operand() = operand; - for (int64 dimension : permutation) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - OpRequest op_request; - ReverseRequest* request = op_request.mutable_reverse_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Sort( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SORT, operand); -} - -ComputationDataHandle ComputationBuilder::SqrtF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BitcastConvertType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_bitcast_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SquareF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::ReciprocalF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Neg( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NEGATE, operand); -} - -ComputationDataHandle ComputationBuilder::Clz( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CLZ, operand); -} - -ComputationDataHandle ComputationBuilder::Clamp( - const ComputationDataHandle& min, const ComputationDataHandle& operand, - const ComputationDataHandle& max) { - return TernaryOp(TRIOP_CLAMP, min, operand, max); -} - -ComputationDataHandle ComputationBuilder::UnaryOp( - UnaryOperation unop, const ComputationDataHandle& operand) { - OpRequest op_request; - UnaryOpRequest* request = op_request.mutable_unary_op_request(); - request->set_unop(unop); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - OpRequest op_request; - BinaryOpRequest* request = op_request.mutable_binary_op_request(); - request->set_binop(binop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - for (int64 dimension : broadcast_dimensions) { - request->add_broadcast_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape) { - OpRequest op_request; - RngRequest* request = op_request.mutable_rng_request(); - request->set_distribution(distribution); - for (const ComputationDataHandle& param : parameters) { - *request->add_parameter() = param; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::TernaryOp( - TernaryOperation triop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - OpRequest op_request; - TernaryOpRequest* request = op_request.mutable_ternary_op_request(); - request->set_triop(triop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_ehs() = ehs; - return RunOpAndParseResponse(&op_request); -} - -Status ComputationBuilder::SetReturnValue( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - SetReturnValueRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - - SetReturnValueResponse response; - - VLOG(2) << "making set-handle-to-execute request"; - Status s = client_->stub()->SetReturnValue(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - return Status::OK(); -} - -StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - TF_RETURN_IF_ERROR(first_error_); - - IsConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - request.set_num_parameters(num_parameters); - IsConstantResponse response; - - VLOG(2) << "making IsConstant request"; - Status s = client_->stub()->IsConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - return response.is_constant(); -} - -StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - TF_RETURN_IF_ERROR(first_error_); - - ComputeConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; - } - for (const auto& param : parameters) { - *request.add_parameters() = param.ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant request"; - Status s = client_->stub()->ComputeConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return InternalError( - "no computed literal in the provided response in ComputeConstant " - "request"); - } - return Literal::CreateFromProto(response.literal()); -} - -ComputationDataHandle ComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - OpRequest op_request; - MapRequest* request = op_request.mutable_map_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_to_apply() = computation.handle(); - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (const ComputationDataHandle& sop : static_operands) { - *request->add_static_operands() = sop; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); -} - -ComputationDataHandle ComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); -} - -ComputationDataHandle ComputationBuilder::While( - const Computation& condition, const Computation& body, - const ComputationDataHandle& init) { - OpRequest op_request; - WhileRequest* request = op_request.mutable_while_request(); - *request->mutable_condition() = condition.handle(); - *request->mutable_body() = body.handle(); - *request->mutable_init() = init; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - OpRequest op_request; - GatherRequest* gather_request = op_request.mutable_gather_request(); - *gather_request->mutable_input() = input; - *gather_request->mutable_gather_indices() = gather_indices; - *gather_request->mutable_dimension_numbers() = dimension_numbers; - for (int64 window_bound : window_bounds) { - gather_request->add_window_bounds(window_bound); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation) { - OpRequest op_request; - ConditionalRequest* request = op_request.mutable_conditional_request(); - *request->mutable_predicate() = predicate; - *request->mutable_true_operand() = true_operand; - *request->mutable_true_computation() = true_computation.handle(); - *request->mutable_false_operand() = false_operand; - *request->mutable_false_computation() = false_computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { - OpRequest op_request; - ReduceRequest* request = op_request.mutable_reduce_request(); - *request->mutable_operand() = operand; - *request->mutable_init_value() = init_value; - for (int64 dimension : dimensions_to_reduce) { - request->add_dimensions(dimension); - } - *request->mutable_to_apply() = computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReduceAll( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - std::vector all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie())); - std::iota(all_dimnos.begin(), all_dimnos.end(), 0); - return Reduce(operand, init_value, computation, all_dimnos); -} - -ComputationDataHandle ComputationBuilder::ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - Status padding_valid = - ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides); - if (!padding_valid.ok()) { - first_error_ = padding_valid; - return ComputationDataHandle(); - } - - std::vector> padding_values = - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); -} - -ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - OpRequest op_request; - ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); - *request->mutable_operand() = operand; - *request->mutable_to_apply() = computation.handle(); - *request->mutable_init_value() = init_value; - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormTraining( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormTrainingRequest* request = - op_request.mutable_batch_norm_training_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormInferenceRequest* request = - op_request.mutable_batch_norm_inference_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - *request->mutable_mean() = mean; - *request->mutable_variance() = variance; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormGrad( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, float epsilon, - int64 feature_index) { - OpRequest op_request; - BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_mean() = batch_mean; - *request->mutable_variance() = batch_var; - *request->mutable_grad_output() = grad_output; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - OpRequest op_request; - CrossReplicaSumRequest* request = - op_request.mutable_cross_replica_sum_request(); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - return SelectAndScatterWithGeneralPadding( - operand, select, window_dimensions, window_strides, - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding), - source, init_value, scatter); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - OpRequest op_request; - SelectAndScatterRequest* request = - op_request.mutable_select_and_scatter_request(); - *request->mutable_operand() = operand; - *request->mutable_select() = select.handle(); - *request->mutable_source() = source; - *request->mutable_init_value() = init_value; - *request->mutable_scatter() = scatter.handle(); - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReducePrecision( - const ComputationDataHandle& operand, const int exponent_bits, - const int mantissa_bits) { - OpRequest op_request; - ReducePrecisionRequest* request = - op_request.mutable_reduce_precision_request(); - *request->mutable_operand() = operand; - request->set_exponent_bits(exponent_bits); - request->set_mantissa_bits(mantissa_bits); - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Send(const ComputationDataHandle& operand, - const ChannelHandle& handle) { - OpRequest op_request; - SendRequest* request = op_request.mutable_send_request(); - *request->mutable_operand() = operand; - *request->mutable_channel_handle() = handle; - *op_request.mutable_computation() = computation_.handle(); - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, - const ChannelHandle& handle) { - OpRequest op_request; - RecvRequest* request = op_request.mutable_recv_request(); - *request->mutable_shape() = shape; - *request->mutable_channel_handle() = handle; - return RunOpAndParseResponse(&op_request); -} - -Computation ComputationBuilder::BuildAndNoteError() { - DCHECK(parent_builder_ != nullptr); - auto build_status = Build(); - if (!build_status.ok()) { - parent_builder_->NoteError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); - return Computation(); - } - return build_status.ConsumeValueOrDie(); -} - -StatusOr ComputationBuilder::Build() { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } - - if (computation_.IsNull()) { - return FailedPrecondition("no computation was built"); - } - - return {std::move(computation_)}; -} - -/* static */ ConvolutionDimensionNumbers -ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(kConvBatchDimension); - dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_output_batch_dimension(kConvBatchDimension); - dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_kernel_output_feature_dimension( - kConvKernelOutputDimension); - dimension_numbers.set_kernel_input_feature_dimension( - kConvKernelInputDimension); - for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(i + 2); - dimension_numbers.add_kernel_spatial_dimensions(i + 2); - dimension_numbers.add_output_spatial_dimensions(i + 2); - } - return dimension_numbers; -} - -/* static */ StatusOr -ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({input_batch, input_feature, input_first_spatial, - input_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", - input_batch, input_feature, input_first_spatial, input_second_spatial); - } - if (std::set({kernel_output_feature, kernel_input_feature, - kernel_first_spatial, kernel_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", - kernel_output_feature, kernel_input_feature, kernel_first_spatial, - kernel_second_spatial); - } - if (std::set({output_batch, output_feature, output_first_spatial, - output_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", - output_batch, output_feature, output_first_spatial, - output_second_spatial); - } - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(input_batch); - dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.add_input_spatial_dimensions(input_first_spatial); - dimension_numbers.add_input_spatial_dimensions(input_second_spatial); - dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); - dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); - dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); - dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_output_spatial_dimensions(output_first_spatial); - dimension_numbers.add_output_spatial_dimensions(output_second_spatial); - return dimension_numbers; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h deleted file mode 100644 index 9431c2c459a564..00000000000000 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ /dev/null @@ -1,1065 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stacktrace.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Wraps an XLA client with a convenient interface for building up -// computations. Any errors encountered in building up the computation are -// deferred from being handled until Build() is called. -// -// Thread-compatible. -class ComputationBuilder { - public: - // client: client in which to build the computation. - // computation_name: name to use for the built computation. - ComputationBuilder(Client* client, const string& computation_name); - - ~ComputationBuilder(); - - // Returns the client the builder was initialized with. - Client* client() const { return client_; } - - // Returns the computation name. - const string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - // Retrieves the (inferred) shape of the operand in the computation. - StatusOr> GetShape( - const ComputationDataHandle& operand); - - // Retrieves the (inferred) result for the current computation's shape. - StatusOr GetProgramShape(); - - // Enqueues a constant with the value of the given literal onto the - // computation. - ComputationDataHandle ConstantLiteral(const Literal& literal); - - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. - template - ComputationDataHandle ConstantR0(NativeT value); - template - ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice values); - ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); - template - ComputationDataHandle ConstantR2( - std::initializer_list> values); - template - ComputationDataHandle ConstantFromArrayWithLayout( - const Array& values, const Layout& layout); - template - ComputationDataHandle ConstantFromArray(const Array& values); - template - ComputationDataHandle ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); - template - ComputationDataHandle ConstantR2FromArray2D(const Array2D& values); - template - ComputationDataHandle ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); - template - ComputationDataHandle ConstantR3FromArray3D(const Array3D& values); - template - ComputationDataHandle ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); - template - ComputationDataHandle ConstantR4FromArray4D(const Array4D& values); - - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. - template - ComputationDataHandle ConstantR1(int64 length, NativeT value); - - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); - - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); - - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); - - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes); - - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); - - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); - - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); - - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); - - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. - void Trace(const string& tag, const ComputationDataHandle& operand); - - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - - // Enqueues a tuple-creation instruction onto the computation. - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); - - // Enqueues a tuple-element-get instruction onto the computation. - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); - - // Enqueues an equal-to comparison instruction onto the computation. - ComputationDataHandle Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a not-equal comparison instruction onto the computation. - ComputationDataHandle Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-or-equal comparison instruction onto the computation. - ComputationDataHandle Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-than comparison instruction onto the computation. - ComputationDataHandle Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-than comparison instruction onto the computation. - ComputationDataHandle Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-or-equal comparison instruction onto the computation. - ComputationDataHandle Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a dot instruction onto the computation. - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - - // Enqueues a general dot instruction onto the computation. - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); - - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an - // error if either the input or the weight dimension numbers have conflicts. - static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial); - - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. - ComputationDataHandle Conv(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). - ComputationDataHandle ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. - ComputationDataHandle ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. - ComputationDataHandle ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. - ComputationDataHandle Fft(const ComputationDataHandle& operand, - FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); - - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. - ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); - - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. - void Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, const string& outfeed_config); - - // Enqueues a call instruction onto the computation. - ComputationDataHandle Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands); - - // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - ComputationDataHandle CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - ComputationDataHandle HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape); - - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. - ComputationDataHandle Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a complex conjugate instruction onto the computation. - ComputationDataHandle Conj(const ComputationDataHandle& operand); - - // Enqueues an add instruction onto the computation. - ComputationDataHandle Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a subtract instruction onto the computation. - ComputationDataHandle Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a multiply instruction onto the computation. - ComputationDataHandle Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a divide instruction onto the computation. - ComputationDataHandle Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a remainder instruction onto the computation. - ComputationDataHandle Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a max instruction onto the computation. - ComputationDataHandle Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a min instruction onto the computation. - ComputationDataHandle Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Element-wise logical operators - ComputationDataHandle And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Not(const ComputationDataHandle& operand); - - ComputationDataHandle ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. - ComputationDataHandle ReduceAll(const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const Computation& computation); - - // Enqueues a windowed reduce instruction onto the computation. - ComputationDataHandle ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); - - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); - - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. - ComputationDataHandle SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // Enqueues an abs instruction onto the computation. - ComputationDataHandle Abs(const ComputationDataHandle& operand); - - // Enqueues a atan2 instruction onto the computation. - ComputationDataHandle Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an exp instruction onto the computation. - ComputationDataHandle Exp(const ComputationDataHandle& operand); - - // Enqueues a floor instruction onto the computation. - ComputationDataHandle Floor(const ComputationDataHandle& operand); - - // Enqueues a ceil instruction onto the computation. - ComputationDataHandle Ceil(const ComputationDataHandle& operand); - - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. - ComputationDataHandle Round(const ComputationDataHandle& operand); - - // Enqueues an log instruction (natural logarithm) onto the computation. - ComputationDataHandle Log(const ComputationDataHandle& operand); - - // Enqueues a sign instruction onto the computation. - ComputationDataHandle Sign(const ComputationDataHandle& operand); - - // Enqueues a cosine instruction onto the computation. - ComputationDataHandle Cos(const ComputationDataHandle& operand); - - // Enqueues a sine instruction onto the computation. - ComputationDataHandle Sin(const ComputationDataHandle& operand); - - // Enqueues a tanh instruction onto the computation. - ComputationDataHandle Tanh(const ComputationDataHandle& operand); - - // Enqueues a real-part instruction onto the computation. - ComputationDataHandle Real(const ComputationDataHandle& operand); - - // Enqueues an imaginary-part instruction onto the computation. - ComputationDataHandle Imag(const ComputationDataHandle& operand); - - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - ComputationDataHandle SquareF32(const ComputationDataHandle& operand); - - // Enqueues a lhs^rhs computation onto the computation. - ComputationDataHandle Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. - ComputationDataHandle IsFinite(const ComputationDataHandle& operand); - - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. - ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); - - // Enqueues a negate instruction onto the computation. - ComputationDataHandle Neg(const ComputationDataHandle& operand); - - // Enqueues a count-leading-zeros instruction onto the computation. - ComputationDataHandle Clz(const ComputationDataHandle& operand); - - // Enqueues a transpose instruction onto the computation. - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); - - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a sort (as increasing order) instruction onto the computation. - ComputationDataHandle Sort(const ComputationDataHandle& operand); - - // Enqueues a clamp instruction onto the computation. - ComputationDataHandle Clamp(const ComputationDataHandle& min, - const ComputationDataHandle& operand, - const ComputationDataHandle& max); - - // Enqueues a map instruction onto the computation. - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); - - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); - - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); - - // Enqueues a while node onto the computation. - ComputationDataHandle While(const Computation& condition, - const Computation& body, - const ComputationDataHandle& init); - - // Enqueues a conditional node onto the computation. - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation); - - // Enqueues a ReducePrecision node onto the computation. - ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, - const int exponent_bits, - const int mantissa_bits); - - // Enqueues a Gather node onto the computation. - ComputationDataHandle Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); - - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. - void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); - - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. - ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with index greater than or equal to - // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. - // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a - // compile-time constant without evaluating the computation. - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters = 0); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. - ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& offset, - float epsilon, int64 feature_index); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. - ComputationDataHandle BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, - int64 feature_index); - - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` - ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, - float epsilon, int64 feature_index); - - // Computes the value of a constant indicated by a - // ComputationDataHandle using a non-optimized interpreter on the host. - // - // The operand must be from the computation currently being built - - // i.e., returned from this builder with no intervening call to - // Build(). This happens to currently work regardless of that, but - // that may stop working at any time. - // - // The operand must represent a constant value, which in this case - // means that it must not statically depend on any parameter of the - // computation that is being built other then the ones specified on the - // parameter list. The parameters in the list will be indexed by their - // parameter id property so the number of parameters specified should be at - // least as many as the largest used parameter index. - // - // `IsConstant` can be used to test whether a computation is a compile-time - // constant without evaluation it. `ComputeConstant` only succeeds for - // computations where `IsConstant` returns true. - // - // This functionality can be useful when translating a computation - // into XLA where something that looked dynamic is required by - // XLA to be specified as a constant. E.g. the source - // computation (outside of XLA) may include a dynamic - // computation of the shape of something and ComputeConstant lets - // you determine what the value of that computation is in the case - // where the value can be determined at compile time. - // - // If output_layout is non-null, then the output of the computation - // will be stored using that layout. - StatusOr> ComputeConstant( - const ComputationDataHandle& operand, - const Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}); - - // Returns a new ComputationBuilder whose resultant Computation is used only - // by this ComputationBuilder. The sub-ComputationBuilder has the same - // die_immediately_on_error behavior as the parent. - std::unique_ptr CreateSubBuilder( - const string& computation_name); - - // Modifies the computation being built so that executions of it - // will return the value associated with operand, rather than the - // last expression enqueued on the ComputationBuilder. Any subsequent - // operations added to the ComputationBuilder will not have any effect unless - // SetReturnValue is called again. - Status SetReturnValue(const ComputationDataHandle& operand); - - // Builds the computation with the requested operations, or returns a non-ok - // status. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent ComputationBuilder and returns an empty computation if building - // failed. This function is intended to be used where the returned - // Computation is only used by the parent ComputationBuilder and hence further - // operation on the returned Computation will simply be error'ed out if an - // error occurred while building this computation. If the built computation is - // to be used by a ComputationBuilder other than the parent ComputationBuilder - // then Build() should be used instead. - Computation BuildAndNoteError(); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // ComputationDataHandle and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - private: - // Limited checking of convolution parameters. Returns false on - // error. - bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers); - - // The parent ComputationBuilder of a sub-ComputationBuilder. The - // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. - ComputationBuilder* parent_builder_{nullptr}; - - // Helper function for creating a Window proto from user-supplied - // data. Returns true if the user-supplied data was valid. - bool MakeWindow(tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - Window* window); - - // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation unop, - const ComputationDataHandle& operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. - ComputationDataHandle BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that does the building for an arbitrary ternary op. - ComputationDataHandle TernaryOp(TernaryOperation triop, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - const ComputationDataHandle& ehs); - - // Internal helper method that does the building for a random number generator - // of a given distribution with an explicitly specified shape. - ComputationDataHandle RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); - - // Populates computation_ with a valid object or returns a failing status. - // This is used before any given operation is enqueued. - Status PrepareComputation(); - - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - // Helper function that runs the given op_request, filling in op_response. - // Before the op is run, PrepareComputation is called, and common fields in - // the op_request are filled in. - Status RunOp(OpRequest* op_request, OpResponse* op_response); - - // Helper function that calls RunOp and calls NoteError on failures. - void RunOpAndNoteError(OpRequest* op_request); - - // Helper function that calls RunOp and either returns the output computation - // data handle (on success) or a vacuous computation data handle (on failure). - ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); - - // Helper function that implements GetShape without noting errors. This makes - // it easier to ensure the real GetShape will note errors on every error path. - StatusOr> GetShapeWithoutNoteError( - const ComputationDataHandle& operand); - - string name_; // Name to use for the built computation. - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tensorflow::SavedStackTrace first_error_backtrace_; - - // The computation that operations are enqueued onto. - Computation computation_; - - // The client that the computation is created in. Not owned. - Client* client_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); -}; - -template -ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1( - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, - NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline ComputationDataHandle ComputationBuilder::ConstantR1( - const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( - const Array& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArray( - const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( - const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( - const Array3D& values) { - return ConstantFromArray(values); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( - const Array4D& values) { - return ConstantFromArray(values); -} - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class ScopedShardingAssignment { - public: - ScopedShardingAssignment(xla::ComputationBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::ComputationBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 6e3c5cb484b8f1..7dee41f6a05025 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } +ExecutableBuildOptions& +ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { + return dump_unoptimized_hlo_proto_to_; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( tensorflow::StringPiece dirpath) { dump_per_pass_hlo_proto_to_ = dirpath.ToString(); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 11f10983606fe0..9dc9be4423564f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -64,6 +65,13 @@ class ExecutableBuildOptions { tensorflow::StringPiece dirpath); const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() + const; + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( @@ -76,6 +84,13 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); tensorflow::gtl::optional hlo_profile() const; + void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + disabled_hlo_passes_.push_back(std::string(pass_name)); + } + const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + return disabled_hlo_passes_; + } + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; @@ -87,8 +102,10 @@ class ExecutableBuildOptions { bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; + std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 40f59eaa68ebeb..2986d406001370 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -31,7 +31,7 @@ GlobalData::~GlobalData() { *request.mutable_data() = handle_; UnregisterResponse response; VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - tensorflow::Status s = parent_->Unregister(&request, &response); + Status s = parent_->Unregister(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 59c4a53c05a454..d49d959a6c8112 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -22,8 +22,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", @@ -43,9 +41,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 63df449e0b3bdd..a1d34796ccfd86 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,7 +17,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,28 +28,6 @@ limitations under the License. namespace xla { namespace { -using InstructionGenerator = - ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, - const ComputationDataHandle&); - -Computation CreateScalarComputation(const string& name, PrimitiveType type, - ComputationBuilder* builder, - InstructionGenerator generator) { - std::unique_ptr b; - if (type == PRED) { - b = builder->CreateSubBuilder(name); - } else { - b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); - } - - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - generator(b.get(), lhs, rhs); - return b->BuildAndNoteError(); -} - using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, @@ -71,71 +50,6 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } // namespace -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "add", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); }); -} - -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "mul", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); -} - -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); }); -} - -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "max", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); }); -} - -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "min", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); -} - -Computation CreateScalarAndComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "and", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); -} - -Computation CreateScalarOrComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); -} - -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder) { - auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); -} - XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f4d3fc801590fe..64b6b7d6335316 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,83 +18,38 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -// Creates a scalar add computation and returns it. -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar multiply computation and returns it. -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar ge computation and returns it. -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar max computation and returns it. -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar min computation and returns it. -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar logical AND computation and returns it. -Computation CreateScalarAndComputation(ComputationBuilder* builder); - -// Creates a scalar logical OR computation and returns it. -Computation CreateScalarOrComputation(ComputationBuilder* builder); - -// Returns whether any predicate in "predicates" is set. -// -// Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder); - -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar add computation and returns it. XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar multiply computation and returns it. XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar ge computation and returns it. XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar max computation and returns it. XlaComputation CreateScalarMaxComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar min computation and returns it. XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar logical AND computation and returns it. XlaComputation CreateScalarAndComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar logical OR computation and returns it. XlaComputation CreateScalarOrComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 311dc4bdd72cfd..3380af9f303b1d 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -46,16 +45,14 @@ int64 DataSizeOfShape(const Shape& shape) { return total_size; } -// Create a ComputationDataHandle for an op what generates fake data with the -// given shape. -ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, - ComputationBuilder* builder) { +// Creates a XlaOp for an op what generates fake data with the given shape. +XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return builder->Broadcast( builder->ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } - std::vector parts; + std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } @@ -64,11 +61,10 @@ ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - ComputationBuilder b( - client, + XlaBuilder b( tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); - Computation computation = b.Build().ConsumeValueOrDie(); + XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; @@ -96,21 +92,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, return MakeFakeDataViaDeviceOrDie(shape, client); } -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); - - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; -} - std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { CHECK(computation.proto().has_program_shape()) diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 1dc2622972d5fd..dc613099e2b42a 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -33,12 +32,6 @@ namespace xla { std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client); -// Returns vector of GlobalData handles of fake data (created using -// MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); - // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1c1270590375ab..ae0308020d014e 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -48,30 +48,52 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, << "Must have a valid device ordinal that the executable was built for."; } -tensorflow::Status LocalExecutable::ValidateExecutionOptions( +Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& computation_layout = - executable_->module_config().entry_computation_layout(); + const ComputationLayout& host_computation_layout = + executable_->module_config().host_entry_computation_layout(); + const ComputationLayout& device_computation_layout = + executable_->module_config().device_entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != computation_layout.parameter_count()) { + if (arguments.size() != host_computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - computation_layout.parameter_count(), arguments.size()); + host_computation_layout.parameter_count(), arguments.size()); + } + if (arguments.size() != device_computation_layout.parameter_count()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + device_computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, - "Argument does not match shape or layout of computation parameter " + "Argument does not match host shape or layout of computation " + "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString( + host_computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } + if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( + arguments[i]->on_device_shape())) { + return InvalidParameterArgument( + executable_.get(), i, + "Argument does not match device shape or layout of computation " + "parameter " + "%d: want %s, got %s", + i, + ShapeUtil::HumanString( + device_computation_layout.parameter_layout(i).shape()) + .c_str(), + ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); + } } if (run_options.stream() != nullptr) { @@ -163,7 +185,7 @@ StatusOr LocalExecutable::Run( run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); - if (executable_->dumping()) { + if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); } return executable_->ExecuteOnStreamWrapper( @@ -173,36 +195,36 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { - executable_->session_module()->set_execution_platform( + executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module())); - TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); return std::move(result); } -tensorflow::Status LocalExecutable::RecordArguments( +Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module) { - session_module->clear_arguments(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*argument)); - *session_module->add_arguments() = literal->ToProto(); + *hlo_snapshot->add_arguments() = literal->ToProto(); } return Status::OK(); } -tensorflow::Status LocalExecutable::RecordResult( - const ShapedBuffer* result, SessionModule* session_module) { - session_module->clear_result(); +Status LocalExecutable::RecordResult(const ShapedBuffer* result, + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); - *session_module->mutable_result() = literal->ToProto(); + *hlo_snapshot->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -239,25 +261,6 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -StatusOr> LocalClient::Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options) { - ExecutableBuildOptions updated_options = options; - if (options.device_ordinal() == -1) { - updated_options.set_device_ordinal(default_device_ordinal()); - VLOG(3) << "Set device ordinal to default value of: " - << updated_options.device_ordinal(); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); -} - StatusOr> LocalClient::Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -301,6 +304,11 @@ StatusOr> LocalClient::ShapedBufferToLiteral( shaped_buffer); } +StatusOr LocalClient::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + return local_service_->GlobalDataToShapedBuffer(data, replica_number); +} + Status LocalClient::TransferToInfeedLocal(const Literal& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index f306c520ede001..4d9e0d7cd9d6dd 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,12 +19,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -42,15 +43,6 @@ class LocalExecutable { const tensorflow::gtl::ArraySlice arguments, ExecutableRunOptions run_options); - // Return the layout (contained in a shape) of the result produced by the - // computation. - const Shape& result_layout() const { - return executable_->module_config() - .entry_computation_layout() - .result_layout() - .shape(); - } - // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -67,25 +59,30 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. - tensorflow::Status ValidateExecutionOptions( + // + // The given ExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). + Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + // + // The given ServiceExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - tensorflow::Status RecordArguments( + Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module); + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. - tensorflow::Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -116,17 +113,11 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. - StatusOr> Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. + // The given ExecutableBuildOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -145,6 +136,11 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + // Transfer the given literal to the infeed queue of the given device. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 0d6e207971ec64..507a2dc5f088e1 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -37,7 +37,6 @@ cc_library( ], ) -# TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 1899983e442116..5e17cc4dfb0b22 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -57,16 +57,6 @@ bool CanBeRoot(HloOpcode opcode) { } } -StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) { - std::vector operand_shapes; - for (const XlaOp& operand : operands) { - TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); - operand_shapes.push_back(shape); - } - return operand_shapes; -} - } // namespace StatusOr XlaBuilder::GetShape(const XlaOp& op) const { @@ -76,12 +66,14 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { return instr->shape(); } -StatusOr XlaOp::GetShape() const { - if (builder_ == nullptr) { - return InvalidArgument( - "cannot GetShape for an invalid XlaOp with handle %lld", handle()); +StatusOr> XlaBuilder::GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + operand_shapes.push_back(shape); } - return builder_->GetShape(*this); + return operand_shapes; } XlaBuilder::XlaBuilder(const string& computation_name) @@ -286,7 +278,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); @@ -325,7 +317,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferUnaryOpShape(unop, operand_shape)); return AddInstruction(std::move(instr), unop, {operand}); @@ -337,8 +329,8 @@ XlaOp XlaBuilder::BinaryOp( tensorflow::gtl::ArraySlice broadcast_dimensions) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); @@ -374,12 +366,12 @@ XlaOp XlaBuilder::BinaryOp( updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; } - TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(instr.shape(), updated_lhs)); } - TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(instr.shape(), updated_rhs)); @@ -393,9 +385,9 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferTernaryOpShape( triop, lhs_shape, rhs_shape, ehs_shape)); @@ -437,7 +429,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { +XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); @@ -485,7 +477,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); @@ -633,7 +625,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); @@ -647,7 +639,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -1002,7 +994,7 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); @@ -1173,6 +1165,10 @@ XlaOp XlaBuilder::Exp(const XlaOp& operand) { return UnaryOp(HloOpcode::kExp, operand); } +XlaOp XlaBuilder::Expm1(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExpm1, operand); +} + XlaOp XlaBuilder::Floor(const XlaOp& operand) { return UnaryOp(HloOpcode::kFloor, operand); } @@ -1189,6 +1185,10 @@ XlaOp XlaBuilder::Log(const XlaOp& operand) { return UnaryOp(HloOpcode::kLog, operand); } +XlaOp XlaBuilder::Log1p(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog1p, operand); +} + XlaOp XlaBuilder::Sign(const XlaOp& operand) { return UnaryOp(HloOpcode::kSign, operand); } @@ -1225,7 +1225,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferTransposeShape(operand_shape, permutation)); @@ -1613,13 +1613,35 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{}, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (!replica_group_ids.empty() || channel_id.has_value()) { + return Unimplemented( + "replica_group_ids and channel_id and is not supported in AllReduce"); + } + + HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, {operand}); }); @@ -1948,11 +1970,18 @@ StatusOr XlaBuilder::LookUpInstruction( const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); + if (op.builder_ == nullptr) { + return InvalidArgument( + "invalid XlaOp with handle %lld; the builder of this op is freed", + op.handle()); + } if (op.builder_ != this) { - return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + return InvalidArgument( + "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "it in builder '%s'", + op.handle(), op.builder_->name().c_str(), this->name().c_str()); } - TF_RET_CHECK(op.builder_ == this); if (op.handle() >= instructions_.size() || op.handle() < 0) { return InvalidArgument("no XlaOp value %lld", op.handle()); } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 4955f1515d66af..532cae014848e1 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(b/74197823): Replace computation_builder.h with this file. -// -// This is NOT YET ready to use. - #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ @@ -48,15 +44,11 @@ class XlaBuilder; // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the // instruction as an operand. -// -// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: XlaOp() : handle_(0), builder_(nullptr) {} ~XlaOp() {} - StatusOr GetShape() const; - const XlaBuilder* builder() const { return builder_; } bool operator==(const XlaOp& rhs) const { @@ -87,8 +79,6 @@ class XlaOp { // A convenient interface for building up computations. // // Thread-compatible. -// -// TODO(b/74197823): Replace xla::ComputationBuilder with this one. class XlaBuilder { public: // computation_name: name to use for the built computation. @@ -139,7 +129,7 @@ class XlaBuilder { // Enqueues a constant with the value of the given literal onto the // computation. - XlaOp ConstantLiteral(const Literal& literal); + XlaOp ConstantLiteral(const LiteralSlice& literal); // Enqueues a constant onto the computation. Methods are templated on the // native host type (NativeT) which corresponds to a specific XLA @@ -542,6 +532,29 @@ class XlaBuilder { // supply one input to the sum and all replicas receive the resulting sum. XlaOp CrossReplicaSum(const XlaOp& operand); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -571,6 +584,9 @@ class XlaBuilder { // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); + // Enqueues an expm1 instruction onto the computation. + XlaOp Expm1(const XlaOp& operand); + // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); @@ -584,6 +600,9 @@ class XlaBuilder { // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); + // Enqueues an log1p instruction (log(x+1)) onto the computation. + XlaOp Log1p(const XlaOp& operand); + // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); @@ -847,6 +866,10 @@ class XlaBuilder { // computation and fills the root_id in the pointer. StatusOr GetProgramShape(int64* root_id) const; + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const; + // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful // operation such as `RngNormal` or `Infeed`. The visitor walks the @@ -981,8 +1004,6 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { // RAII-style object: sets the current sharding assignment in builder on // construction, and sets back to the previous assignment on destruction. -// -// TODO(b/74197823): This is a part of a NOT YET ready refactor. class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index ce984564d016ce..2df3ea3af0d4fc 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -76,7 +76,7 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { auto y = b.Parameter(1, y_shape, "y"); auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -188,8 +188,10 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { builder.Add(p0, p0); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Do not add XlaOp from builder b1 to builder main")); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "built by builder 'b1', but is trying to use it in builder 'main'")); } TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index b70b57e9ffec40..0ffba208b1f868 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -25,8 +25,6 @@ limitations under the License. namespace xla { // The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: XlaComputation() : unique_id_(-1) {} diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h new file mode 100644 index 00000000000000..a1463aa15941b9 --- /dev/null +++ b/tensorflow/compiler/xla/error_spec.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ +#define TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ + +namespace xla { + +// Structure describing permissible absolute and relative error bounds. +struct ErrorSpec { + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} + + float abs; // Absolute error bound. + float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 99b8f0558e6e39..a472747bd174e3 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -45,17 +45,6 @@ stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } -ExecutableRunOptions& ExecutableRunOptions::set_inter_op_thread_pool( - tensorflow::thread::ThreadPool* inter_op_thread_pool) { - inter_op_thread_pool_ = inter_op_thread_pool; - return *this; -} - -tensorflow::thread::ThreadPool* ExecutableRunOptions::inter_op_thread_pool() - const { - return inter_op_thread_pool_; -} - ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index a306ae16ba4aee..416131be006e6e 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -65,12 +65,6 @@ class ExecutableRunOptions { ExecutableRunOptions& set_stream(stream_executor::Stream* stream); stream_executor::Stream* stream() const; - // Sets the thread pool on which to run parallel CPU backend - // computations. Does not take ownership. - ExecutableRunOptions& set_inter_op_thread_pool( - tensorflow::thread::ThreadPool* inter_op_thread_pool); - tensorflow::thread::ThreadPool* inter_op_thread_pool() const; - // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. ExecutableRunOptions& set_intra_op_thread_pool( @@ -93,7 +87,6 @@ class ExecutableRunOptions { int device_ordinal_ = -1; DeviceAssignment* device_assignment_ = nullptr; stream_executor::Stream* stream_ = nullptr; - tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index fdc4bbdd8b162b..e8f29b83291a7c 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -64,6 +65,16 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor) { + Layout layout; + layout.set_format(DENSE); + for (int i = major_to_minor.size() - 1; i >= 0; i--) { + layout.add_minor_to_major(major_to_minor[i]); + } + return layout; +} + /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { Layout layout; layout.set_format(SPARSE); @@ -87,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + // Opaque and token types have empty layouts. + return Layout(); + } + // A Layout proto corresponds to a single array, not a tuple. - DCHECK(!ShapeUtil::IsTuple(shape)); + CHECK(ShapeUtil::IsArray(shape)); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -115,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsOpaque(*shape)) { - shape->clear_layout(); - } else { + } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); + } else { + // Opaque, token types etc. have no layout. + shape->clear_layout(); } } @@ -139,8 +156,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( - const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { @@ -149,30 +165,34 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } - return tensorflow::Status::OK(); - } else if (ShapeUtil::IsOpaque(shape)) { - if (shape.has_layout()) { - return InvalidArgument("opaque should not have a layout field"); - } - return tensorflow::Status::OK(); - } else { - // Array shape. + return Status::OK(); + } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); + } else { + // Token, opaque, etc. shape. + if (shape.has_layout()) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( - const Layout& layout, const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, + const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (ShapeUtil::IsOpaque(shape)) { - return tensorflow::Status::OK(); + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); } if (layout.format() == INVALID_FORMAT) { @@ -224,7 +244,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { @@ -263,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || shape.layout().padded_dimensions_size() == 0) { return false; } @@ -313,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); - } else if (ShapeUtil::IsOpaque(shape)) { + } else if (!ShapeUtil::IsArray(shape)) { + // Opaque, token types etc. ignore layout. return true; } return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; @@ -383,7 +404,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { namespace { // Internal helper for recursively copying layouts. -tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { +Status CopyLayoutInternal(const Shape& src, Shape* dst) { if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); @@ -410,25 +431,21 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { dst->clear_layout(); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace /* static */ -tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, - Shape* dst) { +Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { - return false; - } if (ShapeUtil::IsTuple(lhs)) { - if (ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -437,9 +454,12 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, } } return true; - } else { + } else if (ShapeUtil::IsArray(lhs)) { return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } else { + // Layouts of non-array and non-tuple shapes is ignored. + return true; } } @@ -465,4 +485,25 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { return out; } +/*static*/ size_t LayoutUtil::Hash(const Layout& layout) { + using tensorflow::hash; + using tensorflow::Hash64Combine; + + size_t hash_value = hash()(layout.format()); + + for (int64 minor_to_major : layout.minor_to_major()) { + hash_value = Hash64Combine(hash_value, hash()(minor_to_major)); + } + + for (int64 padded_dim : layout.padded_dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(padded_dim)); + } + + hash_value = + Hash64Combine(hash_value, hash()(layout.padding_value())); + hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); + + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c54eb2201b66a..739bbe73675c7f 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Similar to MakeLayout, but take indices in reverse order. + static Layout MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); @@ -61,12 +65,12 @@ class LayoutUtil { static void SetToDefaultLayout(ProgramShape* program_shape); // Validates that the layout within the given shape is correct. - static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + static Status ValidateLayoutInShape(const Shape& shape); // Validates that the provided layout satisfies invariants for the given // shape. - static tensorflow::Status ValidateLayoutForShape(const Layout& layout, - const Shape& shape); + static Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); // Clears the layout in the given Shape. After this function is called, // HasLayout will return false for the shape. @@ -179,8 +183,7 @@ class LayoutUtil { // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same // rank. within the shapes must have the same number of dimensions. - static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, - Shape* dst); + static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false // otherwise. Recursively compares layouts of tuples. @@ -195,6 +198,9 @@ class LayoutUtil { static bool AreDimensionsConsecutive(const Layout& layout, tensorflow::gtl::ArraySlice dims); + // Compute a hash for `layout`. + static size_t Hash(const Layout& layout); + private: TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4fd1d818e3e3b4..e4c825450dcd45 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { "elements, but shape is rank")); } +TEST_F(LayoutUtilTest, CopyTokenLayout) { + Shape src = ShapeUtil::MakeTokenShape(); + Shape dst = ShapeUtil::MakeTokenShape(); + + // Layouts are trivially the same for token types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyOpaqueLayout) { + Shape src = ShapeUtil::MakeOpaqueShape(); + Shape dst = ShapeUtil::MakeOpaqueShape(); + + // Layouts are trivially the same for opaque types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, ClearLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), @@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) { EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); } +TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) { + // Opaque and token types trivially have layouts. + for (Shape shape : + {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) { + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + LayoutUtil::ClearLayout(&shape); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + } +} + TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index bc8405703b02dc..f42fb92359f40e 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -47,6 +47,12 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // Set cudnn batchnorm off by default; it does not provide a performance win // on average. flags->set_xla_gpu_use_cudnn_batchnorm(false); + + // Run all GPU work on one stream by default. Using multiple streams + // increases memory usage and we lack strong motivating benchmarks for tuning + // the heuristics needed to decide when to run on multiple streams. See + // b/77879207. + flags->set_xla_gpu_disable_multi_streaming(true); } // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc new file mode 100644 index 00000000000000..bf9679cafec72c --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -0,0 +1,741 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_comparison.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" + +using tensorflow::strings::Appendf; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + +namespace xla { +namespace literal_comparison { +namespace { + +// Helper function for comparing a floating point type, FloatT, bitwise equal +// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT +// -- on miscompare, a nice error message is given in the AssertionFailure. +template +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { + auto ulhs = tensorflow::bit_cast(lhs); + auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); + if (ulhs != urhs) { + return InvalidArgument( + "floating values are not bitwise-equal; and equality testing " + "was requested: %s=%g=%a vs %s=%g=%a", + StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, + StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + } + return Status::OK(); +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +Status CompareEqual(NativeT lhs, NativeT rhs) { + if (lhs == rhs) { + return Status::OK(); + } + return InvalidArgument("Expected equality of these values:\n %s\n %s", + StrCat(lhs).c_str(), StrCat(rhs).c_str()); +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(float lhs, float rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(double lhs, double rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(complex64 lhs, complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res.ok()) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} + +// A recursive function which iterates through every index of expected and +// actual literal and compares their values elementwise. Returns true if all +// elements are equal. +template +Status Equal(LiteralSlice expected, LiteralSlice actual, + tensorflow::gtl::MutableArraySlice multi_index, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); + return CompareEqual(expected_value, actual_value); + } + + Status result; + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index[dimension] = i; + result.Update(Equal(expected, actual, multi_index, dimension + 1)); + } + return result; +} + +// Gets the total element count. For tuples, this is not the count of tuple +// elements, but the sum of elements of each tuple element. +int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } +} + +// Returns whether the actual and expected values are mismatched with respect to +// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +template +bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } +} + +template <> +bool NanMismatch(complex64 expected, complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +} + +template <> +bool NanMismatch(half expected, half actual, bool relaxed_nans) { + return NanMismatch(static_cast(expected), + static_cast(actual), relaxed_nans); +} + +// Converts the given floating-point value to a string. +template +string FpValueToString(NativeT value) { + return Printf("%8.4g", static_cast(value)); +} + +template <> +string FpValueToString(complex64 value) { + return Printf("%8.4g + %8.4fi", value.real(), value.imag()); +} + +// Returns the absolute value of the given floating point value. This function +// is used instead of std::abs directly in order to allow type-dependent +// implementations for NearComparator. +template +float FpAbsoluteValue(NativeT value) { + return std::abs(value); +} + +template <> +float FpAbsoluteValue(bfloat16 value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +float FpAbsoluteValue(half value) { + return FpAbsoluteValue(static_cast(value)); +} + +// Helper class for comparing floating-point literals within an error bound. +template +class NearComparator { + public: + // Compares the two array literals elementwise and returns a comparison + // result. The comparison is ok() if all actual and expected elements are + // within the given error bound. In case of error, the status contains a + // detailed message about the discrepancy. + static Status Compare(const LiteralSlice& expected, + const LiteralSlice& actual, ErrorSpec error, + bool detailed_message, + const MiscompareCallback& miscompare_callback) { + NearComparator comparator(expected, actual, error, + detailed_message, miscompare_callback); + return comparator.Run(); + } + + private: + // Data structure encapsulating metadata about a single element mismatch. + struct Mismatch { + NativeT actual; + NativeT expected; + float rel_error; + float abs_error; + + // The linear index of the failure within the shape. This linear index is + // from the 'actual' literal. + int64 linear_index; + + bool operator<(const Mismatch& other) const { + return rel_error < other.rel_error; + } + + string ToString(const Shape& shape) const { + return Printf( + "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", + FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + Literal::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + rel_error, abs_error); + } + }; + + NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, + ErrorSpec error, bool detailed_message, + const MiscompareCallback& miscompare_callback) + : expected_(expected), + actual_(actual), + error_(error), + detailed_message_(detailed_message), + miscompare_callback_(miscompare_callback), + abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), + abs_error_buckets_(kErrorBucketBounds.size(), 0), + rel_error_buckets_(kErrorBucketBounds.size(), 0) {} + + // Runs the comparison between expected and actual literals. + Status Run() { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, ToStringTruncated(expected_)); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, ToStringTruncated(actual_)); + + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); + if (!ShapeUtil::IsArray(expected_.shape())) { + return InvalidArgument("Expected array shape; got %s.", + ShapeUtil::HumanString(expected_.shape()).c_str()); + } + + mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); + mismatches_.PopulateWithValue(false); + + CompareLiterals(); + + if (num_mismatches_ == 0) { + return Status::OK(); + } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { + miscompare_callback_(expected_, actual_, mismatches_); + } + return InvalidArgument("%s", ErrorMessage().c_str()); + } + + // Insert the given absolute value into the absolute value bucket vector. The + // bounds of the buckets are given by kAbsValueBucketBounds. + void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { + // Adjust the bucket containing the absolute values of the 'actual' + // elements. + const float abs_value = FpAbsoluteValue(value); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + if (i == abs_value_buckets_.size() - 1 || + (abs_value >= kAbsValueBucketBounds[i] && + abs_value < kAbsValueBucketBounds[i + 1])) { + // The first value of the pair is the count of elements in the bucket, + // the second is the count of mismatches in the bucket. + abs_value_buckets_[i].first++; + if (is_mismatch) { + abs_value_buckets_[i].second++; + } + return; + } + } + } + + // Insert the given error into the given error bucket vector. + void UpdateErrorBucket( + float error, tensorflow::gtl::MutableArraySlice error_buckets) { + CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < error_buckets.size(); ++i) { + if (error >= kErrorBucketBounds[i]) { + error_buckets[i]++; + } + } + } + + // Compares the two given elements from the expected and actual literals at + // the given literal_index and keeps track of various mismatch statistics. + void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + const bool is_nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + float abs_error; + float rel_error; + if (actual == expected) { + abs_error = 0; + rel_error = 0; + } else if (is_nan_mismatch) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is used + // for sorting a std::set of the top mismatchs, and a nan value here will + // result in undefined behavior because nan's do not satisfy the strict + // weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = FpAbsoluteValue(actual - expected); + rel_error = abs_error / FpAbsoluteValue(expected); + } + const bool is_abs_mismatch = abs_error > error_.abs; + const bool is_rel_mismatch = rel_error > error_.rel; + const bool is_mismatch = + is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + + // Update the error of the relative bucket only if the *absolute* error + // bound is exceeded and vice versa. + if (is_abs_mismatch) { + num_abs_mismatches_++; + UpdateErrorBucket(rel_error, &rel_error_buckets_); + } + if (is_rel_mismatch) { + num_rel_mismatches_++; + UpdateErrorBucket(abs_error, &abs_error_buckets_); + } + + UpdateAbsValueBucket(actual, is_mismatch); + + if (!is_mismatch) { + return; + } + + num_mismatches_++; + + // Keep track of the kTopRelativeErrorCount relative error mismatches. + if (top_rel_mismatches_.size() < kTopRelativeErrorCount || + rel_error > top_rel_mismatches_.begin()->rel_error) { + Mismatch mismatch = {actual, expected, rel_error, abs_error, + linear_index}; + top_rel_mismatches_.insert(mismatch); + if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { + top_rel_mismatches_.erase(top_rel_mismatches_.begin()); + } + } + + mismatches_.data()[linear_index] = true; + } + + // Compares the two literals elementwise. + void CompareLiterals() { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual_.shape().layout(), + expected_.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected_.data(); + tensorflow::gtl::ArraySlice actual_data = + actual_.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + CompareValues(expected_data[i], actual_data[i], i); + } + return; + } + std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + CompareLiteralsSlow(0, &multi_index); + } + + // Slow path for CompareLiterals when 'actual' and 'expected' literals have + // different layouts. In this case, multidimensional indices are constructed + // and indexed for each element. + void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { + if (dimension == multi_index->size()) { + CompareValues(expected_.Get(*multi_index), + actual_.Get(*multi_index), + IndexUtil::MultidimensionalIndexToLinearIndex( + actual_.shape(), *multi_index)); + } else { + for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + CompareLiteralsSlow(dimension + 1, multi_index); + } + } + } + + // Returns an error message string with a detailed breakdown of the + // mismatches. Called after calling Run(). + string ErrorMessage() { + string out; + int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); + + auto percent_string = [](float a, float b) { + float pct = b == 0.0 ? 0.0 : 100.0 * a / b; + return Printf("%0.4f%%", pct); + }; + + Appendf(&out, + "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, + percent_string(num_mismatches_, element_count).c_str(), + ShapeUtil::HumanString(actual_.shape()).c_str(), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + if (num_nan_mismatches_ > 0) { + StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); + } + Appendf(&out, "Top relative error mismatches:\n"); + for (auto it = top_rel_mismatches_.rbegin(); + it != top_rel_mismatches_.rend(); ++it) { + StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + } + + if (!detailed_message_) { + return out; + } + + StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); + CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + const int64 bucket_size = abs_value_buckets_[i].first; + const int64 bucket_mismatches = abs_value_buckets_[i].second; + string mismatch_str = bucket_mismatches > 0 + ? Printf(", mismatches %lld", bucket_mismatches) + : ""; + Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count).c_str(), + mismatch_str.c_str()); + } + + auto print_accum_buckets = [&](const string& header, int64 total, + tensorflow::gtl::ArraySlice buckets) { + StrAppend(&out, header, ":\n"); + Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total).c_str()); + CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < kErrorBucketBounds.size(); ++i) { + Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total).c_str()); + } + }; + Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count).c_str()); + print_accum_buckets( + "Relative error breakdown of elements exceeding abs error bound", + num_abs_mismatches_, rel_error_buckets_); + Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count).c_str()); + print_accum_buckets( + "Absolute error breakdown of elements exceeding rel error bound", + num_rel_mismatches_, abs_error_buckets_); + return out; + } + + // 'actual' and 'expected' literals being compared. + LiteralSlice expected_; + LiteralSlice actual_; + + // The error bounds of the comparison. + ErrorSpec error_; + + // Whether to include detailed breakdown of mismatches in the error message. + bool detailed_message_; + + // Callback to invoke on miscompare. + MiscompareCallback miscompare_callback_; + + // Number of element element mismatches encountered so far. + int64 num_mismatches_ = 0; + + // Number of elements with a nan mismatch. + int64 num_nan_mismatches_ = 0; + + // Number of elements which exceed the absolute/relative error bound. + int64 num_abs_mismatches_ = 0; + int64 num_rel_mismatches_ = 0; + + // A Literal containing which elements did not match in the expected and + // actual literals. mismatches_ contains PREDs and is of the same sizes as + // the comparison literals. + Literal mismatches_; + + // The number of mismatches to report in the output, sorted by relative error + // magnitude. + static constexpr int64 kTopRelativeErrorCount = 5; + + // The set of mismatches with the largest relative error. The size of this set + // is bounded by kTopRelativeErrorCount. + std::multiset top_rel_mismatches_; + + // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the + // bounds of these buckets. abs_value_buckets_ contains a pair for each + // bucket: the element count and failure count. + static constexpr std::array kAbsValueBucketBounds = { + 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; + std::vector> abs_value_buckets_; + + // Buckets for relative and absolute errors. The relative error buckets only + // contains those elements which exceed the *absolute* error bound, and vice + // versa. This makes it easy to see the effect of adjusting the relative (or + // absolute) error bound on the success of the comparison. kErrorBucketBounds + // are the lower bounds of the buckets in both vectors. The error buckets are + // a cumulative distribution so an error value may appear in more than one + // bucket. For example an error value of 0.003 may appear in the buckets + // bounded by 0.01, 0.1, and 1.0. + static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, + 0.01, 0.1, 1}; + std::vector abs_error_buckets_; + std::vector rel_error_buckets_; +}; + +template +constexpr std::array NearComparator::kAbsValueBucketBounds; +template +constexpr std::array NearComparator::kErrorBucketBounds; + +// Helper function for comparing two literals for nearness. Handles tuple-shapes +// via recursion. shape_index is the ShapeIndex of expected (or actual) +// currently being compared. +Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback, + const ShapeIndex& shape_index) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + + if (ShapeUtil::IsTuple(expected.shape())) { + Status return_status; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); + ShapeIndex element_index = shape_index; + element_index.push_back(i); + Status res = + NearHelper(expected_element, actual_element, error, detailed_message, + miscompare_callback, element_index); + if (!res.ok()) { + string err_message = Printf("\nArray at shape index %s%s", + element_index.ToString().c_str(), + res.error_message().c_str()); + if (return_status.ok()) { + return_status = res; + } else { + return_status = AppendStatus(return_status, res.error_message()); + } + } + } + if (!return_status.ok() && shape_index.empty()) { + // Emit a top-level error message containing the top-level shape in case + // of mismatch. + int64 total_elements = RecursiveElementCount(actual.shape()); + return_status = InvalidArgument( + "\nMismatches in shape %s (%lld elements):\n%s", + ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, + return_status.error_message().c_str()); + } + return return_status; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + switch (expected.shape().element_type()) { + case BF16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F32: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case C64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + default: + LOG(FATAL) << "Unsupported primitive type in near comparator: " + << PrimitiveType_Name(expected.shape().element_type()) + << ". Must be floating-point type."; + } + } + + // Non-floating point literal. + return literal_comparison::Equal(expected, actual); +} + +} // namespace + +Status EqualShapes(const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return InvalidArgument("tupleness-mismatch! want: %s got %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (ShapeUtil::IsTuple(expected)) { + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return InvalidArgument( + "want tuple element count: %lld got tuple element count: %lld", + ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + } + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + Status result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result.ok()) { + return AppendStatus(result, StrCat("mismatch in tuple index", i)); + } + } + } else { + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return InvalidArgument("want rank of %s got rank of %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (expected.element_type() != actual.element_type()) { + return InvalidArgument( + "mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()).c_str(), + PrimitiveType_Name(actual.element_type()).c_str()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return InvalidArgument("want dimensions_size %d got dimensions_size %d", + expected.dimensions_size(), + actual.dimensions_size()); + } + for (int i = 0; i < expected.dimensions_size(); ++i) { + if (expected.dimensions(i) != actual.dimensions(i)) { + return InvalidArgument( + "mismatch in dimension #%d expected: %s actual: %s", i, + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + } + } + return Status::OK(); +} + +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, actual.ToString()); + + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update( + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); + } + break; + } + default: + LOG(FATAL) + << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + if (result.ok()) { + return Status::OK(); + } + + return AppendStatus(result, + tensorflow::strings::Printf( + "\nat index: %s\nexpected: %s\nactual: %s", + Literal::MultiIndexAsString(multi_index).c_str(), + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str())); +} + +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback) { + return NearHelper(expected, actual, error, detailed_message, + miscompare_callback, + /*shape_index=*/{}); +} + +string ToStringTruncated(const LiteralSlice& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; +} + +} // namespace literal_comparison +} // namespace xla diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h new file mode 100644 index 00000000000000..00a13e361932e7 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library for comparing literals without taking a dependency on testing +// libraries. + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ + +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace literal_comparison { + +// Returns ok if the given shapes have the same rank, dimension sizes, and +// primitive types. +Status EqualShapes(const Shape& expected, const Shape& actual); + +// Returns ok if the expected and actual literals are (bitwise) equal for all +// elements in the literal. Also, asserts that the rank, dimensions sizes, and +// primitive type are equal. +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); + +using MiscompareCallback = + std::function; + +// Inspects whether the expected and actual literals are within the given error +// bound for all elements. Also, inspects whether the rank, dimensions sizes, +// and dimension bounds are equivalent. +// +// Tuples are matched recursively. +// +// When comparing tensors of non-floating-point type, this inspects for exact +// equality, ignoring the ErrorSpec. +// +// If the shape of the literals is neither a complex/floating-point tensor nor a +// tuple which contains a complex/floating-point tensor, Near() is equivalent to +// Equal(). We don't raise an error in this case, because we want to allow +// callers to call Near() even if they have no preconceptions about the shapes +// being compared. +// +// If detailed_message is true, then the error message in the assertion result +// will contain a more detailed breakdown of mismatches. +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback); + +// Calling ToString on a literal with over 100 million elements takes around +// 3 minutes. The utility of printing a literal with >1000 elements is +// questionable, especially when writing the Literal proto to disk is orders +// of magnitude faster. +string ToStringTruncated(const LiteralSlice& literal); + +} // namespace literal_comparison +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index bb6dd4f9098aef..61afc311a70293 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -61,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(LiteralSlice literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -94,99 +136,89 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(subshape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(subshape.layout()); - piece.set_buffer( - new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( - subshape.element_type())]); - piece.set_sparse_indices(new SparseIndexArray( - max_sparse_elements, ShapeUtil::Rank(subshape))); - } else { - piece.set_buffer(new char[piece.size_bytes()]); - } +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); + + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); -void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + SetPiece(*shape_, root_piece_, allocate_arrays); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; } - owns_buffers_ = other.owns_buffers_; +} - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); +void Literal::DeallocateBuffers() { + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } + Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); + return *this; } -std::unique_ptr Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -201,9 +233,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ std::unique_ptr Literal::ConvertBF16ToF32( + const LiteralSlice& bf16_literal) { + return ConvertType(bf16_literal); +} + +/* static */ std::unique_ptr Literal::ConvertF32ToBF16( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + template Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -263,7 +305,7 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const Literal& src_literal, +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_index, tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); @@ -292,22 +334,21 @@ std::vector Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); - - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -350,7 +391,9 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -387,7 +430,7 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -400,36 +443,32 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(dest_subshape).c_str(), ShapeUtil::HumanString(src_subshape).c_str()); } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); } Status Literal::MoveFrom(Literal&& src_literal, @@ -443,37 +482,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -742,7 +776,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -754,7 +788,7 @@ std::unique_ptr Literal::Relayout( return result; } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -773,7 +807,48 @@ std::unique_ptr Literal::Relayout( return result; } -StatusOr> Literal::Reshape( +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + +StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -787,7 +862,8 @@ StatusOr> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -801,7 +877,79 @@ StatusOr> Literal::Reshape( return std::move(output); } -std::unique_ptr Literal::Transpose( +/* static */ std::unique_ptr Literal::ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); + + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U8: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +std::unique_ptr LiteralBase::Transpose( tensorflow::gtl::ArraySlice permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -832,15 +980,31 @@ std::unique_ptr Literal::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(permuted_shape); - DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr Literal::Slice( +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + +std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -857,71 +1021,37 @@ std::unique_ptr Literal::Slice( const auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); - - auto result_literal = MakeUnique(result_shape); - - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - float value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); case C64: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - complex64 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case S32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - int32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case U32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - uint32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr Literal::CloneToUnique() const { +std::unique_ptr LiteralBase::CloneToUnique() const { auto result = MakeUnique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -961,8 +1091,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -1016,7 +1146,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr Literal::GetIntegralAsS64( +StatusOr LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -1039,6 +1169,27 @@ StatusOr Literal::GetIntegralAsS64( } } +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsTuple(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); @@ -1069,7 +1220,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, return Status::OK(); } -tensorflow::gtl::ArraySlice Literal::GetSparseIndex( +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1081,10 +1232,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -Literal Literal::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); - switch (shape_.element_type()) { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { case PRED: return std::move(*Literal::CreateR0(GetFirstElement())); // 8 bit types. @@ -1120,11 +1271,11 @@ Literal Literal::GetFirstScalarLiteral() const { case U64: return std::move(*Literal::CreateR0(GetFirstElement())); default: - LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } } -void Literal::Piece::SortSparseElements() { +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal(); @@ -1175,7 +1326,7 @@ void Literal::Piece::SortSparseElements() { } template -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data(); @@ -1186,9 +1337,11 @@ void Literal::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1347,13 +1500,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1361,7 +1515,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1371,6 +1525,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; @@ -1386,7 +1553,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1402,7 +1569,7 @@ void Literal::EachCellAsString( namespace { template std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const Literal& src_literal, const ConverterType& converter) { + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1418,7 +1585,8 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1427,7 +1595,7 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); }; @@ -1442,12 +1610,12 @@ BitcastBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const Literal& src_literal) { +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1465,7 +1633,7 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { @@ -1485,7 +1653,7 @@ std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type, + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ @@ -1520,7 +1688,8 @@ StatusOr> ConvertIfDestTypeMatches( } StatusOr> ConvertSwitch( - const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.CloneToUnique(); @@ -1554,12 +1723,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> Literal::Convert( +StatusOr> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> Literal::BitcastConvert( +StatusOr> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1574,7 +1743,7 @@ StatusOr> Literal::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> Literal::ConvertToShape( +StatusOr> LiteralBase::ConvertToShape( const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && @@ -1589,7 +1758,7 @@ StatusOr> Literal::ConvertToShape( } std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralView::Create(*this, {i}); + auto element = LiteralSlice(*this, {i}); TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); @@ -1601,8 +1770,8 @@ StatusOr> Literal::ConvertToShape( } template -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get(*multi_index) == other.Get(*multi_index)); } @@ -1616,7 +1785,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector multi_index; @@ -1644,28 +1813,28 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { @@ -1683,11 +1852,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1740,41 +1909,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; + return true; + }); } -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue(root_piece().data(), @@ -1784,93 +1953,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsAllFirst() const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { return false; - } - }; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; - if (!piece_is_all()) { - return false; - } - } - return true; + if (!piece_is_all()) { + return false; + } + return true; + }); } -bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1912,7 +2081,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1968,12 +2137,12 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1994,7 +2163,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -2061,21 +2230,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -2097,33 +2264,40 @@ StatusOr> Literal::CreateFromProto( auto literal = MakeUnique(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +/* static */ string Literal::MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index) { + return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -2131,11 +2305,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -2143,51 +2317,55 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + + piece->emplace_back(std::move(child_piece)); } } -LiteralView::~LiteralView() {} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(shape) { + CHECK(ShapeUtil::IsArray(shape_)); + CHECK_NE(src_buf_ptr, nullptr); + CHECK(LayoutUtil::HasLayout(shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(&shape_); } -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(shape) { + CHECK(ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsNestedTuple(shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(&shape_); + BuildPieceSubtree(shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_.tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } - owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 8aa19222dc4b91..1e26eb7ad4098b 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -52,14 +51,509 @@ limitations under the License. namespace xla { +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { + public: + virtual ~LiteralBase() = 0; + + // Literals are equal if they have compatible shapes and the same data + // values. Layout is not compared. + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } + + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. + LiteralProto ToProto() const; + + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + int64 size_bytes(const ShapeIndex& shape_index = {}) const; + + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; + + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; + + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; + + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; + + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; + + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} + // minor-to-major dimension layout and the value in the cell at any given + // logical index (i0, i1) will be the same. + // + // For tuple shaped literals, shape_index should be used to select the inner + // array that the new layout applies to. + // + // Note: this is useful when the client wants to ensure that a value placed in + // the XLA allocation tracker has a particular layout; for efficiency + // purposes or avoiding unimplemented operation/layout combinations. + std::unique_ptr Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; + + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The + // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. + StatusOr> Reshape( + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + + // Creates a new literal by reordering the dimensions of this literal. + // The given `permutation` must be a permutation of the dimension numbers + // in the original literal, and it specifies the order of the new dimensions + // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). + // For example, a transpose call on a literal of shape [3 x 8 x 4] and + // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. + std::unique_ptr Transpose( + tensorflow::gtl::ArraySlice permutation) const; + + // Creates a sub-array from this literal by extracting the indices + // [start_index, limit_index) of each dimension. The result literal has the + // same rank and layout as for the given literal. The number of indices in + // start_indices and limit_indices must be the rank of the literal, and the + // indices follow the order of the dimensions. + // This literal must be an array. + std::unique_ptr Slice( + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices) const; + + // Creates a literal with a prepended dimension with bound "times"; e.g. a + // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this + // literal replicated four times. + // This literal must be an array. + template + std::unique_ptr Replicate(int64 times) const; + + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + // + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class Literal; + friend class LiteralSlice; + friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; +}; + // Class representing literal values in XLA. // -// TODO(b/67651157): The methods in this class should be reduced to a minimal -// set of methods which construct Literals and accessors methods. Other methods -// which perform computation on Literals (Reshape, Slice, etc) should be moved -// elsewhere, and perhaps combined with evaluator code which operates on -// Literals. -class Literal { +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { public: Literal() : Literal(ShapeUtil::MakeNil()) {} @@ -74,48 +568,162 @@ class Literal { Literal(const Literal& other) = delete; Literal& operator=(const Literal& other) = delete; Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); Literal& operator=(Literal&& other); - // Literals are equal if they have compatible shapes and the same data - // values. Layout is not compared. - bool operator==(const Literal& other) const; - bool operator!=(const Literal& other) const { return !(*this == other); } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } - // Serialize to and from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); - LiteralProto ToProto() const; + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); - // Return the shape of the literal. - const Shape& shape() const { return shape_; } + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: + // + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return &shape_; } + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. - template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + // Fills this literal with the given value. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + void PopulateWithValue(NativeT value); - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. - const SparseIndexArray* sparse_indices( - const ShapeIndex& shape_index = {}) const; - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + // Factory methods below. + // - // Returns a pointer to (or size of) the underlying buffer holding the array - // at the given shape index. CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. - const void* untyped_data(const ShapeIndex& shape_index = {}) const; - void* untyped_data(const ShapeIndex& shape_index = {}); - int64 size_bytes(const ShapeIndex& shape_index = {}) const; + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -163,10 +771,6 @@ class Literal { values, const Layout& layout); - // Returns this literal's data as a string. This literal must be a rank-1 U8 - // array. - string GetR1U8AsString() const; - // Creates a literal with a sparse layout and the given indices and values. // The shape is initialized from the given dimensions. The minor dimension of // the indices array must equal the rank of the shape (i.e. size of the @@ -206,171 +810,16 @@ class Literal { tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, tensorflow::gtl::ArraySlice values, bool sort = true); - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. - template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape); - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); - - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const Literal& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); - - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); - - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); - - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); - - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); - - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); - - // Creates a new value that has the equivalent value as this literal, but - // conforms to new_layout; e.g. a literal matrix that was in {0, 1} - // minor-to-major dimension layout can be re-layed-out as {1, 0} - // minor-to-major dimension layout and the value in the cell at any given - // logical index (i0, i1) will be the same. - // - // For tuple shaped literals, shape_index should be used to select the inner - // array that the new layout applies to. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; - - // An overload of Relayout which changes the layout of the entire shape rather - // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; - - // Creates a new literal by reshaping this literal to have the given - // dimensions. The total number of elements must not change; The - // implementation currently only supports monotonic dim0-major layouts. - // This literal must be an array. - StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; - - // Creates a new literal by reordering the dimensions of this literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; - - // Creates a sub-array from this literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this - // literal replicated four times. - // This literal must be an array. - template - std::unique_ptr Replicate(int64 times) const; - - // Converts this literal to another primitive type using - // static_cast<>. Returns an error if the conversion is not possible. This - // literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; - // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( @@ -420,83 +869,10 @@ class Literal { // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z); - - // Clones this literal into a new Literal, or new std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // Gets or sets an element in the literal at the given index. The multi_index - // is CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - - // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + template + static std::unique_ptr CreateR4Projected( + std::initializer_list> values, + int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template @@ -507,6 +883,9 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + static std::unique_ptr MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; @@ -538,136 +917,104 @@ class Literal { return MakeTupleOwned(std::move(v)); } - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; - - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); - - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); - - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32( + const LiteralSlice& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16( + const LiteralSlice& f32_literal); + + // Creates a literal with a new shape with the given new dimensions using the + // data in the given input literal. For reshaping purposes the (flat) data + // buffer of the input literal is assumed to have the given minor_to_major + // layout order. + static std::unique_ptr ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal); + + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); + + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). - // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; + // End of factory methods. - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will + // be returned for a 2-dimensional index with dimension 0 index equal to 7, + // dimension 1 equal to 8. + static string MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index); - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + private: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); - // Return the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); } - // Return the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - protected: - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); + Piece& root_piece() const override { return *root_piece_; }; // Internal template helper for the Literal::CopySliceFrom(), matching its // arguments one by one. template - Status CopySliceFromInternal(const Literal& src_literal, + Status CopySliceFromInternal(const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); @@ -695,162 +1042,69 @@ class Literal { int64 minor_loop_size = 1; }; - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Return the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); - - // Return the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - }; - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return *pieces_.mutable_element(shape_index); - } - const Piece& piece(const ShapeIndex& shape_index) const { - return pieces_.element(shape_index); - } - - // Returns the piece at the root of the shape (empty ShapeIndex). - Piece& root_piece() { return piece({}); } - const Piece& root_piece() const { return piece({}); } + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; - // Deallocate the buffers held by this literal (if the literal owns the - // buffer). - void DeallocateBuffers(); + Piece* root_piece_ = nullptr; // Implementation details shared between Populate() and PopulateParallel() template Status PopulateInternal(const FnType& generator, bool parallel); - Shape shape_; - ShapeTree pieces_; - - // Whether the buffers held in pieces_ are owned by this Literal. - bool owns_buffers_; - - // LiteralView must access and manipulate Pieces of other Literals. - friend class LiteralView; -}; // namespace xla + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + friend class LiteralBase; +}; std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralView contains pointers to buffers -// owned by the viewed Literal. -// -// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable -// and mutable) similar to (Mutable)ArraySlice. -class LiteralView : public Literal { +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { public: - // Create and return a view of the given literal rooted at the given shape - // index within the given literal. A factory is used rather than a public - // constructor because only const LiteralViews are supported. It's still - // possible to create non-const LiteralViews via the copy constructors, but - // the factory method makes it a bit less likely. Implementing literal slices - // will fix this undesirable situation (b/71550060). - static const LiteralView Create(const Literal& literal, - const ShapeIndex& view_root = {}); + LiteralSlice() : LiteralBase() {} + + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); - LiteralView(const LiteralView& other); - LiteralView& operator=(const LiteralView& other); + private: + const Piece& root_piece() const override { return *root_piece_; }; + + const Piece* root_piece_; // Not owned. +}; - virtual ~LiteralView(); +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. private: - LiteralView(const Literal& literal, const ShapeIndex& view_root); + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); - // Helper for the copy constructor and copy assignment operator. - void CopyFrom(const LiteralView& other); + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. + const Shape shape_; }; template -tensorflow::gtl::ArraySlice Literal::Piece::data() const { +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -863,7 +1117,7 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { } template -tensorflow::gtl::MutableArraySlice Literal::Piece::data() { +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -876,7 +1130,7 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { } template -NativeT Literal::Piece::Get( +NativeT LiteralBase::Piece::Get( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -884,15 +1138,15 @@ NativeT Literal::Piece::Get( } template -void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } template -tensorflow::gtl::ArraySlice Literal::data( +tensorflow::gtl::ArraySlice LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } @@ -904,13 +1158,13 @@ tensorflow::gtl::MutableArraySlice Literal::data( } template -inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT Literal::Get( +inline NativeT LiteralBase::Get( tensorflow::gtl::ArraySlice multi_index) const { return root_piece().Get(multi_index); } @@ -1157,13 +1411,13 @@ template } template -NativeT Literal::GetFirstElement() const { +NativeT LiteralBase::GetFirstElement() const { return data().at(0); } template -NativeT Literal::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { CHECK( LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); return data(shape_index)[sparse_element_number]; @@ -1196,7 +1450,7 @@ template } template -void Literal::EachCell( +void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { @@ -1372,7 +1626,7 @@ template } template -std::unique_ptr Literal::Replicate(int64 times) const { +std::unique_ptr LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { @@ -1407,6 +1661,38 @@ std::unique_ptr Literal::Replicate(int64 times) const { return literal; } +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + auto literal = MakeUnique(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 61046784e05623..f127cee0fdc126 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -974,7 +975,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { Literal::CreateR1({2.0, 4.0}).get(), &nil_literal}); - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); @@ -985,7 +986,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); @@ -1065,7 +1066,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1107,7 +1108,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1373,36 +1374,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, LiteralViewTest) { +TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); - EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); - EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); - EXPECT_EQ(LiteralView::Create(nil, {}), nil); + EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); + EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); } -TEST_F(LiteralUtilTest, MutatingLiteralView) { +TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); EXPECT_EQ( nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); @@ -1418,19 +1419,57 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) { 555.0f); } -TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { +TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); - const auto tuple_view = - LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); - const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { + std::vector int64_values = {1, 2, 3}; + const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); + + BorrowingLiteral literal(reinterpret_cast(int64_values.data()), + literal_shape); + + EXPECT_EQ(literal.Get({0}), 1); + EXPECT_EQ(literal.Get({1}), 2); + EXPECT_EQ(literal.Get({2}), 3); +} + +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { + std::vector one_two_three = {1, 2, 3}; + const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); + + std::vector hundred = {100}; + const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1}); + + std::vector src_buf_ptrs; + src_buf_ptrs.emplace_back( + reinterpret_cast(one_two_three.data())); + src_buf_ptrs.emplace_back(reinterpret_cast(hundred.data())); + auto literal_tuple = BorrowingLiteral( + src_buf_ptrs, + ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape})); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{0}), + 1); + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{1}), + 100); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{1}, /*shape_index=*/{0}), + 2); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{2}, /*shape_index=*/{0}), + 3); +} + TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1533,11 +1572,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_EQ(literal.Get({1, 1}), 4.0); } -TEST_F(LiteralUtilTest, LiteralViewCopy) { +TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralView::Create(*matrix); - LiteralView matrix_view_copy(matrix_view); + const auto matrix_view = LiteralSlice(*matrix); + LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); @@ -1771,5 +1810,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { + std::unique_ptr literal = Literal::CreateR0(9); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 8db8c6f3de84a6..3c74e070da529b 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault( // Inserts the key-value pair into the collection. Dies if key was already // present. -template -void InsertOrDie(Collection* const collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& data) { - auto p = collection->insert(std::make_pair(key, data)); +template +void InsertOrDie(Collection* const collection, Key&& key, Value&& value) { + auto p = collection->insert( + std::make_pair(std::forward(key), std::forward(value))); CHECK(p.second) << "duplicate key: " << key; } @@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) { } // Inserts `value` into `set`. Dies if it was already present. -template -void InsertOrDie(Set* const set, const typename Set::value_type& value) { - CHECK(set->insert(value).second) << "duplicate value: " << value; +template +void InsertOrDie(Set* const set, Value&& value) { + CHECK(set->insert(std::forward(value)).second) + << "duplicate value: " << value; } } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ecb87bd8893276..83834c1ff65ea2 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -12,6 +12,7 @@ py_library( deps = [ ":pywrap_xla", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -49,9 +50,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 24e17abbe06197..f808990cadeab5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/default/thread_annotations.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace xla { @@ -104,25 +105,25 @@ static StatusOr ToBuffer(LocalClient* client, } /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral( +StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - ScopedShapedBuffer buf = [&] { + StatusOr buf = [&] { if (shape_with_layout) { std::unique_ptr relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid) - .ConsumeValueOrDie(); + return ToBuffer(client, /*device_ordinal=*/0, *relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); + return ToBuffer(client, /*device_ordinal=*/0, argument); }(); - return new LocalShapedBuffer(std::move(buf)); + TF_RETURN_IF_ERROR(buf.status()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -std::unique_ptr LocalShapedBuffer::ToLiteral() const { +StatusOr> LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); + return client->ShapedBufferToLiteral(*shaped_buffer()); } CompiledLocalComputation::CompiledLocalComputation( @@ -197,8 +198,6 @@ StatusOr> CompiledLocalComputation::Execute( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); options.set_allocator(client->backend().memory_allocator()); - options.set_inter_op_thread_pool( - client->backend().inter_op_thread_pool()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); @@ -242,7 +241,6 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( // Execute ExecutableRunOptions options; options.set_allocator(client->backend().memory_allocator()); - options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); ScopedShapedBuffer result_buffer = @@ -251,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -274,18 +272,31 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } +string LocalComputation::GetSerializedProto() const { + string result; + if (!computation_.proto().SerializeToString(&result)) { + LOG(ERROR) << "Failed to serialize the HloModuleProto."; + return ""; + } + return result; +} + StatusOr LocalComputation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation_.GetProgramShape()); return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -294,19 +305,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -314,222 +327,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -537,23 +564,19 @@ StatusOr> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e1048909ab29c2..9ac13b65231c93 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -59,12 +60,14 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral( + static StatusOr FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - std::unique_ptr ToLiteral() const; + + StatusOr > ToLiteral() const; private: ScopedShapedBuffer shaped_buffer_; @@ -95,25 +98,42 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; + + // Returns the HloModuleProto contained in the XlaComputation in the + // serialized binary format. Logs an internal error and returns an empty + // string on failure. + string GetSerializedProto() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -133,166 +153,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr IsConstant(const LocalOp& operand); - StatusOr > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -336,7 +327,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index ac792e8189bda9..536b93c6f9381a 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,27 +174,25 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::CompiledLocalComputation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - temp.set_handle(handle); - $1 = &temp; } -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - -%typemap(out) StatusOr { +%typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); { auto* $1 = value; - $typemap(out, xla::swig::CompiledLocalComputation*) + $typemap(out, xla::swig::LocalShapedBuffer*) } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -288,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - SWIG_fail; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -866,6 +851,11 @@ tensorflow::ImportNumpy(); })) { return nullptr; } + if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { + build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); })) { @@ -921,6 +911,8 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalComputation::GetSerializedProto; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index dc6f5fe5fcc067..68648a3a176363 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -340,13 +340,13 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const Literal& literal) { +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM( - tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + PyTuple_SET_ITEM(tuple, i, + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } return tuple; } else { @@ -431,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 9656cb1c31c39d..64f0aae0f9790f 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -74,7 +74,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const Literal& literal); +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,7 +90,7 @@ StatusOr > XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array); template @@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { } template -void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { +void CopyLiteralToNumpyArray(const LiteralSlice& literal, + PyArrayObject* py_array) { NativeT* dest = static_cast(PyArray_DATA(py_array)); auto source = literal.data(); std::copy(source.begin(), source.end(), dest); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f6809b6b871d7e..11611ac61287da 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,6 +28,7 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api +from tensorflow.compiler.xla.service import hlo_pb2 # Most functions are snake_case for consistency with other modules, whereas @@ -335,20 +336,6 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] - - def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -366,6 +353,7 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None self.dump_optimized_hlo_proto_to = None + self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False @@ -424,6 +412,17 @@ def __init__(self, c_local_computation, is_compiled): assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + def GetProto(self): + """Get the HloModuleProto proto object in this local computation. + + Returns: + An HloModuleProto proto object that has the whole-graph information. + """ + + serialized = self.c_local_computation.GetSerializedProto() + proto = hlo_pb2.HloModuleProto.FromString(serialized) + return proto + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. @@ -535,9 +534,9 @@ def Infeed(self, shape): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -545,9 +544,7 @@ def Outfeed(self, operand): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -557,10 +554,10 @@ def Constant(self, value): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -569,7 +566,7 @@ def ConstantF32Scalar(self, value): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -580,7 +577,7 @@ def ConstantF64Scalar(self, value): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -591,7 +588,7 @@ def ConstantS32Scalar(self, value): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -602,7 +599,7 @@ def ConstantS64Scalar(self, value): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -613,7 +610,7 @@ def ConstantPredScalar(self, value): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -629,15 +626,14 @@ def ParameterWithShape(self, shape, name=None, parameter_num=None): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -649,7 +645,7 @@ def ParameterFromNumpy(self, value, name=None, parameter_num=None): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) @@ -658,14 +654,13 @@ def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -675,10 +670,9 @@ def Concatenate(self, operands, dimension): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -688,14 +682,12 @@ def ConvertElementType(self, operand, new_element_type): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -707,40 +699,35 @@ def Pad(self, operand, padding_value, padding_config): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -749,67 +736,56 @@ def CrossReplicaSum(self, operand): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -817,17 +793,13 @@ def Select(self, pred, on_true, on_false): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -836,207 +808,177 @@ def Slice(self, operand, start_indices, limit_indices, strides=None): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -1044,112 +986,105 @@ def While(self, cond, body, init): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. + """ + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1159,14 +1094,9 @@ def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1183,6 +1113,61 @@ def _GetConvDimensionNumbers(self, num_spatial_dims): dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers + def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a + triple (lhs_spec, rhs_spec, out_spec) where each element is a string of + length N+2 identifying by position (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. + + Returns: a LocalOp representing the ConvGenralDilated operation. + """ + if not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. @@ -1196,15 +1181,14 @@ def forward_to_local_builder_with_handles(target_method, is_binop=False): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c073c02040e4d2..375e720f9b433f 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -164,6 +164,16 @@ def testSum2DF32(self): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testGetProto(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + built = c.Build() + proto = built.GetProto() # HloModuleProto + self.assertTrue(len(proto.computations) == 1) + self.assertTrue(len(proto.computations[0].instructions) == 3) + def testSum2DF64(self): c = self._NewComputation() c.Add( @@ -509,6 +519,46 @@ def testConvWithGeneralPaddingF32(self): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=result) + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = ("NHWC", "OIHW", "CWNH") + c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), + c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index df9dbc58308f04..c289c84cff7438 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -572,7 +572,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module("ReferenceUtil"); + HloModuleConfig config; + HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 28d6a8c3fe85fa..8fa6961d197dce 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -265,9 +265,9 @@ class ReferenceUtil { const Array3D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 3); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; for (int i = 0; i < 3; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -299,9 +299,9 @@ class ReferenceUtil { const Array4D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 4); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; for (int i = 0; i < 4; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -330,13 +330,14 @@ class ReferenceUtil { return result; } - // Slices with modulo-wrapping. + // Slices with index clamping template - static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, - int64 start, int64 size) { + static std::vector ClampSlice1D( + const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { - result.push_back(input[(start + i) % input.size()]); + result.push_back(input[(start + i)]); } return result; } @@ -552,12 +553,11 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 3); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()}; + int64 pad_low[3]; + int64 pad_high[3]; + int64 pad_interior[3]; + int64 output_bounds[3]; for (int64 i = 0; i < 3; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); @@ -573,7 +573,7 @@ class ReferenceUtil { Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; + int indices[] = {0, 0, 0}; for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { @@ -611,12 +611,12 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 4); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + int64 pad_low[4]; + int64 pad_high[4]; + int64 pad_interior[4]; + int64 output_bounds[4]; for (int64 i = 0; i < 4; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 977f8637873a4b..0d56a9a477b159 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -55,7 +55,7 @@ tf_cc_test( deps = [ ":grpc_stub", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index b559ee4b5a345d..313f11a9a95715 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpc++/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -84,7 +84,7 @@ TEST_F(GRPCClientTestBase, ItsAlive) { } TEST_F(GRPCClientTestBase, AxpyTenValues) { - ComputationBuilder builder(client_.get(), "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); @@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - LiteralTestUtil::ExpectNear(*expected_literal, *result_literal, - ErrorSpec(0.0001)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + ErrorSpec(0.0001))); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 0b100bd108e239..4e1435fa30a24c 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -27,24 +27,11 @@ namespace xla { return std::move(grpc_service); } -::grpc::Status DelegateRPC(std::function op) { - tensorflow::Status s = op(); +::grpc::Status DelegateRPC(std::function op) { + Status s = op(); return tensorflow::ToGrpcStatus(s); } -::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Computation(arg, result); }); -} - -::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, - const OpRequest* arg, OpResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Op(arg, result); }); -} - ::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) { @@ -60,26 +47,11 @@ ::grpc::Status GRPCService::DeconstructTuple(::grpc::ServerContext* context, }); } -::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - return DelegateRPC([this, arg, results]() { - return service_->SetReturnValue(arg, results); - }); -} - -::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) { +::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) { return DelegateRPC( - [this, arg, result]() { return service_->Execute(arg, result); }); -} - -::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); + [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); } ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, @@ -129,20 +101,6 @@ ::grpc::Status GRPCService::ResetDevice(::grpc::ServerContext* context, [this, arg, result]() { return service_->ResetDevice(arg, result); }); } -::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->IsConstant(arg, result); }); -} - -::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ComputeConstant(arg, result); }); -} - ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) { @@ -150,43 +108,4 @@ ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, [this, arg, result]() { return service_->GetShape(arg, result); }); } -::grpc::Status GRPCService::GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationShape(arg, result); - }); -} - -::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->GetLocalShape(arg, result); }); -} - -::grpc::Status GRPCService::GetComputationStats( - ::grpc::ServerContext* context, const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationStats(arg, result); - }); -} - -::grpc::Status GRPCService::SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->SnapshotComputation(arg, result); - }); -} - -::grpc::Status GRPCService::LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->LoadComputationSnapshot(arg, result); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index fad74375bd59f7..5cd573167ae8c0 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service { static StatusOr> NewService( se::Platform* platform = nullptr); - ::grpc::Status Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) override; - - ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, - OpResponse* result) override; - ::grpc::Status Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) override; @@ -46,17 +39,9 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - ::grpc::Status Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) override; - - ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, @@ -82,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service { const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - ::grpc::Status IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) override; - - ::grpc::Status ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - ::grpc::Status GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) override; - ::grpc::Status GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ::grpc::Status GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - ::grpc::Status GetComputationStats(::grpc::ServerContext* context, - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - ::grpc::Status SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - ::grpc::Status LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - private: std::unique_ptr<::xla::Service> service_; diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index e1f2b0abe39b10..7b8ab158e1396d 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -20,82 +20,56 @@ namespace xla { GRPCStub::~GRPCStub() = default; -tensorflow::Status MakeRPC( +Status MakeRPC( const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { ::grpc::ClientContext context; ::grpc::Status s = rpc_method(&context); return tensorflow::FromGrpcStatus(s); } -tensorflow::Status GRPCStub::TransferToClient( - const TransferToClientRequest* request, - TransferToClientResponse* response) { +Status GRPCStub::TransferToClient(const TransferToClientRequest* request, + TransferToClientResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToClient(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToServer( - const TransferToServerRequest* request, - TransferToServerResponse* response) { +Status GRPCStub::TransferToServer(const TransferToServerRequest* request, + TransferToServerResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToServer(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToInfeed( - const TransferToInfeedRequest* request, - TransferToInfeedResponse* response) { +Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToInfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferFromOutfeed( - const TransferFromOutfeedRequest* request, - TransferFromOutfeedResponse* response) { +Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferFromOutfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, - ResetDeviceResponse* response) { +Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ResetDevice(context, *request, response); }); } -tensorflow::Status GRPCStub::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->LoadComputationSnapshot(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Execute(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteGraph(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteParallel( - const ExecuteParallelRequest* request, ExecuteParallelResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteParallel(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ExecuteGraphParallel( +Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -103,38 +77,21 @@ tensorflow::Status GRPCStub::ExecuteGraphParallel( }); } -tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteAsync(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::WaitForExecution( - const WaitForExecutionRequest* request, - WaitForExecutionResponse* response) { +Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->WaitForExecution(context, *request, response); }); } -tensorflow::Status GRPCStub::DeconstructTuple( - const DeconstructTupleRequest* request, - DeconstructTupleResponse* response) { +Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->DeconstructTuple(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationStats( - const ComputationStatsRequest* request, - ComputationStatsResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationStats(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::GetComputationGraphStats( +Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -142,81 +99,28 @@ tensorflow::Status GRPCStub::GetComputationGraphStats( }); } -tensorflow::Status GRPCStub::GetComputationShape( - const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationShape(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request, - GetShapeResponse* response) { +Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetDeviceHandles( - const GetDeviceHandlesRequest* request, - GetDeviceHandlesResponse* response) { +Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetDeviceHandles(context, *request, response); }); } -tensorflow::Status GRPCStub::CreateChannelHandle( - const CreateChannelHandleRequest* request, - CreateChannelHandleResponse* response) { +Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateChannelHandle(context, *request, response); }); } -// Methods used by ComputationBuilder. -tensorflow::Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Computation(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::Op(const OpRequest* request, - OpResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->CreateOp(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetLocalShape(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::SetReturnValue( - const SetReturnValueRequest* request, SetReturnValueResponse* responses) { - return MakeRPC([this, request, responses](::grpc::ClientContext* context) { - return grpc_stub_->SetReturnValue(context, *request, responses); - }); -} - -tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->IsConstant(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ComputeConstant( - const ComputeConstantRequest* request, ComputeConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ComputeConstant(context, *request, response); - }); -} - -tensorflow::Status GRPCStub::ComputeConstantGraph( +Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -224,18 +128,9 @@ tensorflow::Status GRPCStub::ComputeConstantGraph( }); } -// Methods used by Computation. -tensorflow::Status GRPCStub::SnapshotComputation( - const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->SnapshotComputation(context, *request, response); - }); -} - // Methods used by GlobalData. -tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request, - UnregisterResponse* response) { +Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Unregister(context, *request, response); }); diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index fd9810d4f1a5e0..8dfcb761387d60 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -28,105 +28,51 @@ class GRPCStub : public ServiceInterface { explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} ~GRPCStub() override; - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* request, - ExecuteParallelResponse* response) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* request, - ComputationStatsResponse* response) override; - - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; - - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; - - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; - - // Methods used by ComputationBuilder. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; - - // Methods used by Computation. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Methods used by GlobalData. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; grpc::XlaService::Stub* service() { return grpc_stub_; } diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index c47164ee1b7657..92eb19ec0f9696 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -75,19 +75,7 @@ service XlaService { rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { } - // Requests the program shape of the referenced computation. - rpc GetComputationShape(GetComputationShapeRequest) - returns (GetComputationShapeResponse) { - } - // Requests the statistics of the given computation. - rpc GetComputationStats(ComputationStatsRequest) - returns (ComputationStatsResponse) { - } - - // Requests the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc GetComputationGraphStats(ComputationGraphStatsRequest) returns (ComputationStatsResponse) { } @@ -121,15 +109,6 @@ service XlaService { rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { } - // Tests if an expression is a compile-time constant. - rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { - } - - // Computes the value of a constant expression. - rpc ComputeConstant(ComputeConstantRequest) - returns (ComputeConstantResponse) { - } - // Computes the value of a constant expression. The request contains the // computation graph for the constant expression. rpc ComputeConstantGraph(ComputeConstantGraphRequest) @@ -165,20 +144,6 @@ service XlaService { rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { } - // Computation creates a new computation with the given name. - // A unique ComputationHandle is returned. - rpc Computation(ComputationRequest) returns (ComputationResponse) { - } - - // Adds a new op to a computation. - rpc CreateOp(OpRequest) returns (OpResponse) { - } - - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - rpc Execute(ExecuteRequest) returns (ExecuteResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. @@ -188,38 +153,13 @@ service XlaService { // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and // execution timing. - rpc ExecuteParallel(ExecuteParallelRequest) - returns (ExecuteParallelResponse) { - } - - // Invokes the provided list of computations in parallel with the provided - // global data for each computation. Returns a list of global data output and - // execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) returns (ExecuteParallelResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { - } - // Waits until the given execution (aysnchronously launched) is complete, and // returns the global data output. rpc WaitForExecution(WaitForExecutionRequest) returns (WaitForExecutionResponse) { } - - // Serializes a computation to proto form, so it can be loaded via - // LoadComputationSnapshot. - rpc SnapshotComputation(SnapshotComputationRequest) - returns (SnapshotComputationResponse) { - } - - // Loads a computation from a captured snapshot. - rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) - returns (LoadComputationSnapshotResponse) { - } } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index e794ec12c7b7af..1f2de0c9553933 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -12,21 +12,26 @@ package_group( ], ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( - name = "session_proto", - srcs = ["session.proto"], + name = "hlo_proto", + srcs = ["hlo.proto"], visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) -xla_proto_library( - name = "hlo_proto", +tf_proto_library_py( + name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], - deps = ["//tensorflow/compiler/xla:xla_data_proto"], + visibility = ["//visibility:public"], ) xla_proto_library( @@ -200,7 +205,22 @@ tf_cc_test( cc_library( name = "hlo_evaluator", - srcs = ["hlo_evaluator.cc"], + srcs = [ + "hlo_evaluator.cc", + "hlo_evaluator_typed_visitor.h", + "hlo_evaluator_typed_visitor_bfloat16.cc", + "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex64.cc", + "hlo_evaluator_typed_visitor_double.cc", + "hlo_evaluator_typed_visitor_float.cc", + "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int32.cc", + "hlo_evaluator_typed_visitor_int64.cc", + "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint32.cc", + "hlo_evaluator_typed_visitor_uint64.cc", + "hlo_evaluator_typed_visitor_uint8.cc", + ], hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", @@ -233,7 +253,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -256,7 +276,9 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "hlo_clone_context.h", "hlo_computation.h", + "hlo_domain_metadata.h", "hlo_instruction.h", "hlo_module.h", "hlo_opcode.h", @@ -280,6 +302,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -319,8 +342,8 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -358,6 +381,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -368,6 +392,7 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -389,12 +414,14 @@ tf_cc_test( srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -411,6 +438,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -515,45 +543,6 @@ tf_cc_test( ], ) -cc_library( - name = "user_computation", - srcs = ["user_computation.cc"], - hdrs = ["user_computation.h"], - deps = [ - ":hlo", - ":session_proto", - ":shape_inference", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "user_computation_test", - srcs = ["user_computation_test.cc"], - deps = [ - ":hlo_matchers", - ":user_computation", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -602,7 +591,6 @@ cc_library( ":compilation_cache", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":execution_tracker", @@ -613,10 +601,8 @@ cc_library( ":hlo_module_config", ":hlo_proto_util", ":platform_util", - ":session_proto", ":source_map_util", ":transfer_manager", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", @@ -644,7 +630,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":hlo", @@ -653,7 +638,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", @@ -678,7 +662,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":platform_util", ":service", "//tensorflow/compiler/xla:status_macros", @@ -744,6 +727,23 @@ cc_library( ], ) +tf_cc_test( + name = "shaped_buffer_test", + srcs = ["shaped_buffer_test.cc"], + deps = [ + ":cpu_plugin", + ":device_memory_allocator", + ":platform_util", + ":shaped_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:ptr_util", + "//tensorflow/core:test", + ], +) + cc_library( name = "executable", srcs = ["executable.cc"], @@ -759,7 +759,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_proto", ":pool", - ":session_proto", ":shaped_buffer", ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", @@ -781,6 +780,7 @@ cc_library( srcs = ["compiler.cc"], hdrs = ["compiler.h"], deps = [ + ":buffer_value", ":executable", ":hlo", ":hlo_module_config", @@ -837,7 +837,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -857,33 +856,12 @@ cc_library( ], ) -cc_library( - name = "computation_tracker", - srcs = ["computation_tracker.cc"], - hdrs = ["computation_tracker.h"], - deps = [ - ":hlo", - ":hlo_module_config", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":session_proto", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -918,33 +896,6 @@ tf_cc_test( ], ) -cc_library( - name = "liveness_util", - srcs = ["liveness_util.cc"], - hdrs = ["liveness_util.h"], - deps = [ - ":hlo", - ":hlo_dataflow_analysis", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - ], -) - -tf_cc_test( - name = "liveness_util_test", - srcs = ["liveness_util_test.cc"], - deps = [ - ":hlo", - ":liveness_util", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "buffer_liveness", srcs = [ @@ -956,7 +907,6 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -993,6 +943,7 @@ cc_library( ], deps = [ ":buffer_liveness", + ":buffer_value_containers", ":heap_simulator", ":hlo", ":hlo_proto", @@ -1015,8 +966,8 @@ tf_cc_test( srcs = ["buffer_assignment_test.cc"], deps = [ ":buffer_assignment", + ":buffer_value", ":call_graph", - ":computation_tracker", ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", @@ -1030,9 +981,9 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1047,7 +998,6 @@ cc_library( ":hlo_dataflow_analysis", ":hlo_proto", ":hlo_value", - ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1069,9 +1019,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1080,11 +1030,11 @@ cc_library( srcs = ["heap_simulator.cc"], hdrs = ["heap_simulator.h"], deps = [ + ":buffer_value", + ":buffer_value_containers", ":hlo", ":hlo_ordering", ":hlo_proto", - ":liveness_util", - ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1096,10 +1046,11 @@ tf_cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ + ":buffer_value", ":heap_simulator", ":hlo", ":hlo_ordering", - ":logical_buffer", + ":hlo_value", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", @@ -1164,15 +1115,16 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ + ":buffer_value", ":hlo", ":hlo_ordering", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1205,6 +1157,7 @@ tf_cc_test( deps = [ ":hlo_matchers", ":instruction_fusion", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1247,13 +1200,11 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -1330,6 +1281,43 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "batch_dot_simplification", + srcs = ["batch_dot_simplification.cc"], + hdrs = ["batch_dot_simplification.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "batch_dot_simplification_test", + srcs = ["batch_dot_simplification_test.cc"], + deps = [ + ":batch_dot_simplification", + ":hlo", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -1343,9 +1331,9 @@ tf_cc_test( deps = [ ":gather_expander", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1651,23 +1639,21 @@ tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ - ":computation_tracker", ":cpu_plugin", ":hlo", ":hlo_cost_analysis", ":local_service", ":service", - ":user_computation", ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1697,8 +1683,10 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -1748,11 +1736,38 @@ tf_cc_test( ], ) +cc_library( + name = "buffer_value", + srcs = ["buffer_value.cc"], + hdrs = ["buffer_value.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "buffer_value_containers", + hdrs = ["buffer_value_containers.h"], + deps = [ + ":buffer_value", + ":logical_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], hdrs = ["logical_buffer.h"], deps = [ + ":buffer_value", ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:shape_util", @@ -1768,6 +1783,7 @@ cc_library( srcs = ["hlo_value.cc"], hdrs = ["hlo_value.h"], deps = [ + ":buffer_value", ":hlo", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -1820,6 +1836,44 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_liveness_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -1956,10 +2010,12 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1983,7 +2039,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_ordering", ":hlo_pass", - ":liveness_util", ":logical_buffer", ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", @@ -2030,6 +2085,24 @@ cc_library( ], ) +cc_library( + name = "hlo_module_dce", + srcs = ["hlo_module_dce.cc"], + hdrs = ["hlo_module_dce.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_liveness_analysis", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], @@ -2065,13 +2138,13 @@ cc_library( hdrs = ["hlo_rematerialization.h"], deps = [ ":buffer_liveness", + ":buffer_value", ":call_graph", ":flatten_call_graph", ":hlo", ":hlo_dce", ":hlo_ordering", ":hlo_scheduling", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2119,6 +2192,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_dce_test", + srcs = ["hlo_module_dce_test.cc"], + deps = [ + ":hlo", + ":hlo_module_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], @@ -2135,9 +2229,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2185,6 +2279,7 @@ cc_library( hdrs = ["hlo_cse.h"], deps = [ ":hlo", + ":hlo_domain_map", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -2207,6 +2302,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -2248,6 +2344,79 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_domain_map", + srcs = ["hlo_domain_map.cc"], + hdrs = ["hlo_domain_map.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_sharding_metadata", + srcs = ["hlo_sharding_metadata.cc"], + hdrs = [ + "hlo_sharding_metadata.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_domain_isolator", + srcs = ["hlo_domain_isolator.cc"], + hdrs = ["hlo_domain_isolator.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_library( + name = "hlo_domain_remover", + srcs = ["hlo_domain_remover.cc"], + hdrs = ["hlo_domain_remover.h"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_domain_test", + srcs = ["hlo_domain_test.cc"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_remover", + ":hlo_parser", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], @@ -2276,8 +2445,14 @@ tf_cc_test( cc_library( name = "device_memory_allocator", - srcs = ["device_memory_allocator.cc"], - hdrs = ["device_memory_allocator.h"], + srcs = [ + "device_memory_allocator.cc", + "owning_device_memory.cc", + ], + hdrs = [ + "device_memory_allocator.h", + "owning_device_memory.h", + ], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2312,6 +2487,24 @@ cc_library( ], ) +xla_test( + name = "elemental_ir_emitter_test", + srcs = ["elemental_ir_emitter_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "hlo_module_config", srcs = ["hlo_module_config.cc"], @@ -2383,7 +2576,6 @@ tf_cc_test( srcs = ["hlo_tfgraph_builder_test.cc"], deps = [ ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", @@ -2436,6 +2628,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -2445,6 +2638,7 @@ tf_cc_test( srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", @@ -2452,7 +2646,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2488,7 +2683,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2592,12 +2787,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", - "@com_google_absl//absl/memory", ], ) @@ -2629,8 +2823,8 @@ tf_cc_test( ":tuple_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2653,9 +2847,10 @@ tf_cc_test( deps = [ ":while_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2681,6 +2876,34 @@ tf_cc_test( ":hlo_matchers", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "while_loop_constant_sinking", + srcs = ["while_loop_constant_sinking.cc"], + hdrs = ["while_loop_constant_sinking.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_constant_sinking_test", + srcs = ["while_loop_constant_sinking_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_constant_sinking", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], @@ -2712,3 +2935,96 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + ], +) + +tf_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":hlo_matchers", + ":indexed_array_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo", + ":hlo_lexer", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_casting_utils", + hdrs = ["hlo_casting_utils.h"], + deps = [":hlo"], +) + +tf_cc_test( + name = "hlo_casting_utils_test", + srcs = ["hlo_casting_utils_test.cc"], + deps = [ + ":hlo_casting_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 8e785de68cb1fb..dc5f1b31bf8510 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -92,26 +92,6 @@ bool ReshapeIsBitcast( valid_bitcast_callback(operand->shape(), reshape->shape()); } -// Adds a scalar computation to the module to enable optimizations with dot -// converting into reduction. -HloComputation* CreateScalarBinaryComputation(HloModule* module, - PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - HloComputation* scalar_computation = - module->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_computation; -} - -} // namespace - // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -177,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub) override; + Status HandleMap(HloInstruction* map) override; + Status HandleMaximum(HloInstruction* maximum) override; Status HandleMinimum(HloInstruction* minimum) override; @@ -220,8 +202,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloComputation* AddReduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, {dim}, AddReduce_computation)); @@ -252,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); - // A Reshape or Broadcast that feeds an element-wise operation with a unique - // non-scalar operand can sink to after the operation. - StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast); + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -291,6 +272,26 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + StatusOr OptimizeDotOfGather(HloInstruction* dot); + + HloComputation* GetOrCreateScalarAddComputation() { + if (scalar_add_computation_) { + return scalar_add_computation_; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation_ = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation_; + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -309,8 +310,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; + + // Cached computation for adding two scalar F32. + HloComputation* scalar_add_computation_ = nullptr; }; +} // namespace + bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, @@ -499,13 +505,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const Literal& literal) { + const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { elems.push_back( - BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); + BuildTupleConstant(computation, LiteralSlice(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { @@ -912,6 +918,134 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.rhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0 || + dnums.rhs_batch_dimensions_size() != 0 || + dot->shape().dimensions_size() != 2) { // dot output 2D + VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; + return nullptr; + } + + // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). + // Currently a Gather is a DynamicSlice. + auto is_dynamic_slice_constant_combination = + [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { + // First operand is a DynamicSlice(Constant). + if (a->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + auto* dynamic_slice_op = a->operand(0); + if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { + return false; + } + // Second operand is a Constant. + if (b->opcode() != HloOpcode::kConstant) { + return false; + } + // The DynamicSlice output is a vector. + const Shape& dynamic_slice_shape = a->shape(); + if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { + return false; + } + // Constant size is the same before and after slice in the contracting + // dimension, otherwise we either must precompute for all possible slice + // indices or dot is invalid. + const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); + if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != + dynamic_slice_shape.dimensions(a_contracting_dimension)) { + return false; + } + return true; + }; + + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + + if (!is_dynamic_slice_constant_combination( + lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && + !is_dynamic_slice_constant_combination( + rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { + VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " + "dot(ctB, DS(ctA)), where the two constants have equal " + "contracting dimensions."; + return nullptr; + } + + // LHS is DynamicSlice: + // input: dot(DS(ctA), ctB)) + // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. + // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. + + // RHS is DynamicSlice: + // input: dot(ctA, DS(ctB)) + // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). + // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. + + bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + + // ctA: + HloInstruction* left_operand = + lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; + // ctB: + HloInstruction* right_operand = + lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); + // Build ctA x ctB. + const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); + const int n = + right_operand->shape().dimensions(1 - rhs_contracting_dimension); + auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); + auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( + memoized_shape, left_operand, right_operand, dnums)); + // Get pair {start, 0} or {0, start}. + HloInstruction* original_start_indices = + lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); + // Position of start: + int index_of_non_zero_start = lhs_is_dynamic_slice + ? 1 - lhs_contracting_dimension + : 1 - rhs_contracting_dimension; + // Position of zero: + int index_of_zero_start = 1 - index_of_non_zero_start; + + // Slice out start and 0 components and reorder if necessary. + auto indices_type = original_start_indices->shape().element_type(); + Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + HloInstruction* non_zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_non_zero_start}, + {index_of_non_zero_start + 1}, {1})); + HloInstruction* zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_zero_start}, + {index_of_zero_start + 1}, {1})); + HloInstruction* new_start_indices = + lhs_is_dynamic_slice + ? computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {non_zero_start, zero_start}, 0)) + : computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {zero_start, non_zero_start}, 0)); + + // Build DynamicSlice(ctA x ctB). + const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; + const int new_slice_n = lhs_is_dynamic_slice ? n : 1; + auto* memoized_lookup = + computation_->AddInstruction(HloInstruction::CreateDynamicSlice( + dot->shape(), memoized_inst, new_start_indices, + {new_slice_m, new_slice_n})); + + return memoized_lookup; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { HloInstruction *lhs, *rhs; CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); @@ -941,6 +1075,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_concat_optimized); } + // Simplify dot(ConstA, Gather(Index, ConstB)) to: + // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately + // batched version of dot. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, + OptimizeDotOfGather(dot)); + if (dot_of_gather_optimized) { + VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " + "gather(i, dot*(constA, constB))"; + return ReplaceInstruction(dot, dot_of_gather_optimized); + } + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); @@ -1160,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); changed_ |= sink_succeeded; if (sink_succeeded) { return Status::OK(); @@ -1412,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor:: - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast) { +StatusOr +AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast) { + TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); bool changed = false; - if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + if (ShapeUtil::IsScalar(broadcast->shape())) { return false; } - HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); - for (HloInstruction* user : reshape_or_broadcast->users()) { + HloInstruction* operand = broadcast->mutable_operand(0); + for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; } @@ -1438,55 +1584,50 @@ StatusOr AlgebraicSimplifierVisitor:: continue; } - int64 reshape_or_broadcast_operand_index = -1; // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_count = 0; - for (int64 i = 0; i < user->operand_count(); ++i) { - if (ShapeUtil::IsScalar(user->operand(i)->shape())) { - ++scalar_count; - } else { - reshape_or_broadcast_operand_index = i; + int64 scalar_broadcast_count = 0; + int64 broadcast_use_count = 0; + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; + } else if (broadcast == user_operand) { + ++broadcast_use_count; } } - if (scalar_count != user->operand_count() - 1) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } - VLOG(4) << "Sinking reshape or broadcast after user:"; - VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); + std::vector new_operands; + new_operands.reserve(user->operand_count()); + + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()), + user_operand->mutable_operand(0), {}))); + } else { + CHECK_EQ(broadcast, user_operand); + new_operands.push_back(operand); + } + } + VLOG(4) << "Sinking broadcast after user:"; + VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), - reshape_or_broadcast); - auto new_user_operands = user->operands(); - new_user_operands[reshape_or_broadcast_operand_index] = operand; - auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(operand->shape().dimensions()), - LayoutUtil::MinorToMajor(operand->shape())), - new_user_operands)); + HloInstruction* new_user = + computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()), + new_operands)); VLOG(4) << " new user: " << new_user->ToString(); - HloInstruction* new_reshape_or_broadcast = nullptr; - if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user)); - } else { - TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user, reshape_or_broadcast->dimensions())); - } - VLOG(4) << " new reshape/broadcast: " - << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); + HloInstruction* new_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + user->shape(), new_user, broadcast->dimensions())); + VLOG(4) << " new broadcast: " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -1529,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } - // A Reshape that feeds a unary element-wise operation can sink the - // reshape after the unary element-wise operation. - TF_ASSIGN_OR_RETURN( - bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - changed_ |= sink_succeeded; - if (sink_succeeded) { - return Status::OK(); - } - // Make this a bitcast if possible. if (is_layout_sensitive_ && ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { @@ -1643,6 +1774,15 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); + } // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1687,15 +1827,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape()) || - ShapeUtil::HasZeroElements(arg->shape())) { - auto reshape = computation_->AddInstruction( - HloInstruction::CreateReshape(reduce->shape(), arg)); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateMap(reduce->shape(), - {init_value, reshape}, function)); - } return Status::OK(); } @@ -1715,7 +1846,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateMap(reduce_window->shape(), - {operand, reduce_window->mutable_operand(1)}, + {reduce_window->mutable_operand(1), operand}, function)); } @@ -2045,6 +2176,39 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } +Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { + auto* map_computation = map->to_apply(); + auto* map_root = map_computation->root_instruction(); + if (map_root->opcode() == HloOpcode::kParameter) { + ReplaceInstructionIfSameShape( + map, map->mutable_operand(map_root->parameter_number())); + return Status::OK(); + } + if (map_root->opcode() == HloOpcode::kConstant) { + if (!ShapeUtil::IsScalar(map_root->shape())) { + return Status::OK(); + } + auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); + if (ShapeUtil::IsScalar(map->shape())) { + return ReplaceWithNewInstruction(map, std::move(clone)); + } + return ReplaceWithNewInstruction( + map, + HloInstruction::CreateBroadcast( + map->shape(), computation_->AddInstruction(std::move(clone)), {})); + } + std::vector new_operands; + for (auto* root_operand : map_root->operands()) { + if (root_operand->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + new_operands.push_back( + map->mutable_operand(root_operand->parameter_number())); + } + auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); + return ReplaceWithNewInstruction(map, std::move(clone)); +} + Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 20c549562d5153..cda157f9fac163 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -142,6 +143,39 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMap); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, zero)); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -1317,32 +1351,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); - HloInstruction* movable_reshape = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), - HloOpcode::kMaximum, movable_reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Maximum(param, zero))); -} - // Regression test for a bug in the reshape sinking transformation, where // moving a reshape to a scalar led to a crash. TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { @@ -1699,14 +1707,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1732,8 +1740,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1751,7 +1759,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1766,14 +1774,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1789,14 +1797,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1924,12 +1932,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(b.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - if (!simplifier.Run(&module).ValueOrDie()) { + if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } auto* root = computation->root_instruction(); @@ -2044,15 +2052,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Minimum(param0, min_value), max_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2074,15 +2082,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2105,15 +2113,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2135,15 +2143,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2167,8 +2175,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2176,7 +2184,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2201,8 +2209,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); @@ -2211,10 +2219,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2242,8 +2250,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); @@ -2251,7 +2259,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2260,7 +2268,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2289,7 +2297,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2312,15 +2320,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2341,7 +2349,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2374,7 +2382,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2397,15 +2405,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2431,12 +2439,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2962,5 +2970,208 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); + +struct DotOfGatherTestSpec { + int64 m; + int64 k; + int64 n; + int s; // start index for dynamic slice on the non-contracting dimension + int64 lcd; // left contracting dimension + int64 rcd; // right contracting dimension + bool neg; // is negative testcase +}; + +class DotOfGatherSimplificationTest + : public HloVerifiedTestBase, + public ::testing::WithParamInterface {}; + +// input: dot(DS(ctA), ctB)) +// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. +// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. +TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.m); + + // For negative tests, increase k of the dynamic slice argument to prevent the + // optimization (constants ctA, ctB must have equal contracting dimensions). + int64 k_increase = spec.neg ? 5 : 0; + int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + int32 start_row = (spec.lcd == 0) ? 0 : spec.s; + int32 start_col = (spec.lcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + + int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = 1; + int64 dot_col_size = spec.n; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +// input: dot(ctA, DS(ctB)) +// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}). +// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. +TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.n); + + int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + // For negative tests increase k of the dynamic slice argument to prevent the + // optimization + int64 k_increase = spec.neg ? 5 : 0; + int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + int32 start_row = (spec.rcd == 0) ? 0 : spec.s; + int32 start_col = (spec.rcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = spec.m; + int64 dot_col_size = 1; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +std::vector DotOfGatherPositiveNegativeTests() { + std::vector positives = { + // "Classical dot", i.e. matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to + // dot(ct, ct) before DotOfGather optimization kicks in. + // Contract on rows: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + // Reverse matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + // Contract on columns: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + }; + std::vector all; + for (int i = 0; i < positives.size(); i++) { + DotOfGatherTestSpec positive_test = positives[i]; + all.push_back(positive_test); + DotOfGatherTestSpec negative_test = positive_test; + negative_test.neg = true; + all.push_back(negative_test); + } + return all; +} + +INSTANTIATE_TEST_CASE_P( + DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, + ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index cf1231bcce4d00..95b4cb6d2e6940 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -101,7 +101,7 @@ StatusOr AllocationTracker::RegisterInternal( return result; } -tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { +Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; @@ -130,7 +130,7 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { for (auto& shaped_buffer : it->second) { shaped_buffer.reset(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( @@ -220,8 +220,10 @@ void AllocationTracker::AddAllocationOrIncrementRefCount( AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); if (it == allocation_map.end()) { - allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, - /*ref_count=*/1}; + allocation_map[device_memory.opaque()] = { + OwningDeviceMemory(device_memory, device_ordinal, + backend_->memory_allocator()), + /*ref_count=*/1}; } else { it->second.ref_count++; } @@ -235,13 +237,12 @@ Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( - device_ordinal, &device_memory)); + allocation.device_memory.Free(); allocation_map.erase(it); } else { allocation.ref_count--; } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 1174fa641c06ae..a7d8927cf7e90d 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -76,10 +76,7 @@ class AllocationTracker { // Data structure encapsulating single memory allocation on the device. struct Allocation { // The pointer to this allocation. - se::DeviceMemoryBase device_memory; - - // The device that the memory is allocated on. - int device_ordinal; + OwningDeviceMemory device_memory; // This is the number of times this memory allocation is referred to by // registered data handles. @@ -126,7 +123,10 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - tensorflow::gtl::FlatMap opaque_to_allocation_map_ + // + // This is not a TF FlatMap because (currently) FlatMap (and therefore + // AllocationMap) is not movable. + std::unordered_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index b1d616ec3506f9..349b32451a697d 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -138,9 +138,6 @@ Backend::Backend( << "Service found no devices for backend " << platform_->Name() << '.'; if (platform->id() == se::host::kHostPlatformId) { - inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "xla_inter_op", - tensorflow::port::NumSchedulableCPUs())); const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); @@ -155,10 +152,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { - return inter_op_thread_pool_.get(); -} - const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { if (intra_op_thread_pool_wrapper_ == nullptr) { diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index d32a0a400d8bd5..6546602473e338 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -140,10 +140,6 @@ class Backend { // be equivalent to an executable compiled for the other. StatusOr devices_equivalent(int device_ordinal_a, int device_ordinal_b); - // For the host platform, returns the threadpool to use when scheduling - // parallel operators. For other platforms, returns NULL. - tensorflow::thread::ThreadPool* inter_op_thread_pool() const; - // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; @@ -178,9 +174,6 @@ class Backend { // The default memory allocator to use. std::unique_ptr memory_allocator_; - // For the CPU backend, a threadpool for scheduling parallel operators. - std::unique_ptr inter_op_thread_pool_; - // For the CPU backend, an Eigen threadpool device for use by Eigen code. std::unique_ptr intra_op_thread_pool_wrapper_; }; diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc new file mode 100644 index 00000000000000..2099916509acdb --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" + +namespace xla { +StatusOr +BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot) { + const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); + HloInstruction *lhs = batch_dot->mutable_operand(0), + *rhs = batch_dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + + std::vector degenerate_dims; + for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { + if (lhs_shape.dimensions(batch_dim) == 1) { + degenerate_dims.push_back(batch_dim); + } + } + + if (degenerate_dims.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, + ElideDegenerateDims(lhs, degenerate_dims)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, + ElideDegenerateDims(rhs, degenerate_dims)); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.clear_lhs_batch_dimensions(); + new_dim_numbers.clear_rhs_batch_dimensions(); + + for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() - + degenerate_dims.size(); + i < e; i++) { + new_dim_numbers.add_lhs_batch_dimensions(i); + new_dim_numbers.add_rhs_batch_dimensions(i); + } + + new_dim_numbers.set_lhs_contracting_dimensions( + 0, + new_dim_numbers.lhs_contracting_dimensions(0) - degenerate_dims.size()); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, + new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, + MakeReshapeHlo(batch_dot->shape(), new_dot)); + + VLOG(2) << "Replaced " << batch_dot->ToString() << " with " + << new_dot->ToString(); + + TF_RETURN_IF_ERROR( + batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped)); + + return true; +} + +tensorflow::StringPiece BatchDotSimplification::name() const { + return "batch-dot-simplification"; +} + +StatusOr BatchDotSimplification::Run(HloModule* module) { + bool changed = false; + std::vector dot_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); + } + for (HloInstruction* dot_instr : dot_instrs) { + TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, + ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); + changed |= elided_batch_dim_from_one; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h new file mode 100644 index 00000000000000..c0ca8d8ebac1a3 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// Simplifies batch dot operations. +// +// Normally these would live in the algebraic simplifier, but we want to run +// this to fixpoint (this pass reaches fixed point in one execution) before we +// run the DotDecomposer. +class BatchDotSimplification : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override; + tensorflow::StringPiece name() const override; + + private: + StatusOr ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc new file mode 100644 index 00000000000000..38f1a5d3a645f9 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class BatchDotSimplificationTest : public HloVerifiedTestBase {}; + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3,7] parameter(1) + ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,3] parameter(1) + ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,20,3] parameter(1) + ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,19,3] parameter(0) + b = f32[9,1,7,1,3,20] parameter(1) + ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 38086bd7e12184..598718c72c6941 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -15,35 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include #include -#include -#include #include #include #include -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -80,30 +77,88 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + HloComputation* GetOrCreateScalarAddComputation( + PrimitiveType primitive_type) { + HloComputation** scalar_add_computation = + &scalar_add_computations_[primitive_type]; + if (*scalar_add_computation) { + return *scalar_add_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + *scalar_add_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_add_computation; } - // Current HloComputation instance the BatchNormExpander is - // traversing. - HloComputation* computation_; + // TODO(b/80534766): Remove maps after performance issues with scalar + // broadcasts are resolved on all backends. + HloComputation* GetOrCreateScalarRsqrtComputation( + PrimitiveType primitive_type) { + HloComputation** scalar_rsqrt_computation = + &scalar_rsqrt_computations_[primitive_type]; + if (*scalar_rsqrt_computation) { + return *scalar_rsqrt_computation; + } - bool rewrite_training_op_; - bool rewrite_inference_op_; - bool rewrite_grad_op_; - bool use_fusion_; + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( + shape, b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kPower, scalar_lhs, scalar_rhs)); + *scalar_rsqrt_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_rsqrt_computation; + } - // Whether rewrite has occurred. - bool changed_ = false; + std::unique_ptr Rsqrt(HloInstruction* operand) { + return HloInstruction::CreateMap( + operand->shape(), {operand}, + GetOrCreateScalarRsqrtComputation(operand->shape().element_type())); + } + + HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type, + int64 element_count) { + HloComputation** scalar_mean_computation = + &scalar_mean_computations_[std::pair( + primitive_type, element_count)]; + if (*scalar_mean_computation) { + return *scalar_mean_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( + shape, b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0( + 1.0f / static_cast(element_count)))))); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs)); + *scalar_mean_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_mean_computation; + } + + std::unique_ptr Mean(int64 element_count, + HloInstruction* operand) { + return HloInstruction::CreateMap( + operand->shape(), {operand}, + GetOrCreateScalarMeanComputation(operand->shape().element_type(), + element_count)); + } // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -127,8 +182,29 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { changed_ = true; return Status::OK(); } + // Current HloComputation instance the BatchNormExpander is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_inference_op_; + bool rewrite_grad_op_; + bool use_fusion_; + + // Whether rewrite has occurred. + bool changed_ = false; + + // Cached computations for adding two scalars. + tensorflow::gtl::FlatMap + scalar_add_computations_; + tensorflow::gtl::FlatMap + scalar_rsqrt_computations_; + tensorflow::gtl::FlatMap, HloComputation*> + scalar_mean_computations_; }; +} // namespace + bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, @@ -156,6 +232,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); // Expand batch norm training into smaller HLO ops. @@ -165,12 +245,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -182,8 +257,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = - add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = add(HloInstruction::CreateBroadcast( + operand_shape, + add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -199,11 +275,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // X^2. - auto operand_squared = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = + add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); // Sum[X]. auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, dimensions_without_feature, @@ -229,56 +305,47 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( } // E[X]. - auto mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + auto mean = add(Mean(elements_per_feature_int64, sum)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum)); // E^2[X]. - auto mean_square = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, mean, mean)); + auto mean_square = + add_binary(feature_shape, HloOpcode::kMultiply, mean, mean); // Var[X]. - auto var = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + auto var = + add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square); auto var_broadcasted = add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted); auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); @@ -320,8 +387,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( + operand_shape, + computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(epsilon_literal))), + {})); std::vector dimensions_without_feature; @@ -338,6 +408,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); auto scale_broadcasted = add( @@ -353,30 +427,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( @@ -424,6 +491,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); @@ -439,26 +510,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = + auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon_activation = add( + HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {})); + auto epsilon_feature = + add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {})); std::vector dimensions_without_feature; @@ -478,29 +543,24 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, - epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = + add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation))); + + auto rsqrt_var_add_epsilon = add(Rsqrt( + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature))); // X - E[X]. - auto activation_minus_mean = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); + auto activation_minus_mean = add_binary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). auto grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + activation_minus_mean); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = @@ -529,9 +589,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); + auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, + sum_grad_output_times_activiation_minus_mean, + rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, @@ -543,39 +603,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( {feature_index})); // I4 = (X - E[X]) * I3 - auto i4 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); + auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3, + activation_minus_mean); // I5 = I4 / (Var[X] + epsilon) - auto i5 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, i4, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)))); + auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4, + add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation)); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = + add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, - elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = + add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon)); - auto i1 = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto elements_per_feature_literal = + Literal::CreateR0(elements_per_feature_int64); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + add(HloInstruction::CreateBroadcast( + activation_shape, elements_per_feature, {}))); // I6 = I1 - I2 - I5 - auto i6 = add(HloInstruction::CreateBinary( + auto i6 = add_binary( activation_shape, HloOpcode::kSubtract, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - i1, i2)), - i5)); + add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, i6)); + auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6); auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 08d0152e3cfcfc..1b8b2d20450357 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape()) || - !bfloat16_support_->SupportsMixedPrecisions(*crs)) { - return DefaultAction(crs); - } - // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + // Then do per-tuple-element handling on the output. std::vector> per_tuple_element_gtes( crs->operand_count()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 28e71c2054f59b..7fd1e733e96da9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* tuple = builder.AddInstruction( HloInstruction::CreateTuple({gte_a, convert_gte_b})); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(FoldConversions(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 1afaefd9df9c57..9926661dd30600 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, + reduction)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 43ebe92c5ec1c9..ed0746980f87ac 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -33,7 +33,7 @@ BFloat16Propagation::BFloat16Propagation( const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} -void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( +void BFloat16Propagation::DetermineFusionComputationPrecision( HloInstruction* fusion) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) { @@ -48,15 +48,13 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( auto root = fusion->fused_instructions_computation()->root_instruction(); // Adjust root's element types according to the fusion's output shape. - ShapeUtil::ForEachMutableSubshape( - root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32) { + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { return; } - if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() == - BF16) { - subshape->set_element_type(BF16); - changed_ = true; + if (OutputTypeAfterChange(fusion, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); VLOG(2) << "Fused root " << root->ToString() << " at shape index " << index << " changed to BF16 precision for fusion " << fusion->ToString(); @@ -67,13 +65,101 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( auto insts = fusion->fused_instructions_computation()->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert( + computations_visited_in_backward_pass_.insert( fusion->fused_instructions_computation()); + + RevertIfFusionInternalBF16Changes(fusion); +} + +void BFloat16Propagation::RevertIfFusionInternalBF16Changes( + HloInstruction* fusion) { + auto has_changes = [this](HloInstruction* inst) { + auto it = changes_to_bf16_.find(inst); + return it != changes_to_bf16_.end() && !it->second.empty(); + }; + + auto root = fusion->fused_instructions_computation()->root_instruction(); + tensorflow::gtl::FlatSet changed_root_buffers; + + auto root_changes_it = changes_to_bf16_.find(root); + if (root_changes_it != changes_to_bf16_.end()) { + for (const auto& index : root_changes_it->second) { + for (const HloValue* value : + dataflow_->GetValueSet(root, index).values()) { + changed_root_buffers.insert(value); + } + } + } + + auto aliases_changed_root_buffer = + [this, &changed_root_buffers](const HloInstruction* inst) { + bool aliasing = false; + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (aliasing) { + // Skip if aliasing is already found. + return; + } + // Only F32 buffers are considered for changing to BF16 in this + // pass. + if (subshape.element_type() != F32) { + return; + } + for (const HloValue* value : + dataflow_->GetValueSet(inst, index).values()) { + if (ContainsKey(changed_root_buffers, value)) { + aliasing = true; + break; + } + } + }); + return aliasing; + }; + + for (auto inst : + fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + if (aliases_changed_root_buffer(inst)) { + continue; + } + if (inst->opcode() == HloOpcode::kFusion) { + bool parameter_reverted = false; + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (has_changes(inst->mutable_operand(i))) { + // Changes on the operand have not been reverted. + continue; + } + auto* fused_parameter = inst->fused_parameter(i); + if (has_changes(fused_parameter)) { + changes_to_bf16_.erase(fused_parameter); + parameter_reverted = true; + } + } + if (parameter_reverted) { + RevertIfFusionInternalBF16Changes(inst); + } + } + if (!has_changes(inst)) { + continue; + } + bool revert_changes = true; + for (auto operand : inst->operands()) { + if (has_changes(operand)) { + revert_changes = false; + break; + } + } + if (revert_changes) { + changes_to_bf16_.erase(inst); + } + } } -void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( +void BFloat16Propagation::DetermineWhileComputationsPrecision( HloInstruction* while_hlo) { CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile); @@ -86,16 +172,14 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( auto body_root = body->root_instruction(); HloComputation* condition = while_hlo->while_condition(); - ShapeUtil::ForEachMutableSubshape( - body_root->mutable_shape(), - [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32) { + ShapeUtil::ForEachSubshape( + body_root->shape(), [this, while_hlo, body_root]( + const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { return; } - if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() == - BF16) { - subshape->set_element_type(BF16); - changed_ = true; + if (OutputTypeAfterChange(while_hlo, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16); VLOG(2) << "While body root " << body_root->ToString() << " at shape index " << index << " changed to BF16 precision for while " @@ -106,30 +190,30 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( auto body_insts = body->MakeInstructionPostOrder(); for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert(body); + computations_visited_in_backward_pass_.insert(body); auto condition_insts = condition->MakeInstructionPostOrder(); for (auto inst_it = condition_insts.rbegin(); inst_it != condition_insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert(condition); + computations_visited_in_backward_pass_.insert(condition); } bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { - auto value_set = dataflow_->GetValueSet(&hlo, index); + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (value->shape().element_type() == BF16) { + if (ValueTypeAfterChange(value) == BF16) { continue; } for (const HloUse& use : value->uses()) { - if (!ContainsKey(instructions_visited_in_mutation_pass_, + if (!ContainsKey(instructions_visited_in_backward_pass_, use.instruction)) { // We don't know yet whether use.instruction will consume BF16 since it // hasn't been visited. Although we visit instructions in reverse @@ -145,26 +229,23 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // precision, or a called computation's parameters have been changed to // BF16 for fusions or whiles. if (use.instruction->opcode() == HloOpcode::kFusion) { - const auto* fused_parameter = + auto* fused_parameter = use.instruction->fused_parameter(use.operand_number); - if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) { return false; } continue; } else if (use.instruction->opcode() == HloOpcode::kWhile) { - const auto* cond_parameter = + auto* cond_parameter = use.instruction->while_condition()->parameter_instruction( use.operand_number); - if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) { return false; } - const auto* body_parameter = + auto* body_parameter = use.instruction->while_body()->parameter_instruction( use.operand_number); - if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) { return false; } continue; @@ -174,19 +255,20 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, continue; } // If the op propagates precision and it outputs a BF16, then it's OK to - // supply BF16 also as the input. In the backward mutation pass, the users - // shapes should have already been processed. + // supply BF16 also as the input. In the backward pass, the users shapes + // should have already been processed. PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; if (use.instruction->opcode() == HloOpcode::kTuple || (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && ShapeUtil::IsTuple(use.instruction->shape()))) { - user_output_type = ShapeUtil::GetSubshape( - ShapeUtil::GetSubshape(use.instruction->shape(), - {use.operand_number}), - use.operand_index) - .element_type(); + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + user_output_type = + OutputTypeAfterChange(use.instruction, use_output_index); } else { - user_output_type = use.instruction->shape().element_type(); + user_output_type = OutputTypeAfterChange(use.instruction, {}); } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( *use.instruction, use.operand_number) && @@ -199,8 +281,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return true; } -void BFloat16Propagation::DetermineAndMutateInstructionPrecision( - HloInstruction* hlo, bool skip_parameters) { +void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, + bool skip_parameters) { // We handle any fusion computation or while body/condition after the // instruction is handled, because we need to know the output shape of a // fusion or while before propagating inside its computations. @@ -209,12 +291,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( [this, hlo, &postpone_processing_called_computations] { if (!postpone_processing_called_computations) { if (hlo->opcode() == HloOpcode::kFusion) { - DetermineAndMutateFusionComputationPrecision(hlo); + DetermineFusionComputationPrecision(hlo); } else if (hlo->opcode() == HloOpcode::kWhile) { - DetermineAndMutateWhileComputationsPrecision(hlo); + DetermineWhileComputationsPrecision(hlo); } } - instructions_visited_in_mutation_pass_.insert(hlo); + instructions_visited_in_backward_pass_.insert(hlo); }); if (hlo->opcode() == HloOpcode::kWhile && @@ -245,9 +327,9 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( CHECK(hlo->parent() != nullptr); if (hlo == hlo->parent()->root_instruction()) { if (!hlo->parent()->IsFusionComputation()) { - ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape, + ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */, const ShapeIndex& index) { - if (subshape.element_type() != F32) { + if (OutputTypeAfterChange(hlo, index) != F32) { return; } for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { @@ -269,13 +351,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( return; } - ShapeUtil::ForEachMutableSubshape( - hlo->mutable_shape(), - [hlo, this](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() == F32 && + ShapeUtil::ForEachSubshape( + hlo->shape(), + [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) { + if (OutputTypeAfterChange(hlo, index) == F32 && AllUsersConsumeBF16(*hlo, index)) { - subshape->set_element_type(BF16); - changed_ = true; + AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16); VLOG(2) << "HloInstruction output at shape index " << index << " changed to BF16 precision: " << hlo->ToString(); } @@ -308,26 +389,24 @@ void BFloat16Propagation::AdjustCalledComputationParameters( CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { auto parameter = computation->parameter_instruction(i); - ShapeUtil::ForEachMutableSubshape( - parameter->mutable_shape(), - [this, i, hlo, &operands, parameter](Shape* subshape, + ShapeUtil::ForEachSubshape( + parameter->shape(), + [this, i, hlo, &operands, parameter](const Shape& /* subshape */, const ShapeIndex& index) { if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { return; } PrimitiveType operand_type = - ShapeUtil::GetSubshape(operands[i]->shape(), index) - .element_type(); - if (subshape->element_type() == operand_type) { + OutputTypeAfterChange(operands[i], index); + if (OutputTypeAfterChange(parameter, index) == operand_type) { return; } - CHECK(operand_type == F32 || operand_type == BF16); - subshape->set_element_type(operand_type); - changed_ = true; + AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type); VLOG(2) << "Called computation parameter " << parameter->ToString() << " at shape index " << index - << " adjusted to match operand in HLO " - << hlo->ToString(); + << " adjusted to " + << (operand_type == BF16 ? "BF16" : "F32") + << " to match operand in HLO " << hlo->ToString(); }); } }; @@ -348,51 +427,48 @@ void BFloat16Propagation::AdjustCalledComputationParameters( void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - const Shape& output_shape) { + HloInstruction* output) { // Adjust root. HloInstruction* root = computation->root_instruction(); - ShapeUtil::ForEachMutableSubshape( - root->mutable_shape(), [this, hlo, root, &output_shape]( - Shape* subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { - return; - } - const PrimitiveType output_type = - ShapeUtil::GetSubshape(output_shape, index).element_type(); - if (subshape->element_type() == output_type) { - return; - } - CHECK(output_type == F32 || output_type == BF16); - subshape->set_element_type(output_type); - // It's possible that output_type is F32, but the root instruction's - // type is BF16; e.g., a fusion node's output was changed to BF16 - // initially but then adjusted back to F32, and the fusion computation - // is now being adjusted after the fusion node. - if (output_type == F32) { - for (const auto* value : - dataflow_->GetValueSet(root, index).values()) { - // We rely on the fact that this adjustment works in reverse - // topological order so that called computation will be - // processed later. Adding the value to - // values_that_must_be_kept_as_f32_ will ensure the - // correctness of the adjustment for HLOs that will be - // processed later. - values_that_must_be_kept_as_f32_.insert(value); - } - } - changed_ = true; - VLOG(2) << "Called computation root " << root->ToString() - << " at shape index " << index - << " adjusted to match output shape of " << hlo->ToString(); - }); + ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output]( + const Shape& /* subshape */, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { + return; + } + const PrimitiveType output_type = OutputTypeAfterChange(output, index); + if (OutputTypeAfterChange(root, index) == output_type) { + return; + } + AddToOrRemoveFromBF16ChangeSet(root, index, output_type); + // It's possible that output_type is F32, but the root instruction's + // type is BF16; e.g., a fusion node's output was changed to BF16 + // initially but then adjusted back to F32, and the fusion computation + // is now being adjusted after the fusion node. + if (output_type == F32) { + for (const auto* value : dataflow_->GetValueSet(root, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order so that called computation will be + // processed later. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the + // correctness of the adjustment for HLOs that will be + // processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index << " adjusted to " + << (output_type == BF16 ? "BF16" : "F32") + << " to match output shape of " << hlo->ToString(); + }); }; switch (hlo->opcode()) { case HloOpcode::kFusion: - adjust_computation(hlo->fused_instructions_computation(), hlo->shape()); + adjust_computation(hlo->fused_instructions_computation(), hlo); break; case HloOpcode::kWhile: - adjust_computation(hlo->while_body(), hlo->shape()); + adjust_computation(hlo->while_body(), hlo); break; default: break; @@ -409,16 +485,19 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; auto adjust_hlo_output = [this, hlo, ¶meter_changed]( - Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32 && subshape->element_type() != BF16) { + const Shape& /* subshape */, + const ShapeIndex& index) { + auto output_type = OutputTypeAfterChange(hlo, index); + if (output_type != F32 && output_type != BF16) { return; } PrimitiveType type = BF16; for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { - if (value->shape().element_type() == BF16) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { continue; } - CHECK_EQ(value->shape().element_type(), F32); + CHECK_EQ(value_type, F32); type = F32; break; } @@ -437,16 +516,17 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( values_that_must_be_kept_as_f32_.insert(value); } } - if (type != subshape->element_type()) { - subshape->set_element_type(type); + if (type != output_type) { + AddToOrRemoveFromBF16ChangeSet(hlo, index, type); VLOG(2) << "HloInstruction output at shape index " << index - << " adjusted to " << *subshape << ": " << hlo->ToString(); + << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": " + << hlo->ToString(); if (hlo->opcode() == HloOpcode::kParameter) { parameter_changed = true; } } }; - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); AdjustCalledComputationRoot(hlo); if (hlo->opcode() == HloOpcode::kWhile) { // We need to run on the while body and condition repeatedly until a fixed @@ -463,8 +543,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), &visited_in_while)) { visited_in_while.clear(); - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), - adjust_hlo_output); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); AdjustCalledComputationRoot(hlo); } visited_computations->insert(visited_in_while.begin(), @@ -478,7 +557,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( return parameter_changed; } -Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( +void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { std::list computations_topological_order = module->MakeComputationPostOrder(); @@ -490,7 +569,9 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( } ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved); } +} +Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { // We could have changed a fusion computation's root shape to have a different // precision than the fusion node's output, if the fusion root does not // define a buffer (e.g., a tuple). Now we add conversions after such fusion @@ -517,7 +598,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( // (2) after adding conversion // (3) after tuple simplifier and DCE. bool needs_tuple_simplifier = false; - for (auto computation : computations_topological_order) { + for (auto computation : module->MakeComputationPostOrder()) { auto insts = computation->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; @@ -587,7 +668,14 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); } } + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + } + return Status::OK(); +} +Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { // We may have converted some constants from F32 to BF16, so adjust the // constant literals in such cases. We do this here instead of when the // constant node's is changed because 1) the HloInstruction interface does not @@ -598,8 +686,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( // can avoid repeated conversions. // // TODO(b/73833576): Consider resetting literal in HloInstruction. - bool needs_dce = needs_tuple_simplifier; - for (auto computation : computations_topological_order) { + for (auto computation : module->MakeComputationPostOrder()) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConstant) { continue; @@ -612,23 +699,13 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); - needs_dce = true; } } } - - if (needs_tuple_simplifier) { - TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - } - if (needs_dce) { - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } return Status::OK(); } -Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { +Status BFloat16Propagation::SkipNoopConversions(HloModule* module) { for (auto computation : module->computations()) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConvert) { @@ -643,7 +720,6 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { if (is_root) { computation->set_root_instruction(source); } - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo)); } } return Status::OK(); @@ -652,8 +728,18 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { // The algorithm first does a forward pass (parameters to root) to determine a // set of instructions to consider using bfloat16, then does a backward pass to // determine the precisions of those instructions according to the need of -// their users. +// their users. During the backward pass, the potential changes are stored in +// changes_to_bf16_ which are subject to further adjustments then applied to the +// HLOs. StatusOr BFloat16Propagation::Run(HloModule* module) { + consider_using_bfloat16_.clear(); + instructions_visited_in_backward_pass_.clear(); + computations_visited_in_backward_pass_.clear(); + values_that_must_be_kept_as_f32_.clear(); + caller_counts_.clear(); + changes_to_bf16_.clear(); + changed_ = false; + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); std::list computations_topological_order = @@ -686,8 +772,24 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { } auto insts = (*comp_it)->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, - /*skip_parameters=*/true); + DetermineInstructionPrecision(*inst_it, + /*skip_parameters=*/true); + } + } + + // It's possible that an instruction does not define a buffer, but the + // defining instruction's shape has changed. So we need to adjust the output + // shapes of instructions according to the HLO values they refer to. + ResolveInconsistencyOfAliasingBuffers(module); + + // Apply the changes in changes_to_bf16_. + for (auto& change : changes_to_bf16_) { + auto shape = change.first->mutable_shape(); + for (const auto& index : change.second) { + auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + CHECK_EQ(subshape->element_type(), F32); + subshape->set_element_type(BF16); + changed_ = true; } } @@ -695,15 +797,56 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { return false; } - // It's possible that an instruction does not define a buffer, but the - // defining instruction's shape has changed. So we need to adjust the output - // shapes of instructions according to the HLO values they refer to. - TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); + TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> - // BF16), so we remove them now. - TF_RETURN_IF_ERROR(RemoveNoopConversions(module)); + // BF16), so we skip them now. + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); + + { + // We may have dead HLOs after ResolveInconsistentFusions, + // ResolveConvertedConstants and SkipNoopConversions. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } return true; } +PrimitiveType BFloat16Propagation::OutputTypeAfterChange( + HloInstruction* hlo, const ShapeIndex& index) const { + PrimitiveType type_on_hlo = + ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + if (type_on_hlo != F32) { + return type_on_hlo; + } + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return type_on_hlo; + } + return ContainsKey(it->second, index) ? BF16 : F32; +} + +PrimitiveType BFloat16Propagation::ValueTypeAfterChange( + const HloValue* value) const { + auto hlo = value->defining_instruction(); + const auto& position = value->defining_position(); + return OutputTypeAfterChange(hlo, position.index); +} + +void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( + HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { + if (target_type == BF16) { + auto& entry = changes_to_bf16_[hlo]; + entry.insert(index); + } else { + CHECK_EQ(target_type, F32); + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return; + } + it->second.erase(index); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1744e9db90aeff..de0355ddfca127 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -85,30 +86,39 @@ class BFloat16Propagation : public HloPassInterface { tensorflow::gtl::FlatSet consider_using_bfloat16_; // *************************** - // Functions called and state produced by the backward mutation pass (from - // root to parameters). + // Functions called and state produced by the backward pass (from root to + // parameters) that finds opportunities to use BF16. - // Determines the precision for the given instruction in the mutation pass. - void DetermineAndMutateInstructionPrecision(HloInstruction* hlo, - bool skip_parameters); + // Determines the precision for the given instruction in the + // opportunity-finding pass. + void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); - // Special handling in the mutation pass for fusion computations. + // Special handling in the opportunity-finding pass for fusion computations. // // Precondition: hlo->opcode() == kFusion - void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion); + void DetermineFusionComputationPrecision(HloInstruction* fusion); - // Special handling in the mutation pass for while computations. + // Reverts changes to BF16 that will not propagate outside a fusion + // computation. This avoids BF16 casts overhead inside a fusion which won't + // save memory bandwidth. + // + // Precondition: hlo->opcode() == kFusion + void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); + + // Special handling in the opportunity-finding pass for while computations. // // Precondition: hlo->opcode() == kWhile - void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo); + void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); - // The set of HloInstructions that have been visited in the mutation pass. + // The set of HloInstructions that have been visited in the + // opportunity-finding pass. tensorflow::gtl::FlatSet - instructions_visited_in_mutation_pass_; + instructions_visited_in_backward_pass_; - // The set of HloComputations that have been visited in the mutation pass. + // The set of HloComputations that have been visited in the + // opportunity-finding pass. tensorflow::gtl::FlatSet - computations_visited_in_mutation_pass_; + computations_visited_in_backward_pass_; // *************************** // Functions called by the final inconsistency resolving pass. @@ -116,7 +126,7 @@ class BFloat16Propagation : public HloPassInterface { // Adjusts the output shapes of HloInstructions such that if two // HloInstructions have aliasing buffers in their outputs, they must have the // same precision. - Status ResolveInconsistencyOfAliasingBuffers(HloModule* module); + void ResolveInconsistencyOfAliasingBuffers(HloModule* module); // Resolves inconsistency of aliasing buffers for the given computation, and // recursively runs on a while instruction's condition and body until a fixed @@ -134,9 +144,19 @@ class BFloat16Propagation : public HloPassInterface { void AdjustCalledComputationRoot(HloInstruction* hlo); // *************************** - // Removes no-op conversions (same source and target shapes) that can be - // produced this pass. - Status RemoveNoopConversions(HloModule* module); + // Functions called after changes in changes_to_bf16_ are applied. + + // Resolves inconsistencies introduced by this pass for fusions with + // tuple-type output. + Status ResolveInconsistentFusions(HloModule* module); + + // Converts the literals in kConstant HLOs which have their types changed to + // BF16 by this pass. + Status ResolveConvertedConstants(HloModule* module); + + // Skips no-op conversions (same source and target shapes) that can be + // produced this pass, i.e., replaces them in their uses with their operands. + Status SkipNoopConversions(HloModule* module); // *************************** // Functions called and state used by two or more passes. @@ -146,6 +166,23 @@ class BFloat16Propagation : public HloPassInterface { bool AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const; + // The output element type of the HLO at the given shape index after changes + // in changes_to_bf16_ are applied. + PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, + const ShapeIndex& index) const; + + // The element type of the HLO value after changes in changes_to_bf16_ are + // applied. + PrimitiveType ValueTypeAfterChange(const HloValue* value) const; + + // If target_type == BF16, adds the HLO at the given index to + // changes_to_bf16_; otherwise, target_type must be F32 and this function + // removes the HLO at the given index from changes_to_bf16_ if it was earlier + // added. + void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, + const ShapeIndex& index, + PrimitiveType target_type); + // The set of F32 HLO values that must be kept in F32. tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; @@ -153,10 +190,28 @@ class BFloat16Propagation : public HloPassInterface { // module. Populated at the beginning of this pass. tensorflow::gtl::FlatMap caller_counts_; + // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which + // are subject to further adjustment, then finally applied to the HLOs. This + // avoids setting changed_ to true but all changes are reverted during + // adjustment. + struct IndexHasher { + int64 operator()(const ShapeIndex& index) const { + int64 hash = 0; + for (int64 i : index) { + hash = tensorflow::Hash64Combine(hash, std::hash()(i)); + } + return hash; + } + }; + tensorflow::gtl::FlatMap> + changes_to_bf16_; + + // Whether the last processed HLO module has been changed by this pass. + bool changed_ = false; + const BFloat16Support* bfloat16_support_; std::unique_ptr dataflow_; - - bool changed_ = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 183db1652e498e..5e1499ee6b6ef3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_TRUE(OutputsBF16(dot->operand(1))); EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(0)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); - LiteralTestUtil::ExpectEqual( + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(1)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); } // Tests that BF16 can be propagated through nested tuples. @@ -323,6 +323,37 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { EXPECT_TRUE(OutputsBF16(b_f1)); } +// Tests that changes to BF16 that cannot be propagated outside a fusion are +// discarded. +TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), fusion); +} + // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion // outputs are only used by a dot, and 3) one element of the tuple is used by // an add in the fusion computation, then the propagation pass should create a diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index dbe45e932cdeed..682c3865797c85 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -134,6 +135,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -292,112 +294,6 @@ BufferAllocationProto BufferAllocation::ToProto() const { return proto; } -std::pair> -BufferAllocation::ComputePeakMemoryLogicalBuffers() const { - if (HeapTraces().empty()) { - // Just return the largest LogicalBuffer in the allocation. - const LogicalBuffer* largest_buffer = nullptr; - int64 largest_size = 0; - for (const auto& pair : assigned_buffers()) { - const LogicalBuffer* buffer = pair.first; - int64 size = pair.second.size; - if (largest_buffer == nullptr) { - largest_buffer = buffer; - largest_size = size; - continue; - } - // Tie-break with LogicalBuffer::Id so the return value is stable relative - // to changing addresses. - if (size > largest_size || - ((size == largest_size) && (largest_buffer->id() > buffer->id()))) { - largest_buffer = buffer; - largest_size = size; - } - } - CHECK(largest_buffer != nullptr) - << "No logical buffers in allocation: " << ToString(); - return {largest_size, {largest_buffer}}; - } - - // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical - // buffers in this allocation. - tensorflow::gtl::FlatMap - id_to_buffer; - tensorflow::gtl::FlatMap buffer_sizes; - for (const auto& pair : assigned_buffers()) { - const LogicalBuffer* buffer = pair.first; - const OffsetSize& offset_size = pair.second; - id_to_buffer[buffer->id()] = buffer; - buffer_sizes[buffer] = offset_size.size; - } - - // Returns how much the given event increases the total size of live - // buffers. Can be negative. - auto memory_delta = [this, &id_to_buffer, &buffer_sizes]( - const HeapSimulatorTrace::Event& event) -> int64 { - const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); - const int64 buffer_size = buffer_sizes.at(buffer); - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - return buffer_size; - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - // Sharing a buffer does not change the live set size for the purposes of - // the heap simulator. Even though the shared-with buffer may be smaller, - // the entire allocation remains live. - return 0; - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - return -1 * buffer_size; - } - LOG(FATAL) << "Unknown event kind: " << event.kind(); - }; - - int64 total_max_live_size = 0; - std::vector live_buffers_vector; - for (const HeapSimulatorTrace& heap_trace : HeapTraces()) { - // First compute the size of the maximal live set. - int64 max_live_size = 0; - int64 live_size = 0; - for (const auto& event : heap_trace.events()) { - live_size += memory_delta(event); - if (max_live_size < live_size) { - max_live_size = live_size; - } - } - - // Next gather the set of logical buffers live at the earliest point of - // maximal live set size. - tensorflow::gtl::FlatSet live_buffers; - live_size = 0; - for (const auto& event : heap_trace.events()) { - const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - InsertOrDie(&live_buffers, buffer); - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - // Nothing to do. - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - CHECK(ContainsKey(live_buffers, buffer)); - live_buffers.erase(buffer); - } - - live_size += memory_delta(event); - if (live_size == max_live_size) { - break; - } - } - CHECK_EQ(live_size, max_live_size); - total_max_live_size += max_live_size; - - live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), - live_buffers.end()); - } - - // Stabily sort the live buffers. - std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - return {total_max_live_size, live_buffers_vector}; -} - string BufferAllocation::ToString() const { string output; Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); @@ -610,6 +506,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, BufferAllocation* allocation = NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color()); AddAssignment(allocation, buffer, /*offset=*/0, size); + allocation->peak_buffers_.push_back(&buffer); return allocation; } @@ -680,6 +577,10 @@ void BufferAssignment::CombineTempAllocations() { CHECK_EQ(temp_allocation.HeapTraces().size(), 1); combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front()); } + combined_allocation->peak_buffers_.insert( + combined_allocation->peak_buffers_.end(), + temp_allocation.peak_buffers_.begin(), + temp_allocation.peak_buffers_.end()); } // Replace all existing temporary allocations with the new combined // allocations. @@ -800,7 +701,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto::BufferAlias* proto_alias = proto.add_buffer_aliases(); LogicalBufferProto::Location proto_alias_location = - LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index()); + BufferValue::ToLocationProto(*alias.instruction(), alias.index()); proto_alias->set_source_buffer_id(buffer.id()); proto_alias->mutable_location()->Swap(&proto_alias_location); } @@ -1184,7 +1085,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1212,7 +1115,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1228,6 +1133,89 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( return Status::OK(); } +namespace { + +// Computes and returns the set of logical buffers live at the point of maximal +// liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id. +std::vector ComputePeakMemoryLogicalBuffers( + const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { + // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical + // buffers in this allocation. + tensorflow::gtl::FlatMap + id_to_buffer; + tensorflow::gtl::FlatMap buffer_sizes; + for (const auto& pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + const BufferAllocation::OffsetSize& offset_size = pair.second; + id_to_buffer[buffer->id()] = buffer; + buffer_sizes[buffer] = offset_size.size; + } + + // Returns how much the given event increases the total size of live + // buffers. Can be negative. + auto memory_delta = [&id_to_buffer, &buffer_sizes]( + const HeapSimulatorTrace::Event& event) -> int64 { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + const int64 buffer_size = buffer_sizes.at(buffer); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + return buffer_size; + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Sharing a buffer does not change the live set size for the purposes of + // the heap simulator. Even though the shared-with buffer may be smaller, + // the entire allocation remains live. + return 0; + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + return -1 * buffer_size; + } + LOG(FATAL) << "Unknown event kind: " << event.kind(); + }; + + // First compute the size of the maximal live set. + int64 max_live_size = 0; + int64 live_size = 0; + for (const auto& event : heap_trace.events()) { + live_size += memory_delta(event); + if (max_live_size < live_size) { + max_live_size = live_size; + } + } + + // Next gather the set of logical buffers live at the earliest point of + // maximal live set size. + tensorflow::gtl::FlatSet live_buffers; + live_size = 0; + for (const auto& event : heap_trace.events()) { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + InsertOrDie(&live_buffers, buffer); + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Nothing to do. + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + CHECK(ContainsKey(live_buffers, buffer)); + live_buffers.erase(buffer); + } + + live_size += memory_delta(event); + if (live_size == max_live_size) { + break; + } + } + CHECK_EQ(live_size, max_live_size); + + std::vector live_buffers_vector; + live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), + live_buffers.end()); + + // Stabily sort the live buffers. + std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + return live_buffers_vector; +} + +} // namespace + void BufferAssigner::AssignBuffersFromHeapSimulator( const HeapSimulator::Result& result, BufferAssignment* assignment, LogicalBuffer::Color color) { @@ -1242,10 +1230,15 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); for (const auto& buffer_chunk : result.chunk_map) { - const LogicalBuffer& buffer = *buffer_chunk.first; + // TODO(lauj) Remove this down_cast after downstream users of + // BufferAllocation::assigned_buffers() are updated to use BufferValue. + const LogicalBuffer& buffer = + *CHECK_NOTNULL(dynamic_cast(buffer_chunk.first)); const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } + allocation->peak_buffers_ = + ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace); VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString(); allocation->AddHeapTrace(result.debug_trace); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 3086d0e2ca0026..ad0b0bf7c25d71 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -206,17 +206,15 @@ class BufferAllocation { return heap_traces_; } - // Compute and return the LogicalBuffers which are live at the point of peak - // memory usage for the given allocation. The point of peak memory usage is - // the point at which the total size of all live logical buffers is - // maximal. If peak memory is reached at multiple points, the set of logical - // buffers live at the earliest maximal point is returned. The vector is - // stabily asserted by LogicalBuffer::Index. - // - // The return value is a pair of total size of the logical buffers at peak, - // and the buffers themselves. - std::pair> - ComputePeakMemoryLogicalBuffers() const; + // Returns the LogicalBuffers which are live at the point of peak memory usage + // for this allocation. The point of peak memory usage is the point at which + // the total size of all live logical buffers is maximal. If peak memory is + // reached at multiple points, the set of logical buffers live at the earliest + // maximal point is returned. The vector is stabily sorted by + // LogicalBuffer::Index. + const std::vector& PeakMemoryLogicalBuffers() const { + return peak_buffers_; + } // Get the number of bytes lost to fragmentation. This is equal to the // difference between the size of the allocation and the size of the maximal @@ -291,6 +289,9 @@ class BufferAllocation { int64 fragmentation_bytes_ = 0; std::vector heap_traces_; + + // Set of buffers live at the point of peak memory usage for this allocation. + std::vector peak_buffers_; }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. @@ -414,10 +415,10 @@ class BufferAssignment { // Only BufferAssigner can build or modify BufferAssignments. friend class BufferAssigner; - explicit BufferAssignment(const HloModule* module, - std::unique_ptr liveness, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) + BufferAssignment(const HloModule* module, + std::unique_ptr liveness, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), buffer_size_(std::move(buffer_size)), diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 513a8785bbd52b..7e86c33687e595 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -81,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { class BufferAssignmentTest : public HloTestBase { protected: - BufferAssignmentTest() : computation_tracker_() {} + BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -251,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Computation tracker for nested computations. - ComputationTracker computation_tracker_; - // Shapes for use in the examples. Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); @@ -1519,12 +1516,8 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // single logical buffer should be exactly the logical buffer in that // allocation. const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); - int64 peak_size; - std::vector peak_buffers; - - std::tie(peak_size, peak_buffers) = - mul_buffer.ComputePeakMemoryLogicalBuffers(); - EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(f32vec100_)); + const std::vector& peak_buffers = + mul_buffer.PeakMemoryLogicalBuffers(); ASSERT_EQ(peak_buffers.size(), 1); EXPECT_EQ(peak_buffers[0]->instruction(), mul); } @@ -1555,6 +1548,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0)); // Make the root tiny so no interior nodes can share its buffer. auto root = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); auto module = CreateNewModule(); @@ -1569,12 +1563,10 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { EXPECT_TRUE(buffer.IsPreallocatedTempBuffer()); ASSERT_EQ(buffer.assigned_buffers().size(), 4); - int64 peak_size; - std::vector peak_buffers; - std::tie(peak_size, peak_buffers) = buffer.ComputePeakMemoryLogicalBuffers(); + const std::vector& peak_buffers = + buffer.PeakMemoryLogicalBuffers(); // The peak live set should be concat and its inputs. - EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {400}))); ASSERT_EQ(peak_buffers.size(), 3); std::vector peak_instructions; for (const LogicalBuffer* logical_buffer : peak_buffers) { @@ -1583,6 +1575,69 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); } +TEST_F(BufferAssignmentTest, PeakBuffersWhile) { + auto module = CreateNewModule(); + const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); + HloComputation* condition; + { + auto b = HloComputation::Builder(TestName() + ".cond"); + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + condition = module->AddEmbeddedComputation(b.Build()); + } + HloComputation* body; + { + auto b = HloComputation::Builder(TestName() + ".body"); + auto param = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + b.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param)); + body = module->AddEmbeddedComputation(b.Build()); + } + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param)); + auto while_op = builder.AddInstruction( + HloInstruction::CreateWhile(shape, condition, body, copy)); + // This broadcast should get a temporary allocation which is merged with the + // allocation for the while. Peak buffers should include the while and the + // broadcast. + auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1})); + builder.AddInstruction(HloInstruction::CreateReverse( + ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); + const std::vector& peak_buffers = + buffer.PeakMemoryLogicalBuffers(); + ASSERT_EQ(peak_buffers.size(), 2); + + // The peak buffers should include the broadcast and one of the colocated + // buffers of the while (body param, condition param, body root, or the while + // itself). + const LogicalBuffer* bcast_buffer; + const LogicalBuffer* nonbcast_buffer; + if (peak_buffers[0]->instruction() == bcast) { + bcast_buffer = peak_buffers[0]; + nonbcast_buffer = peak_buffers[1]; + } else { + bcast_buffer = peak_buffers[1]; + nonbcast_buffer = peak_buffers[0]; + } + EXPECT_EQ(bcast_buffer->instruction(), bcast); + EXPECT_TRUE( + nonbcast_buffer->instruction() == copy || + nonbcast_buffer->instruction() == while_op || + nonbcast_buffer->instruction() == body->parameter_instruction(0) || + nonbcast_buffer->instruction() == body->root_instruction() || + nonbcast_buffer->instruction() == condition->parameter_instruction(0)); +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1626,7 +1681,7 @@ class WhileBufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } - static int64 ByteSizeOf(const LogicalBuffer& buffer) { + static int64 ByteSizeOf(const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); } @@ -1641,7 +1696,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1738,7 +1793,7 @@ ENTRY %test_module { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. @@ -1816,7 +1871,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); @@ -1884,7 +1939,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1929,7 +1984,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1994,7 +2049,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2073,7 +2128,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 37982aaef9eddd..810d597e730c18 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +43,7 @@ StatusOr> BufferLiveness::Run( return std::move(liveness); } -tensorflow::Status BufferLiveness::Analyze() { +Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { @@ -71,7 +70,7 @@ tensorflow::Status BufferLiveness::Analyze() { } XLA_VLOG_LINES(3, ToString()); - return tensorflow::Status::OK(); + return Status::OK(); } string BufferLiveness::ToString() const { @@ -105,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(), + alias.index(), user)) { continue; } if (user != b.instruction() && @@ -132,9 +131,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && - !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), - b.instruction(), b.index(), - points_to_analysis())) { + !points_to_analysis().CanShareOperandBufferWithUser( + alias.instruction(), alias.index(), b.instruction(), b.index())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 11834a5127e383..cdd3cf4032ef69 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -89,7 +89,7 @@ class BufferLiveness { // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. - tensorflow::Status Analyze(); + Status Analyze(); // Returns true if the live range of the buffer of 'a' is strictly before the // live range of the buffer of 'b' (they do not overlap). diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc new file mode 100644 index 00000000000000..2bc556a9e27013 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_value.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, + Id id) + : id_(id) { + const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); + is_array_ = ShapeUtil::IsArray(shape); + is_tuple_ = ShapeUtil::IsTuple(shape); +} + +BufferValue::~BufferValue() {} + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer) { + out << buffer.ToString(); + return out; +} + +/*static*/ LogicalBufferProto::Location BufferValue::ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index) { + LogicalBufferProto::Location proto; + proto.set_computation_name(instruction.parent()->name()); + proto.set_instruction_name(instruction.name()); + for (const int64 index_entry : index) { + proto.add_shape_index(index_entry); + } + return proto; +} + +LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const { + LogicalBufferProto proto; + proto.set_id(id()); + proto.set_size(size_fn(*this)); + LogicalBufferProto::Location proto_location = + ToLocationProto(*instruction(), index()); + proto.mutable_defined_at()->Swap(&proto_location); + if (has_color()) { + proto.set_color(color().value()); + } + return proto; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h new file mode 100644 index 00000000000000..f4be16e0843f64 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Abstract class describing a value used by one of the dataflow analyses - +// TuplePointsToAnalysis or HloDataflowAnalysis. +// TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. +// +// XLA arrays are trivially a single BufferValue. Tuples are made up of more +// than one BufferValue: an BufferValue for the pointer vector, and an +// BufferValue for each child element. +// +// Every BufferValue is defined by a particular instruction and most +// instructions define only a single BufferValue. Instructions which define a +// single BufferValue include array-shaped instructions such as Add but also +// includes Tuple-shaped instructions such as Tuple. The Tuple instruction +// defines a single BufferValue which is a vector of pointers to the values +// containing the Tuple instruction's operands. Though the result of the Tuple +// instruction includes multiple values only the top-level BufferValue (the +// vector of pointers) is defined by the Tuple instruction. The values +// containing the tuple elements are defined by earlier instructions, usually +// the operands of the Tuple instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one BufferValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing BufferValues like the Tuple instruction +// does, but rather define all the BufferValues in the tuple. +// +// Some instructions, such as Bitcast, define no buffers. These instructions +// simply forward buffers from their operands. +// +// The BufferValue object describes which HLO instruction defines a buffer and +// where within that instruction's output shape the buffer is defined. The +// location within the output shape is indicated by BufferValue::index() which +// is defined identically to the index used in ShapeUtil::GetSubshape(). +// Examples: +// +// %add = Add(%foo, %bar) +// %tuple_constant = Constant({1, {42, 43}}) +// +// %add defines a single array-shaped buffer BufferValue(%add, {}) which holds +// the array result of the add operation. The nested-tuple-shaped +// %tuple_constant defines 5 buffers described by the following BufferValue +// objects: +// +// BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of +// // pointers to BufferValues at +// // indices {0} and {1} +// BufferValue(%tuple_constant, {0}) // Holds value "1" +// BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of +// // pointers to BufferValues at +// // indices {1, 0} and {1, 1} +// BufferValue(%tuple_constant, {1, 0}) // Holds value "42" +// BufferValue(%tuple_constant, {1, 1}) // Holds value "43" + +class BufferValue { + public: + TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); + + // Id is a unique identifier for the BufferValue to facilitate efficient + // collections of BufferValues with stable iteration order. + using Id = int64; + + // Functions which return the size and alignment of a logical buffer in bytes. + using SizeFunction = std::function; + using AlignmentFunction = std::function; + + virtual ~BufferValue(); + + Id id() const { return id_; } + + // Return the instruction that defines the buffer. + virtual HloInstruction* instruction() const = 0; + + // Return the index within the output of the instruction where the buffer is + // defined. Index used defined as in ShapeUtil::GetSubshape() + virtual const ShapeIndex& index() const = 0; + + // Return the color of the BufferValue. Differently colored buffers can not be + // parts of the same allocation. + Color color() const { + CHECK_NE(color_, kInvalidColor) + << "Should not query the color of a buffer that was never colored"; + return color_; + } + + void set_color(Color color) { + CHECK_NE(color, kInvalidColor) + << "Should not set the color of a buffer to the invalid color"; + color_ = color; + } + + bool has_color() const { return color_ != kInvalidColor; } + + // Return the shape of the buffer. This reference points into the shape field + // of the instruction defining the buffer. Therefore, the returned shape will + // contain the layout of instruction, if any. + virtual const Shape& shape() const = 0; + + // Returns true if this buffer is the top-level output buffer of the defining + // HLO instruction. This is equivalent to index == {}. + bool IsTopLevel() const { return index().empty(); } + + // Whether this buffer contains a tuple. + bool IsTuple() const { return is_tuple_; } + + // Whether this buffer contains an array. + bool IsArray() const { return is_array_; } + + // operator< is required for std::set. + bool operator<(const BufferValue& other) const { return id_ < other.id_; } + + virtual string ToString() const = 0; + + // TODO(lauj) rename LogicalBufferProto to BufferValueProto. + LogicalBufferProto ToProto(const SizeFunction& size_fn) const; + + // Returns the LogicalBufferProto::Location that serializes the given + // instruction and index. + static LogicalBufferProto::Location ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index); + + const Color kInvalidColor = Color(-1); + + protected: + BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); + + private: + // The definining instruction and index are not stored here; they can be found + // in the LogicalBuffer and HloValue subclasses. This class exists only to + // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by + // allowing abstract use of LogicalBuffer or HloValue. After those migrations + // are complete, this class should be deleted (b/78906445). Because we plan to + // delete LogicalBuffer and this class, we don't refactor all the shared + // features from LogicalBuffer and HloValue into this class. + Id id_ : 62; + bool is_array_ : 1; + bool is_tuple_ : 1; + Color color_ = kInvalidColor; +}; + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h new file mode 100644 index 00000000000000..305914fca828f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ + +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/core/lib/gtl/compactptrset.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Define various containers of BufferValues, and utilities to convert from +// containers of LogicalBuffers to containers of BufferValues. + +using BufferValueCompactPointerSet = + tensorflow::gtl::CompactPointerSet; +template +BufferValueCompactPointerSet ToBufferValueCompactPointerSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueCompactPointerSet output; + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +using BufferValueFlatSet = tensorflow::gtl::FlatSet; +template +BufferValueFlatSet ToBufferValueFlatSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueFlatSet output; + output.reserve(logical_buffer_container.size()); + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a8053d15e12431..a23427f00ccd88 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index c7763f2ca3e684..52f33a1318e91d 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,8 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index c9f78a0f9f1c0e..d8fdccf9bbf1c1 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -69,70 +68,34 @@ CompileOnlyService::CompileAheadOfTime( for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_program_shape()); - const DebugOptions& debug_options = options.debug_options(); - const auto& program_shape = instance.computation.program_shape(); - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(program_shape, instance.argument_layouts, - &execution_options)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - HloModule::CreateFromProto(instance.computation, *module_config)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); - hlo_modules.push_back(std::move(hlo_module)); - } - - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); -} - -StatusOr>> -CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - const DebugOptions& debug_options = options.debug_options(); - // Dump computation proto state if flag is set. + // Dump computation proto if flag is set. const string& directory_path = debug_options.xla_dump_computations_to(); if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); + HloSnapshot hlo_snapshot; + *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); + "computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - + const auto& program_shape = instance.computation.program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, user_computation)); + CreateModuleConfig(program_shape, instance.argument_layouts, + &execution_options)); - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(instance.computation, *module_config)); TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index c10609e67fcdec..e6a66c202d6e0d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -38,24 +38,7 @@ class CompileOnlyService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); - // A description of a xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { HloModuleProto computation; std::vector argument_layouts; @@ -65,58 +48,36 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override { + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override { + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8b01a6c4b5004d..6f06bba6798bdf 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,6 +28,13 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return {}; +} + /* static */ std::map* Compiler::GetPlatformCompilerFactories() { static auto* r = new std::map; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 5c14591d93cc99..6c52ffd800d19d 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -24,8 +24,11 @@ limitations under the License. #include #include #include +#include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -33,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -152,6 +156,16 @@ class Compiler { std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; + // Returns the backend configurations that the backend will consider for the + // given HLO. Returns no configurations if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> @@ -181,9 +195,9 @@ class Compiler { // Returns a function that computes the size in bytes of a given // logical buffer. - std::function BufferSizeBytesFunction() { + std::function BufferSizeBytesFunction() { HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); - return [shape_size](const LogicalBuffer& buffer) { + return [shape_size](const BufferValue& buffer) { return shape_size(buffer.shape()); }; } diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index d2d4f14fcec35f..cb61f3da39fb8e 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -23,12 +23,15 @@ limitations under the License. namespace xla { -ComputationLayout::ComputationLayout(const ProgramShape& program_shape) +ComputationLayout::ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts) : result_layout_(program_shape.result()) { for (auto& shape : program_shape.parameters()) { parameter_layouts_.emplace_back(shape); } - SetToDefaultLayout(); + if (ignore_layouts) { + SetToDefaultLayout(); + } } void ComputationLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 80e102411c7885..53c3a3f7b73868 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -34,8 +34,9 @@ class ComputationLayout { public: // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the - // ProgramShape are ignored. - explicit ComputationLayout(const ProgramShape& program_shape); + // ProgramShape are ignored if ignore_layouts is true. + explicit ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts = true); // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc deleted file mode 100644 index 70e25eebdb068d..00000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/computation_tracker.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" - -using ::tensorflow::strings::Appendf; - -namespace xla { - -ComputationTracker::ComputationTracker() : next_computation_(1) {} - -ComputationHandle ComputationTracker::NewComputation( - const string& computation_name) { - tensorflow::mutex_lock lock(computation_mutex_); - ComputationHandle computation_handle; - int64 handle_value = next_computation_++; - computation_handle.set_handle(handle_value); - opaque_to_computation_[handle_value] = - MakeUnique(computation_name, computation_handle); - return computation_handle; -} - -StatusOr ComputationTracker::LoadSessionModule( - const SessionModule& session_module) { - tensorflow::mutex_lock lock(computation_mutex_); - - // For each embedded computation, create a new computation based on its - // serialized data, and place the mapping from the old computation handle to - // the new computation handle. - - // Build a mapping from old embedded computation handles to new computation - // handles. We build the ID mapping first since the embedded computations are - // in no particular order and may refer to each other. - std::map old_to_new; - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { - return InvalidArgument("Duplicate embedded computation handle %lld", - old_handle); - } - } - - // Create a new computation from each serialized embedded computation. - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - const ComputationHandle& new_handle = old_to_new[old_handle]; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - computation, new_handle, old_to_new)); - } - - // Finally, place the entry computation in the tracker with all of the - // remappings populated from the above. - const int64 old_handle = session_module.entry().computation_handle().handle(); - TF_ASSIGN_OR_RETURN( - old_to_new[old_handle], - LoadSessionComputation(session_module.entry(), &old_to_new)); - return old_to_new[old_handle]; -} - -StatusOr> -ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); - const VersionedComputationHandle entry_versioned_handle = - user_computation->GetVersionedHandle(); - std::set visited; - std::list post_order; - { - tensorflow::mutex_lock lock(computation_mutex_); - ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); - } - auto session_module = MakeUnique(); - *session_module->mutable_entry() = - Resolve(entry_versioned_handle.handle) - .ValueOrDie() - ->CloneSessionComputation(entry_versioned_handle.version); - for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { - *session_module->add_embedded_computations() = - Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); - } - return std::move(session_module); -} - -StatusOr ComputationTracker::Resolve( - const ComputationHandle& computation) const { - tensorflow::mutex_lock lock(computation_mutex_); - return ResolveInternal(computation); -} - -ComputationHandle ComputationTracker::AllocateHandle() { - int64 handle_value = next_computation_++; - ComputationHandle result; - result.set_handle(handle_value); - return result; -} - -StatusOr ComputationTracker::LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) { - TF_RET_CHECK(old_to_new != nullptr); - const ComputationHandle new_handle = AllocateHandle(); - (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - session_computation, new_handle, *old_to_new)); - return new_handle; -} - -StatusOr ComputationTracker::ResolveInternal( - const ComputationHandle& computation) const { - auto it = opaque_to_computation_.find(computation.handle()); - if (it == opaque_to_computation_.end()) { - return NotFound("computation handle not found: %lld", computation.handle()); - } - UserComputation* user_computation = it->second.get(); - return user_computation; -} - -void ComputationTracker::ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const { - if (visited->count(versioned_handle) > 0) { - CHECK_EQ(1, visited->count(versioned_handle)); - return; - } - - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - std::vector embedded_handles = - computation->GetEmbeddedComputations(versioned_handle.version); - - for (const auto& embedded_handle : embedded_handles) { - ComputeComputationPostOrder(embedded_handle, visited, post_order); - } - - visited->insert(versioned_handle); - post_order->push_back(versioned_handle); -} - -StatusOr> ComputationTracker::BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(computation_mutex_); - - VLOG(1) << "BuildHloModule(" << entry_handle - << ", include_unreachable_instructions=" - << include_unreachable_instructions << ")"; - XLA_VLOG_LINES(1, ToStringInternal()); - - TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, - ResolveInternal(entry_handle.handle)); - - // Build a topological sort of the entry and any embedded computations as a - // list. The root of the computation will be the last element in the list. - std::set visited; - std::list post_order; - ComputeComputationPostOrder(entry_handle, &visited, &post_order); - - // Map from ComputationHandle value and computation version to HloComputation. - std::map hlo_computations; - - // The resolver lambda resolves VersionedHandles to embedded - // HloComputation*. This is required by UserComputation::BuildHloComputation - // when lowering calling operations (map, reduce etc). - auto resolver = [&hlo_computations]( - const VersionedComputationHandle& versioned_handle) -> HloComputation* { - CHECK_GT(hlo_computations.count(versioned_handle), 0); - return hlo_computations.at(versioned_handle); - }; - - // Print the post-order list for this entry computation. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Visiting UserComputations in post order:"; - for (const VersionedComputationHandle& versioned_handle : post_order) { - VLOG(2) << " " << versioned_handle; - } - } - - string module_name = - tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle, config); - for (auto versioned_handle : post_order) { - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - computation->BuildHloComputation(versioned_handle.version, resolver, - config.debug_options(), - include_unreachable_instructions)); - - // Add the newly created computation to VersionedHandle-to-HloComputation - // map. - DCHECK_EQ(0, hlo_computations.count(versioned_handle)); - hlo_computations[versioned_handle] = hlo_computation.get(); - - if (computation == entry_computation) { - module->AddEntryComputation(std::move(hlo_computation)); - } else { - module->AddEmbeddedComputation(std::move(hlo_computation)); - } - } - - return std::move(module); -} - -string ComputationTracker::ToString() const { - tensorflow::mutex_lock lock(computation_mutex_); - return ToStringInternal(); -} - -string ComputationTracker::ToStringInternal() const { - string out; - Appendf(&out, "ComputationTracker(%p):\n", this); - for (const auto& handle_computation : opaque_to_computation_) { - int64 handle = handle_computation.first; - const std::unique_ptr& computation = - handle_computation.second; - Appendf(&out, " %4lld : %s \"%s\"\n", handle, - computation->GetVersionedHandle().ToString().c_str(), - computation->name().c_str()); - } - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h deleted file mode 100644 index d42d66adefe7fa..00000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Tracks computations for the XLA service; computations can be registered -// with a UserComputation instance and can be resolved from a handle for later -// use. -// -// This class is also capable of serializing/deserializing computations that it -// tracks (and to serialize properly you need to serialize all referred-to -// computations as well). -class ComputationTracker { - public: - ComputationTracker(); - - // Creates a new UserComputation object and returns the corresponding - // ComputationHandle for it. - // - // Precondition: user_computation is not already present in the map. - ComputationHandle NewComputation(const string& computation_name); - - // Restores session data for a computation that has been serialized, and - // allocates a new computation handle for it. - StatusOr LoadSessionModule( - const SessionModule& session_module); - - // Snapshots a computation (referenced by the provided handle) at its latest - // version, returning a module where it is the entry, and any referred-to - // computations are entrained as "embedded" (non-entry) computations. - StatusOr> SnapshotComputation( - const ComputationHandle& computation); - - // Resolves a ComputationHandle to a UserComputation that is present in the - // map. - StatusOr Resolve( - const ComputationHandle& computation) const; - - // Builds an HLO module using the specified computation as the entry. The - // module will include the entry computation as well as all computations which - // are called directly or indirectly from the entry computation via operations - // like "map". config is the HLO module configuration to use for the - // constructed module. - // If include_unreachable_instructions is true, then instructions - // which are not reachable from the root are lowered into HloInstructions - // including unreachable parameters. This ensures the entry HloComputation has - // the same program shape (ProgramShape) as the entry UserComputation. - StatusOr> BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions = true) const; - - string ToString() const; - - private: - // Bumps the next_computation_ number and returns the allocated number wrapped - // in a ComputationHandle. - ComputationHandle AllocateHandle() - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Loads a session computation into a UserComputation, registers it, and - // returns the computation handle of the registered computation. If old_to_new - // is provided, it is used for remapping references to computations present in - // session_computation. - // - // old_to_new will be updated with the mapping from session_computation's old - // handle to the returned handle value, and may not be null. - StatusOr LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Internal implementation of Resolve method which requires, but does not - // acquire the mutex. - StatusOr ResolveInternal( - const ComputationHandle& computation) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Builds a post order sort of a computation ("entry") and all of its embedded - // computations including all transitively embedded computations. An embedded - // computation (the callee) will always appear in the sort before the - // computation which calls the embedded computation (the caller). Necessarily, - // the entry computation is the last element in the sort. visited and - // post_order should be empty when calling. post_order contains the post order - // sort when the function return. - void ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Guards the computation mapping. Marked mutable so that the Resolve method - // can remain const; Resolve does't really modify the tracker in any way, but - // it has to lock the mutex for safety. - mutable tensorflow::mutex computation_mutex_; - - // The next sequence number to assign to a computation, guarded by the same - // mutex as the mapping as they'll be mutated at the same time. - int64 next_computation_ GUARDED_BY(computation_mutex_); - - // Mapping from ComputationHandle value to the corresponding registered - // UserComputation object. - std::map> opaque_to_computation_ - GUARDED_BY(computation_mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index e560abc87f8456..e9ec796121fff2 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -35,7 +35,7 @@ namespace xla { // Tries to replace a conditional with a call operation of the corresponding // computation. If the given conditional has a constant predicate, tries to -// replace it with a call to its true/false computation as appropirate and then +// replace it with a call to its true/false computation as appropriate and then // inline that computation. // // Returns true if it made a change to the graph. diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 40519ecc799c8f..33d8338809d4e8 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -65,7 +64,7 @@ struct SpecialCaseCopyPolicy { // output tuple. bool copy_root_replicated_buffers = false; // If true, insert a copy if a buffer coming from a constant or a parameter - // is found wihtin the output tuple. + // is found within the output tuple. bool copy_parameters_and_constants = false; }; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 04fda3b2df5745..278bb1bebfa1a0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -103,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batch_dot_simplification", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -125,12 +126,14 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", @@ -169,11 +172,13 @@ cc_library( ":orc_jit_memory_mapper", ":runtime_fp16", ":runtime_conv2d", + ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -293,6 +298,15 @@ cc_library( ], ) +cc_library( + name = "target_machine_features_fake", + testonly = 1, + hdrs = ["target_machine_features_fake.h"], + deps = [ + ":target_machine_features", + ], +) + cc_library( name = "ir_function", srcs = ["ir_function.cc"], @@ -334,6 +348,7 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":ir_emission_utils", ":target_machine_features", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", @@ -363,10 +378,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -406,7 +421,6 @@ cc_library( "//tensorflow/core:lib", "@llvm//:analysis", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:ipo", "@llvm//:mc", "@llvm//:object", @@ -470,6 +484,27 @@ cc_library( ], ) +cc_library( + name = "runtime_conv2d_mkl", + srcs = [ + "runtime_conv2d_mkl.cc", + ], + hdrs = ["runtime_conv2d_mkl.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + ":runtime_conv2d", + ":runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_helpers", + "//third_party/eigen3", + ] + if_mkl([ + "@mkl_dnn", + "//third_party/mkl:intel_binary_blob", + ]), +) + cc_library( name = "runtime_fft", srcs = [ @@ -482,7 +517,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -544,6 +578,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], @@ -599,6 +649,7 @@ tf_cc_test( deps = [ ":cpu_instruction_fusion", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -636,6 +687,7 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":cpu_runtime", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", @@ -648,14 +700,15 @@ tf_cc_test( srcs = ["ir_emission_utils_test.cc"], deps = [ ":ir_emission_utils", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -666,6 +719,7 @@ cc_library( deps = [ ":dot_op_emitter", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", @@ -679,6 +733,7 @@ tf_cc_test( srcs = ["cpu_layout_assignment_test.cc"], deps = [ ":cpu_layout_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -703,6 +758,7 @@ cc_library( deps = [ ":cpu_runtime", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -717,6 +773,7 @@ tf_cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -755,6 +812,7 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", + ":target_machine_features", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", @@ -767,6 +825,7 @@ tf_cc_test( deps = [ ":cpu_executable", ":parallel_task_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -889,3 +948,17 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +tf_cc_test( + name = "cpu_eigen_tensor_alignment_test", + size = "small", + srcs = ["cpu_eigen_tensor_alignment_test.cc"], + deps = [ + ":dot_op_emitter", + ":ir_emission_utils", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2136aeb3877685..0985b9297fe487 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -33,7 +33,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { if (hlo->opcode() == HloOpcode::kConvolution && - !PotentiallyImplementedAsEigenConvolution(*hlo)) { + !PotentiallyImplementedAsEigenConvolution(*hlo, + target_machine_features_)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); auto input_batch_dim = dnums.input_batch_dimension(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 9b2c3d82eb673c..e6fd1499edd009 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,12 +33,19 @@ namespace cpu { // convolutions can run faster. class ConvCanonicalization : public HloPassInterface { public: + explicit ConvCanonicalization( + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) {} + ~ConvCanonicalization() override {} tensorflow::StringPiece name() const override { return "convolution-canonicalization"; } StatusOr Run(HloModule* module) override; + + private: + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 968f53d5c70665..375b017b09263c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -89,7 +90,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); @@ -146,7 +151,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 3c0c367df30639..25b18eff20f901 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" @@ -81,12 +82,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" @@ -118,10 +121,12 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { CpuAotCompilationResult::CpuAotCompilationResult( ObjectFileData object_file_data, BufferSizes buffer_sizes, - int64 result_buffer_index) + int64 result_buffer_index, + std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), buffer_sizes_(std::move(buffer_sizes)), - result_buffer_index_(result_buffer_index) {} + result_buffer_index_(result_buffer_index), + hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} CpuAotCompilationResult::~CpuAotCompilationResult() = default; @@ -171,14 +176,13 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: static StatusOr> GetCandidatesForComputation( - HloComputation* computation, + const HloComputation& computation, const std::unordered_map& assigned_indices) { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx, assigned_indices); - TF_RETURN_IF_ERROR( - computation->Accept(&profile_candidates_for_computation)); + TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } @@ -229,7 +233,10 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); @@ -246,8 +253,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(&target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); @@ -258,7 +266,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); - pipeline.AddPass(); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -270,16 +277,19 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); } + pipeline.AddPass(); pipeline.AddPass( - [](const HloInstruction& dot, - const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot) + [&target_machine_features]( + const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot, target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -287,12 +297,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout()); + module->mutable_device_entry_computation_layout(), + &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -312,8 +325,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); + pipeline.AddPass( + max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -423,6 +436,41 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { return Status::OK(); } +Status CreateHloProfilingArtifacts( + const HloModule& module, + std::unordered_map* + instruction_to_profile_idx, + std::unordered_map* + computation_to_profile_idx, + std::unique_ptr* hlo_profile_index_map, + std::unique_ptr* hlo_profile_printer_data) { + *hlo_profile_index_map = MakeUnique(module); + const HloComputation& entry_computation = *module.entry_computation(); + + TF_ASSIGN_OR_RETURN( + *instruction_to_profile_idx, + CollectProfileCandidates::GetCandidatesForComputation( + entry_computation, + (*hlo_profile_index_map)->instruction_to_profile_idx())); + + auto shape_size_bytes = [](const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return static_cast(sizeof(void*)); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + + HloCostAnalysis cost_analysis(shape_size_bytes); + TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis)); + *hlo_profile_printer_data = + CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis); + *computation_to_profile_idx = + (*hlo_profile_index_map)->computation_to_profile_idx(); + + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( @@ -431,7 +479,13 @@ StatusOr> CpuCompiler::RunHloPasses( VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + std::unique_ptr jit_target_machine = + SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config())); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, + jit_target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -477,28 +531,9 @@ StatusOr> CpuCompiler::RunBackend( std::unique_ptr hlo_profile_index_map; std::unique_ptr hlo_profile_printer_data; if (module->config().hlo_profiling_enabled()) { - hlo_profile_index_map = MakeUnique(*module); - - TF_ASSIGN_OR_RETURN( - instruction_to_profile_idx, - CollectProfileCandidates::GetCandidatesForComputation( - entry_computation, - hlo_profile_index_map->instruction_to_profile_idx())); - - auto shape_size_bytes = [](const Shape& shape) { - // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { - return static_cast(sizeof(void*)); - } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - }; - - HloCostAnalysis cost_analysis(shape_size_bytes); - TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); - hlo_profile_printer_data = - CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis); - computation_to_profile_idx = - hlo_profile_index_map->computation_to_profile_idx(); + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); } std::unique_ptr cpu_executable; @@ -515,7 +550,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -540,10 +576,11 @@ StatusOr> CpuCompiler::RunBackend( // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. + LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); + &target_machine_features, jit->external_constant_pool()); for (auto embedded_computation : entry_computation->MakeEmbeddedComputationsList()) { @@ -685,7 +722,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + TF_RETURN_IF_ERROR( + RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -714,12 +752,22 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_optimized_hlo_proto_to, module->name())); } + std::unordered_map instruction_to_profile_idx; + std::unordered_map computation_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer_data; + + if (module->config().hlo_profiling_enabled()) { + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); + } + + LLVMTargetMachineFeatures target_machine_features(target_machine.get()); IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*instruction_to_profile_idx=*/ - std::unordered_map{}, - /*computation_to_profile_idx=*/ - std::unordered_map{}, - target_machine.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + &target_machine_features, /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : @@ -760,6 +808,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_RETURN_IF_ERROR(verify_status); } + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(llvm_module)); + Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, @@ -793,7 +843,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, results.emplace_back(MakeUnique( std::move(object_file_data), std::move(buffer_sizes), - result_slice.index())); + result_slice.index(), std::move(hlo_profile_printer_data))); } VLOG(1) << "Compilation finished"; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 151af38438a980..e56f9f01134f84 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -76,10 +77,16 @@ class CpuAotCompilationOptions : public AotCompilationOptions { class CpuAotCompilationResult : public AotCompilationResult { public: - CpuAotCompilationResult(ObjectFileData object_file_data, - BufferSizes buffer_sizes, int64 result_buffer_index); + CpuAotCompilationResult( + ObjectFileData object_file_data, BufferSizes buffer_sizes, + int64 result_buffer_index, + std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); + HloProfilePrinterData* hlo_profile_printer_data() const { + return hlo_profile_printer_data_.get(); + } + const ObjectFileData& object_file_data() const { return object_file_data_; } const BufferSizes& buffer_sizes() const { return buffer_sizes_; } int64 result_buffer_index() const { return result_buffer_index_; } @@ -97,6 +104,10 @@ class CpuAotCompilationResult : public AotCompilationResult { // result of the computation. This buffer should be passed into the output // parameter when calling the compiled computation. const int64 result_buffer_index_; + + // Contains an instance of HloProfilePrinterData if HLO profiling is enabled, + // otherwise is nullptr. + std::unique_ptr hlo_profile_printer_data_; }; // CPU-targeting implementation of the XLA Compiler interface. @@ -138,7 +149,8 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module, bool is_aot_compile); + Status RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc new file mode 100644 index 00000000000000..8727c72b6e4251 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace cpu { +namespace { + +// Test that we don't call into Eigen with tensors too small to be aligned +// reliably. + +class CpuEigenTensorAlignmentTest : public ::testing::Test {}; + +TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { + string hlo_string = R"( +HloModule DotOperation + +ENTRY DotOperation { + arg0 = f32[5,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloInstruction* dot = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE( + PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenDot( + *dot, target_machine_with_full_alignment)); +} + +TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { + string hlo_string = R"( +HloModule ConvOperation + +ENTRY ConvOperation { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, dim_labels=b0f_0io->b0f +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloInstruction* conv = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_full_alignment)); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index aabf4d5161e3af..cf43b74c699ca8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -73,7 +73,7 @@ CpuExecutable::CpuExecutable( Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { + std::vector* buffers) { CHECK_EQ(buffers->size(), assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -201,59 +201,18 @@ Status CpuExecutable::ExecuteComputeFunction( return Status::OK(); } -static void LogLiveAddresses( - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - if (!VLOG_IS_ON(3)) { - return; - } - - CHECK_EQ(buffers.size(), buffers_in_result.size()); - std::vector live_out_buffers; - for (int i = 0; i < buffers.size(); ++i) { - if (buffers_in_result[i]) { - live_out_buffers.push_back(buffers[i].opaque()); - } - } - VLOG(3) << "Live addresses in output marking found " - << live_out_buffers.size() << " addresses:\n" - << tensorflow::str_util::Join( - live_out_buffers, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); -} - -static Status DeallocateTempBuffers( - DeviceMemoryAllocator* allocator, se::Stream* stream, - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - // Keep those buffers in the output of the marked live because they are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR( - allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); - } - } - - return Status::OK(); -} - StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice allocated_buffers, - std::vector* buffers_in_result) { + tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), - run_options->allocator(), stream->parent()->device_ordinal()); + /*on_host_shape=*/host_result_shape(), + /*on_device_shape=*/host_result_shape(), run_options->allocator(), + stream->parent()->device_ordinal()); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. + // Move OwningDeviceMemory values which contain the array(s) of the result + // into the respective location in ScopedShapedBuffer which is returned to the + // caller. TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus( [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); @@ -272,10 +231,9 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( CHECK(!slice.allocation()->is_entry_computation_parameter()); const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + OwningDeviceMemory& buffer = buffers[buffer_index]; CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - (*buffers_in_result)[buffer_index] = true; + *device_memory = buffer.Forget(); return Status::OK(); })); return std::move(result_buffer); @@ -291,23 +249,21 @@ StatusOr CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - // Free all buffers not in the result. - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), + arguments, unowning_buffers, + hlo_execution_profile)); - return std::move(result_buffer); + return CreateResultShapedBuffer(run_options, &buffers); } StatusOr CpuExecutable::ExecuteAsyncOnStream( @@ -323,30 +279,53 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - - LogLiveAddresses(buffers, buffers_in_result); - - host_stream->EnqueueTask([this, run_options, arguments, buffers, - buffers_in_result, memory_allocator, stream]() { - // Failing a CHECK here is not great, but I don't see an obvious way to - // return a failed Status asynchronously. - TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, - buffers, - /*hlo_execution_profile=*/nullptr)); - TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); - }); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, &buffers)); - return std::move(result_buffer); + // At this point, `unowning_buffers` contains unowning pointers to all of our + // buffers, and `buffers` contains owning pointers to the non-live-out + // buffers. Enqueue a task which keeps alive the non-live-out buffers. + // + // Logically we want this lambda to capture `buffers` by move, ultimately our + // functor needs to be wrapped in an std::function, and that requires its + // functor to be copyable. Thus we perpitrate the hack of capturing buffers + // "by shared pointer". + // + // We also need to change the types of some of the variables we capture: + // run_options needs to change from a pointer to a value type, and arguments + // needs to change from an ArraySlice into a vector. We use a struct instead + // of a lambda to make this explicit. + struct AsyncRunTask { + CpuExecutable* executable; + ServiceExecutableRunOptions run_options; + std::vector arguments; + std::vector unowning_buffers; + std::shared_ptr> buffers; + + void operator()() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(executable->ExecuteComputeFunction( + &run_options.run_options(), arguments, unowning_buffers, + /*hlo_execution_profile=*/nullptr)); + } + }; + host_stream->EnqueueTask(AsyncRunTask{ + this, *run_options, + std::vector(arguments.begin(), arguments.end()), + unowning_buffers, + std::make_shared>(std::move(buffers))}); + + return std::move(result); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 68ad38cba88720..8dd47bfb865e8a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -92,7 +92,7 @@ class CpuExecutable : public Executable { // buffer is assigned for this element. Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); + std::vector* buffers); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. @@ -102,16 +102,12 @@ class CpuExecutable : public Executable { tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); - // Creates a ScopedShapedBuffer for holding the result of the computation. The - // addresses (DeviceMemoryBases) are set according to buffer assignment. - // 'buffers_in_result' should point to a vector of the same size as - // 'allocated_buffers'. An element in buffers_in_result is set to true if the - // corresponding buffer is live out of the computation (and thus contained in - // the returned ShapedBuffer). + // Creates a ScopedShapedBuffer for holding the result of the computation, + // moving buffers out of allocated_buffers and into the result as appropriate. + // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice allocated_buffers, - std::vector* buffers_in_result); + tensorflow::gtl::MutableArraySlice buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 0fc5a746bbbc76..b40d264c03aba6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -34,6 +34,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 6ed1cd31b18f63..97e10a89a209c0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -157,37 +158,95 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { EXPECT_EQ(dot, computation->root_instruction()); } -TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { - HloComputation::Builder builder(TestName()); - HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); - HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1")); +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1)); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[1024,256] parameter(1) + exponential = s32[1024,256] exponential(arg1) + transpose = s32[256,1024] transpose(exponential), dimensions={1,0} + ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloComputation* computation = module->entry_computation(); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }, TransposeFolding::NeverFoldTranspose); - EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); - EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[256,1] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[1,256] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)); +} + +TEST_F(InstructionFusionTest, + DotOperationFusion_TransposeFusion_LHS_NonDefault) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[256,1] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); } class OpcodeFusionTest : public InstructionFusionTest { @@ -697,6 +756,154 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +struct GatherLoopFusionTestSpec { + string test_name; + string hlo_computation_text; + + static string Name( + const ::testing::TestParamInfo& info) { + return info.param.test_name; + } +}; + +class GatherLoopFusionTest + : public OpcodeFusionTest, + public ::testing::WithParamInterface {}; + +TEST_P(GatherLoopFusionTest, GatherLoopFusion) { + const GatherLoopFusionTestSpec& spec = GetParam(); + string hlo_string = tensorflow::strings::StrCat( + "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); +} + +std::vector GetGatherLoopFusionTestSpecs() { + std::vector result; + + result.push_back({"FusedTensorFlowGatherV2", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_0", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_1", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedBatchDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + return result; +} + +INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index e8117377e61a4e..aa872d5ec9e759 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -100,7 +100,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) { + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) { const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -126,7 +127,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction)) { + } else if (PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, // rhs, and output need to be row-major. @@ -139,13 +141,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - // dot is a kDot or a kTransposeDot fusion node. In the latter case, if - // it represents X @ X, it may have just one operand. - if (dot->operand_count() > 1) { - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - } + const HloInstruction* rhs_instruction = dot->operand(1); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); @@ -181,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index c8edbb9e15a5b6..3c4fe68b830d96 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" @@ -27,12 +28,17 @@ namespace cpu { // layout constraints for operands and results of library calls. class CpuLayoutAssignment : public LayoutAssignment { public: - explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit CpuLayoutAssignment( + ComputationLayout* entry_computation_layout, + const TargetMachineFeatures* target_machine_features) + : LayoutAssignment(entry_computation_layout), + target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 6ba030fff3bbc5..429fc7b78608da 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -49,7 +50,12 @@ class CpuLayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -311,7 +317,12 @@ static StatusOr RunDotOutputFusion( result.addend_fusion_param = fusion_instruction->operand( fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(&computation_layout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index f9c51f243c47b8..e75fcb6bc9719f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -22,6 +22,8 @@ namespace { const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; +const char* const kXlaEnableExperimentalLlvmIrGemm = + "xla_enable_experimental_llvm_ir_gemm"; } // namespace @@ -54,6 +56,12 @@ tensorflow::gtl::optional LlvmIrGemvTilingFactor( return tensorflow::gtl::nullopt; } +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index be62ff3cc1af23..106dfbbc62dfba 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -26,6 +26,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 872b0be1f8a8ec..54c52bc08f9c53 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32"; extern const char* const kMKLMatMulF32SymbolName = "__xla_cpu_runtime_MKLMatMulF32"; extern const char* const kMKLMatMulF64SymbolName = @@ -50,6 +51,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index e392e231b4c71b..aa0e96712302e8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -44,6 +44,7 @@ namespace runtime { extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kMKLConvF32SymbolName; extern const char* const kMKLMatMulF32SymbolName; extern const char* const kMKLMatMulF64SymbolName; extern const char* const kMKLSingleThreadedMatMulF32SymbolName; @@ -51,6 +52,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 9b39e7f5765ae5..d97802ee45d6ad 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -88,8 +88,8 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status CpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 3ecb0d23649837..6dfc666f09dfa6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -38,7 +38,7 @@ class CpuTransferManager : public GenericTransferManager { ~CpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 495fecc4aa8b3c..d77076546f404a 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -41,17 +42,17 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { namespace { -// Loads a tile of values from a 2D tensor. -class TileLoader { +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { public: - // Constructs a TileLoader that will load a tile consisting of + // Constructs a MemoryTile that can operate on tiles consisting of // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { + : vsl_(vsl), ir_builder_(ir_builder) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { llvm::Value* total_offset = ir_builder->CreateMul( @@ -61,9 +62,10 @@ class TileLoader { } } - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. std::vector LoadTile(llvm::Value* minor_dim_offset) const { std::vector result; result.reserve(pointers_.size()); @@ -73,11 +75,104 @@ class TileLoader { return result; } + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(tensorflow::gtl::ArraySlice tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, + ir_builder_->getInt64(j)))); + } + } + return result; + } + private: VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* ir_builder_; std::vector pointers_; }; +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", + tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the // layout of the vector does not matter). This implementation uses a tiling // scheme to improve performance. @@ -139,38 +234,46 @@ class TileLoader { // TODO(sanjoy): We should investigate if using gather loads and scatter stores // can be used here have the same inner loop for both column-major and row-major // matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter { +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, - int64 tile_rows, int64 tile_cols, - int64 m, int64 k, llvm::Value* lhs, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { - CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), + ir_builder_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: void EmitOuterLoopBody(llvm::Value* column, int64 column_count, bool is_first_column); - TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m_, + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); } @@ -187,18 +290,14 @@ class ColumnMajorMatrixVectorProductEmitter { return result; } - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column); void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -210,25 +309,25 @@ class ColumnMajorMatrixVectorProductEmitter { void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( llvm::Value* column, int64 column_count, bool is_first_column) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, /*column_count=*/column_count); std::vector rhs_tile = LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, /*columns=*/column_count, is_first_column); EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); } void ColumnMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k_ % tile_cols_; - int64 column_limit = k_ - column_remainder; + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols_, is_first_column); + EmitOuterLoopBody(column, tile_cols(), is_first_column); }); if (column_remainder != 0) { @@ -238,14 +337,14 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column) { - int64 row_limit = m_ - (m_ % tile_rows_); + int64 row_limit = m() - (m() % tile_rows()); ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows_, [&](llvm::Value* row) { + /*step=*/tile_rows(), [&](llvm::Value* row) { std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); llvm::Value* accumulator = is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) : vsl_.GetZeroVector()) @@ -259,8 +358,8 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m_ - (m_ % tile_rows_); - if (row_start == m_) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { return; } @@ -280,11 +379,11 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); @@ -364,51 +463,55 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer // epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter { +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, - llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* addend, llvm::Value* result, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { - CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: - TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k_, + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); } void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators); void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -420,7 +523,7 @@ class RowMajorMatrixVectorProductEmitter { void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, int64 row_count) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, /*row_count=*/row_count); std::vector vector_accumulators; std::vector scalar_accumulators; @@ -428,7 +531,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); } - EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, &vector_accumulators); EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); @@ -465,12 +568,12 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, void RowMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m_ % tile_rows_; - int64 row_limit = m_ - row_remainder; + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -478,14 +581,14 @@ void RowMajorMatrixVectorProductEmitter::Emit() { } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, int64 rows, + MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators) { - int64 column_limit = k_ - (k_ % tile_cols_); + int64 column_limit = k() - (k() % tile_cols()); ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols_, [&](llvm::Value* col) { + /*step=*/tile_cols(), [&](llvm::Value* col) { std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); for (int i = 0; i < rows; i++) { llvm::Value* old_sum = (*vector_accumulators)[i].Get(); @@ -498,18 +601,18 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators) { - int64 column_start = k_ - (k_ % tile_cols_); - if (column_start == k_) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { return; } for (int r = 0; r < rows; r++) { llvm::Value* total_offset = ir_builder_->CreateMul( ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k_)); + ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), /*step=*/1, [&](llvm::Value* scalar_col) { llvm::Value* product = vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), @@ -520,18 +623,336 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } } +// This class implements a tiled matrix multiplication algorithm, intended for +// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, +// Kazushige, and Robert Van De Geijn. "High-performance implementation of the +// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): +// 4). +// +// This only supports canonical dot operations (i.e. where the lhs contraction +// dimension is 1 and the rhs contraction dimension is 0) over row major +// matrices. +class MatrixMatrixBlockPanelEmitter { + public: + // Describe the dimensions of the GEBP kernel. These will usually not be the + // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP + // kernels with smaller dimensions. + class Dimensions { + public: + explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} + + int64 m() const { return m_; } + int64 k() const { return k_; } + int64 n() const { return n_; } + + string ToString() const { + return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); + } + + private: + const int64 m_; + const int64 k_; + const int64 n_; + }; + + // Represents the configuration of the GEBP emitter. The LLVM IR emitted by + // the emitter, modulo the LLVM values holding the input and output buffers, + // must be a function of the instance of `Config` passed to it. + // + // `dims` holds the matrix multiplication dimensions. + // + // `max_vectorization_width` is the maximum vector width (i.e. the width of + // the largest vector register we will use). This can be larger than the + // largest vector register supported by the machine -- LLVM will legalize + // these large vector widths into legally sized vectors. + // `min_vectorization_width` is the smallest vector width the emitter will use + // -- below that it will devolve to using a scalar loop. + // + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), + "_", max_vectorization_width(), "_", min_vectorization_width(), "_", + tile_size_m(), "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : lhs_(lhs), + rhs_(rhs), + result_(result), + config_(config), + ir_builder_(ir_builder), + ksl_(ir_builder_) { + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GT(tile_size_k(), 0); + } + + void Emit(); + + private: + // This emits a loop that loops over the `n` dimension in multiples of + // `max_vectorization_width` as much as possible and then emits a remainder + // epilogue. + void EmitLoopOverN(); + + // This emits a loop that loops over the `k` dimension in multiples of + // `tile_size_k` as much as possible and then emits a remainder epilogue. + void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + + // This emits a loop that loops over the `m` dimension in multiples of + // `tile_size_m` as much as possible and then emits a remainder epilogue. + void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits the inner reduction loop. This inner reduction loop multiplies + // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the + // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size + // [`tile_size_m`, vls->vector_width()] in the result. + void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } + + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + Config config_; + + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; +}; + +void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); } + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = max_vectorization_width(); + int64 n_start = 0; + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); + if (n_start != n_end) { + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, + ir_builder_, "gebp"); + EmitLoopOverK(&vsl, GetInt64(n_start), GetInt64(n_end)); + n_start = n_end; + } + current_vectorization_width /= 2; + } + + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + llvm::Value* n_i_next = + ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); + EmitLoopOverK(&vsl, n_i, n_i_next); + }); + } +} + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { + int64 k_start = 0; + int64 k_end = dims().k() - (dims().k() % tile_size_k()); + if (k_end != k_start) { + EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); + k_start = k_end; + } + + if (k_start != dims().k()) { + EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); + } +} + +void MatrixMatrixBlockPanelEmitter::EmitLoopOverM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, + tile_size_m(), GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), + GetInt64(dims().m())); + } +} + +// The tiling scheme is as follows: +// +// Let the LHS be: +// +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ +// .. .. +// +// and the RHS be: +// +// +----+----+----+----+ +// | p0 | p1 | p2 | p3 | . +// +----+----+----+----+ . +// | q0 | q1 | q2 | q3 | . +// +----+----+----+----+ +// | r0 | r1 | r2 | r3 | . +// +----+----+----+----+ . +// ...... ...... +// +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. +// +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: +// +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ +// +// +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: +// +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ +// +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ +// +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + + ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i, + tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + ksl_.For( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_memory_tile.LoadTile(n_i); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_memory_tile.StoreTile(result_tile, n_i); + }); + }); + }); +} + } // namespace -DotOpEmitter::DotOpEmitter( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -541,23 +962,99 @@ DotOpEmitter::DotOpEmitter( hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, +/* static */ Status DotOpEmitter::EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, addend_array, - executable_run_options_value, ir_builder, - hlo_module_config, target_machine_features); + DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, + addend_array, executable_run_options_value, + ir_builder, hlo_module_config, + target_machine_features); return dot_emitter.Emit(); } -bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } +bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( + const DotOpEmitter::MatMultDims& mat_mult_dims) { + if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + return false; + } + + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + switch (primitive_type) { + default: + return false; + + case F32: + case F64: + case S32: + case S64: + break; + } + + if (!(mat_mult_dims.lhs_column_major == mat_mult_dims.rhs_column_major && + mat_mult_dims.rhs_column_major == mat_mult_dims.target_column_major)) { + return false; + } + + llvm::Value* lhs = lhs_array_.GetBasePointer(); + llvm::Value* rhs = rhs_array_.GetBasePointer(); + llvm::Value* target = target_array_.GetBasePointer(); + int64 m = mat_mult_dims.m; + int64 k = mat_mult_dims.k; + int64 n = mat_mult_dims.n; + + if (mat_mult_dims.lhs_column_major) { + std::swap(lhs, rhs); + std::swap(m, n); + } + + int64 size_bytes = m * n * ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + ir_builder_->CreateMemSet( + target, ir_builder_->getInt8(0), size_bytes, + target_machine_features_.minimum_alignment_for_allocation(size_bytes)); + + int64 max_vector_width = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + MatrixMatrixBlockPanelEmitter::Config config( + /*scalar_type=*/primitive_type, + MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_vector_width, + /*min_vectorization_width=*/std::min(4, max_vector_width), + /*tile_size_m=*/3, /*tile_size_k=*/5); + + VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + << config.GetCacheKey(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs, rhs, target, + [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { + MatrixMatrixBlockPanelEmitter gebp_emitter( + config, /*lhs=*/lhs, /*rhs=*/rhs, + /*result=*/target, ir_builder_); + gebp_emitter.Emit(); + }); + + return true; +} bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { @@ -580,7 +1077,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.m == 1) { bool rhs_effectively_row_major = - transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major; if (rhs_effectively_row_major) { k = mat_mult_dims.k; m = mat_mult_dims.n; @@ -596,7 +1093,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.n == 1) { bool lhs_effectively_column_major = - transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major; if (lhs_effectively_column_major) { m = mat_mult_dims.m; k = mat_mult_dims.k; @@ -611,7 +1108,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return false; + return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -644,47 +1141,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = vector_register_element_size; - int64 tile_cols = tiling_factor; - - string kernel_name = tensorflow::strings::StrCat( - "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = tiling_factor; - int64 tile_cols = vector_register_element_size; - - string kernel_name = tensorflow::strings::StrCat( - "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { RowMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } @@ -692,7 +1181,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { return true; } -tensorflow::Status DotOpEmitter::Emit() { +Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -736,23 +1225,17 @@ tensorflow::Status DotOpEmitter::Emit() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_)) { + if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { return EmitCallToRuntime(); } // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = 0; - if (ShapeUtil::Rank(lhs_shape) >= 2) { - lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); - } - int64 rhs_reduction_dimension = 0; - if (ShapeUtil::Rank(rhs_shape) >= 2) { - rhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); - } + int64 lhs_reduction_dimension = + dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = + dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -876,10 +1359,10 @@ tensorflow::Status DotOpEmitter::Emit() { // loop. ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitScalarDot() { +Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; llvm::Value* lhs_value = @@ -904,12 +1387,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitCallToRuntime() { - DCHECK(ShapesAreLegalForRuntimeDot()); - +Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -918,8 +1399,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded = - hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool multi_threaded = ShouldUseMultiThreadedEigen(); bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; @@ -990,8 +1470,8 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; - bool transpose_lhs = transpose_lhs_; - bool transpose_rhs = transpose_rhs_; + bool transpose_lhs = mat_mult_dims.lhs_non_canonical; + bool transpose_rhs = mat_mult_dims.rhs_non_canonical; if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); @@ -1011,7 +1491,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { ir_builder_->getInt64(mat_mult_dims.k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); - return tensorflow::Status::OK(); + return Status::OK(); } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { @@ -1019,12 +1499,18 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); - - return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), - lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), - rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, - LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; + const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); + + return { + /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), + /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), + /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, + /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1, + /*target_column_major=*/ + LayoutUtil::Minor(target_array_.GetShape().layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1064,19 +1550,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) { // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. -static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { +static bool AreValidGemmShapes( + const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); - return (output_primitive_type == F64 || output_primitive_type == F32 || - output_primitive_type == F16) && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); + if (!(output_primitive_type == F64 || output_primitive_type == F32 || + output_primitive_type == F16)) { + return false; + } + + if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape))) { + return false; + } + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || + !is_aligned(output_shape)) { + return false; + } + + return true; } -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features) { // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -1093,28 +1599,18 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), + target_machine_features)) { + const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; } } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 9d748eb81f7850..d88ccea0dbc845 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -31,7 +31,9 @@ limitations under the License. namespace xla { namespace cpu { -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot @@ -55,17 +57,16 @@ class DotOpEmitter { // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. - static tensorflow::Status EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + static Status EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); private: - DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, const llvm_ir::IrArray& target_array, + DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -75,18 +76,18 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - tensorflow::Status Emit(); + Status Emit(); // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - tensorflow::Status EmitScalarDot(); + Status EmitScalarDot(); // Emit an LLVM IR implementation of the dot operation if we can. Returns // true if an LLVM IR implementation was emitted. bool EmitLlvmIrDotIfProfitable(); // Emits a call to the CPU runtime to perform the matrix multiply. - tensorflow::Status EmitCallToRuntime(); + Status EmitCallToRuntime(); // Emits a series of nested loops for iterating over an operand array in the // dot operation. Loops are constructed in major to minor dimension layout @@ -99,10 +100,6 @@ class DotOpEmitter { llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Our runtime operation requires that all arrays have the same layout, - // no padding, and a rank of two. - bool ShapesAreLegalForRuntimeDot() const; - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. @@ -115,11 +112,20 @@ class DotOpEmitter { // The number of columns on the RHS. int64 n; - // True if the LHS matrix column major. + // True if the LHS matrix is column major. bool lhs_column_major; - // True if the RHS matrix column major. + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; + + // True if the RHS matrix is column major. bool rhs_column_major; + + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; + + // True if the result matrix is column major. + bool target_column_major; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -127,6 +133,8 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; + bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. int64 GetGemvTilingFactor() const { @@ -135,9 +143,18 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + // Returns true if we should use an experimental implementation of GEMM + // (general matrix matrix multiplication) if possible. + bool EnableExperimentalLlvmIrGemm() const { + return options::EnableExperimentalLlvmIrGemm(hlo_module_config_); + } + + // Returns true if we should call into multi-threaded Eigen routines. + bool ShouldUseMultiThreadedEigen() { + return hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + } + const HloInstruction& dot_; - const bool transpose_lhs_; - const bool transpose_rhs_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index 7dcc4ca7fa08b4..c5628655915875 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -26,13 +26,13 @@ limitations under the License. namespace xla { namespace cpu { -void ExternalConstantPool::Insert(string name, const Literal& literal, +void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, int64 alignment) { CHECK(!ShapeUtil::IsTuple(literal.shape())); CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); CHECK(entries_.find(name) == entries_.end()); - int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); void* raw_pointer = tensorflow::port::AlignedMalloc( literal_size, std::max(alignment, sizeof(void*))); CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 8008a56df4dbf1..0677f5f0b58005 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -43,7 +43,7 @@ class ExternalConstantPool { // The constant pool copies out the contents of `literal` into a buffer it // owns -- it does not keep pointers to `literal`, or to memory owned by // `literal`. - void Insert(string name, const Literal& literal, int64 alignment); + void Insert(string name, const LiteralSlice& literal, int64 alignment); // Find the constant with name `name` in this constant pool. If there isn't // such constant, return nullptr. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index f209a69e3cd0f8..b560b7531c0d24 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -24,8 +24,25 @@ limitations under the License. namespace xla { namespace cpu { +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features) { + CHECK(ShapeUtil::IsArray(shape)); + CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); + + // We don't require a layout to be set on `shape`. This only works on CPU + // because we don't pad our tensors or otherwise have complicated data tiling + // schemes. + + int64 allocation_size_bytes = + ShapeUtil::ElementsIn(shape) * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + return target_machine_features.minimum_alignment_for_allocation( + allocation_size_bytes); +} + bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution) { + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features) { // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -35,6 +52,18 @@ bool PotentiallyImplementedAsEigenConvolution( // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); const Shape& kernel_shape = convolution.operand(1)->shape(); + const Shape& output_shape = convolution.shape(); + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(input_shape) || !is_aligned(kernel_shape) || + !is_aligned(output_shape)) { + return false; + } + if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; @@ -71,7 +100,6 @@ bool PotentiallyImplementedAsEigenConvolution( } } - const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && dnums.output_batch_dimension() == 0 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 34b2003916933f..68fbc7caaa9bfe 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -17,13 +17,20 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace cpu { bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution); + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features); + +// Computes the minimum alignment guaranteed for a tensor of shape `shape` on +// the target machine. +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features); // Dynamic loop bounds are specified as an array of dimension index // [start, limit) pairs of ir values (one for each partitioned outer dimension). diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index 215f48c4cc1a1a..530ebce854fedf 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -34,12 +35,17 @@ ENTRY Conv { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* entry_computation = module->entry_computation(); HloInstruction* conv_instr = entry_computation->root_instruction(); - EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution( + *conv_instr, target_machine_features)); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 0b08ad8da3cf17..59223fddac2f5f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -83,7 +83,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine_features, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), @@ -94,7 +94,7 @@ IrEmitter::IrEmitter( alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), is_top_level_computation_(false), - target_machine_features_(target_machine), + target_machine_features_(*target_machine_features), external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() @@ -160,41 +160,59 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant // arrays into the Ir, in the off chance the LLVM optimizer can do something // interesting with it. + // + // TODO(b/29904935): Remove the large constant pool. const int kMaxInternalConstantSizeInBytes = 128; if (external_constant_pool_ && ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - global_for_const = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); + result = result_global; } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - global_for_const = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result = llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); + } + return result; +} + +Status IrEmitter::HandleConstant(HloInstruction* constant) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + const Literal& literal = constant->literal(); + llvm::Constant* global_for_const; + + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + emitted_literals_[&literal] = global_for_const; } emitted_value_[constant] = global_for_const; VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); @@ -214,32 +232,6 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { } } -// Calculate the alignment of a buffer with a particular size. -int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { - // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on - // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than kMallocAlignmentThreshold bytes and at least - // alignment 16 for allocations greater than or equal to - // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound - // by explicitly allocating the memory with posix_memalign. This is - // complicated by our desire to allow parameter buffers created by clients to - // be consumed directly by the JIT. - if (buffer_size == 0) { - // No need to align empty buffers. - return 1; - } - - const int64 kMallocAlignmentThreshold = 512; - - int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= kMallocAlignmentThreshold - ? 2 * pointer_size - : pointer_size; - DCHECK_GT(buffer_alignment, 0); - - return buffer_alignment; -} - // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); @@ -264,7 +256,7 @@ int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); - return MinimumAlignmentForBufferSize(buffer_size); + return target_machine_features_.minimum_alignment_for_allocation(buffer_size); } void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, @@ -277,7 +269,8 @@ void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size) { - int alignment = MinimumAlignmentForBufferSize(buffer_size); + int alignment = + target_machine_features_.minimum_alignment_for_allocation(buffer_size); if (alignment > 1) { llvm_ir::SetAlignmentMetadataForLoad(load, alignment); } @@ -517,7 +510,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32, BF16})); + /*supported_types=*/{F32, BF16, S32})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -814,13 +807,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { "Dot with multiple contracting dimensions not implemented."); } - if (dnums.lhs_contracting_dimensions(0) != - std::min(lhs->shape().dimensions_size() - 1, 1) || - dnums.rhs_contracting_dimensions(0) != 0) { - return Unimplemented( - "Dot with non-standard contracting dimensions not implemented."); - } - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -837,8 +823,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, /*addend_array=*/nullptr, + *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, target_machine_features_); } @@ -854,7 +839,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (PotentiallyImplementedAsEigenConvolution(*convolution)) { + // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support + // different data layouts. + if (PotentiallyImplementedAsEigenConvolution(*convolution, + target_machine_features_)) { const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -942,16 +930,26 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - bool multi_threaded_eigen = + bool multi_threaded = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool use_mkl_dnn = + hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); + + // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the + // potential race condition by setting the omp_num_threads. const char* fn_name = primitive_type == F16 - ? (multi_threaded_eigen + ? (multi_threaded ? runtime::kEigenConvF16SymbolName : runtime::kEigenSingleThreadedConvF16SymbolName) - : (multi_threaded_eigen - ? runtime::kEigenConvF32SymbolName + : (multi_threaded + ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName + : runtime::kEigenConvF32SymbolName) : runtime::kEigenSingleThreadedConvF32SymbolName); + if (!multi_threaded && use_mkl_dnn) { + LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " + "conv2d function."; + } llvm::Function* conv_func = llvm::cast( module_->getOrInsertFunction(fn_name, conv_type)); conv_func->setCallingConv(llvm::CallingConv::C); @@ -1010,12 +1008,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), - "convolution_sum_address", &ir_builder_, + lhs_llvm_type, "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); + llvm::Value* constant_zero = + llvm::Constant::getNullValue(lhs_llvm_type); + ir_builder_.CreateStore(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); std::vector kernel_spatial(num_spatial_dims); @@ -1169,7 +1169,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); @@ -1191,16 +1197,45 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { } Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { - if (hlo_module_config_.replica_count() == 1) { - // When there is a single replica, a cross replica sum is the identity - // function, and the buffer assignment expects a copy (we could eliminate - // these at the HLO level as an optimization). - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on CPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on CPU."); + } + + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + + // CRS with one operand and one replica is simply the identity function. + if (crs->operand_count() == 1) { return EmitMemcpy(*crs->operand(0), *crs); } - // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented("CrossReplicaSum is not implemented on CPU."); + // CRS with multiple operands and one replica produces a (one-deep) tuple. + std::vector operand_ptrs; + for (int64 i = 0; i < crs->operand_count(); ++i) { + llvm::Value* in_ptr = GetEmittedValueFor(crs->operand(i)); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(crs, {i})); + + const Shape& operand_shape = crs->operand(i)->shape(); + CHECK(ShapeUtil::IsArray(operand_shape)) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + + // TODO(b/63762267): Be more aggressive about specifying alignment. + ir_builder_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, + ShapeUtil::ByteSizeOf(operand_shape)); + } + llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &ir_builder_, module_); + return Status::OK(); } // Fills up the free variables in 'index_with_free_var' with values from @@ -2061,44 +2096,7 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); - if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - DCHECK(root->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - fusion->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - fusion->operand(rhs_parameter->parameter_number()); - - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32, F64})); - - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - - Shape target_shape = fusion->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array = GetIrArrayFor(fusion); - VLOG(2) << "HandleFusion kTransposeDot: "; - VLOG(2) << " lhs operand: " - << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); - VLOG(2) << " rhs operand: " - << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); - VLOG(2) << " target: " - << llvm_ir::DumpToString(*target_array.GetBasePointer()); - - // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *root, root->operand(0)->IsRank2Transpose(), - root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); - return Status::OK(); - } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, - assignment_)) { + if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); @@ -2141,9 +2139,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { GetIrArrayFor(fusion->operand(addend_param_number))); TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2538,8 +2536,12 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // nothing to do since the result was already written directly into the output // buffer. VLOG(2) << "FinishVisit root: " << root->ToString(); - llvm::Value* root_value = GetEmittedValueFor(root); - VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); + if (root->opcode() == HloOpcode::kOutfeed) { + VLOG(2) << " outfeed with value: " + << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0))); + } else { + VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root)); + } auto record_complete_computation = [&](llvm::Value* prof_counter) { if (prof_counter) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 0f2f3d1817d6e8..32c536e18fee86 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -76,7 +76,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine, ExternalConstantPool* external_constant_pool); ~IrEmitter() override; @@ -514,9 +514,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Calculate the alignment of a buffer allocated for a given primitive type. int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); - // Calculate the alignment of a buffer with a particular size. - int MinimumAlignmentForBufferSize(int64 buffer_size); - // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; @@ -530,15 +527,32 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); + const HloModuleConfig& hlo_module_config_; bool is_top_level_computation_; - TargetMachineFeatures target_machine_features_; + const TargetMachineFeatures& target_machine_features_; int64 external_global_constant_counter_ = 0; ExternalConstantPool* external_constant_pool_; + struct LiteralPtrHashFunctor { + size_t operator()(const Literal* literal) const { return literal->Hash(); } + }; + + struct LiteralPtrEqualityFunctor { + bool operator()(const Literal* lhs, const Literal* rhs) const { + return *lhs == *rhs; + } + }; + + tensorflow::gtl::FlatMap + emitted_literals_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 557aa4a6bfc2ef..2e55181eed867a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -33,8 +33,8 @@ namespace cpu { // emitters for function and function argument access. // The llvm::Function is created with the standard function signature // used in the XLA CPU backend (see ir_function.cc for argument details). -// In addtion IrFunction saves the callers IR insert point during contruction, -// and restores it after desctruction. +// In addition IrFunction saves the callers IR insert point during construction, +// and restores it after destruction. // // Example usage: // diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index fb28280fade307..4fa5984b0466b1 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -38,7 +38,7 @@ class SimpleCostModel : public ParallelCostModel { const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -63,7 +63,7 @@ class DefaultCostModel : public ParallelCostModel { int64 max_parallelism; // Calculate flops-to-bytes-ratio for 'instruction'. const int64 bytes_accessed = - std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction)); const float flops_to_bytes_ratio = cost_analysis_->flop_count(*instruction) / static_cast(bytes_accessed); @@ -93,7 +93,7 @@ class DefaultCostModel : public ParallelCostModel { } // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel { ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module, + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -127,7 +129,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Emit custom loops (kSelectAndScatter). // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. @@ -139,8 +141,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) || + PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { @@ -231,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { ParallelTaskAssignment parallel_task_assignment(max_parallelism_, - shape_size_function_, module); + shape_size_function_, module, + &target_machine_features_); // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 7140dabe516cd7..8becc8fa23424d 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,7 +40,8 @@ class ParallelTaskAssignment { // 'module': the containing HloModule. ParallelTaskAssignment(const int64 max_parallelism, const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + HloModule* module, + const TargetMachineFeatures* target_machine_features); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -47,6 +49,7 @@ class ParallelTaskAssignment { private: std::unique_ptr cost_model_; + const TargetMachineFeatures& target_machine_features_; }; // ParallelTaskAssigner computes target parallel task counts for all HLOs @@ -63,8 +66,11 @@ class ParallelTaskAssigner : public HloPassInterface { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size, + const TargetMachineFeatures* target_machine_features) + : max_parallelism_(max_parallelism), + shape_size_function_(shape_size), + target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -94,6 +100,7 @@ class ParallelTaskAssigner : public HloPassInterface { int64 max_parallelism_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 13eb75a57213b1..fc2efbaf9a22b0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,6 +32,19 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { // Use any value larger than 2 since we only test whether a module is // parallelized or not const int max_parallelism_ = 10; + + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; + + ParallelTaskAssignmentTest() + : target_machine_features_([](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }) {} + + StatusOr RunParallelTaskAssigner(HloModule* module) { + return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_, + &target_machine_features_) + .Run(module); + } }; TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { @@ -45,9 +59,7 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -74,9 +86,7 @@ TEST_F(ParallelTaskAssignmentTest, )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -92,9 +102,7 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -108,9 +116,7 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc new file mode 100644 index 00000000000000..c60580d6e763c6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" +#include +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int64; + +#ifdef INTEL_MKL +#include +#include "mkldnn.hpp" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" + +namespace { + +// Downcast an int64 to int and check if value is in range. +int ToInt(int64 input) { + int output = static_cast(input); + if (static_cast(output) != input) { + std::cerr << "Error occurred in downcasting int64 to int32: Value " << input + << " is out-of-range for type int32. \n"; + exit(1); + } + return output; +} + +using mkldnn::convolution_direct; +using mkldnn::convolution_forward; +using mkldnn::engine; +using mkldnn::memory; +using mkldnn::padding_kind; +using mkldnn::primitive; +using mkldnn::prop_kind; +using mkldnn::reorder; +using mkldnn::stream; + +template +void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, int64 input_batch, int64 input_rows, + int64 input_cols, int64 input_channels, int64 kernel_rows, + int64 kernel_cols, int64 kernel_channels, int64 kernel_filters, + int64 output_rows, int64 output_cols, int64 row_stride, + int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, + int64 rhs_row_dilation, int64 rhs_col_dilation) { + auto cpu_engine = engine(engine::cpu, 0); + + // Create a vector primitive to hold the network. + std::vector net; + + // Since memory::dims takes int for each dimension, we downcast the int64 + // values to int using the ToInt function defined above. + memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels), + ToInt(input_rows), ToInt(input_cols)}; + memory::dims conv1_weights_dim = {ToInt(kernel_filters), + ToInt(kernel_channels), ToInt(kernel_rows), + ToInt(kernel_cols)}; + memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters), + ToInt(output_rows), ToInt(output_cols)}; + memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)}; + // Note: In MKL_DNN dilation starts from 0. + memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1), + ToInt(rhs_col_dilation - 1)}; + memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)}; + memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)}; + + // Create memory for user data. Input and output data have format of NHWC and + // kernel data has format of HWIO. + // Note that as a convention in MKL-DNN, the dimensions of the data is always + // described in NCHW/IOHW, regardless of the actual layout of the data. + auto user_src_memory = + memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc}, + cpu_engine}, + lhs); + auto user_weights_memory = memory( + {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio}, + cpu_engine}, + rhs); + auto user_dst_memory = + memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc}, + cpu_engine}, + out); + + // Create memory descriptors for convolution data with no specified format for + // best performance. + auto conv1_src_mem_desc = memory::desc( + {conv1_src_dim}, memory::data_type::f32, memory::format::any); + auto conv1_weights_mem_desc = memory::desc( + {conv1_weights_dim}, memory::data_type::f32, memory::format::any); + auto conv1_dst_mem_desc = memory::desc( + {conv1_dst_dim}, memory::data_type::f32, memory::format::any); + + // Create a convolution. + auto conv1_desc = convolution_forward::desc( + prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc, + conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates, + conv1_padding_l, conv1_padding_r, padding_kind::zero); + auto conv1_prim_desc = + convolution_forward::primitive_desc(conv1_desc, cpu_engine); + + // Create reorders for data and weights if layout requested by convolution is + // different from NCHW/OIHW. + auto conv1_src_memory = user_src_memory; + if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) != + user_src_memory.get_primitive_desc()) { + conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc()); + net.push_back(reorder(user_src_memory, conv1_src_memory)); + } + + auto conv1_weights_memory = user_weights_memory; + if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) != + user_weights_memory.get_primitive_desc()) { + conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc()); + net.push_back(reorder(user_weights_memory, conv1_weights_memory)); + } + + // Check if output need layout conversion. If yes, create memory for + // intermediate layer of conv1_dst_memory. + bool need_output_conversion = + (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) != + user_dst_memory.get_primitive_desc()); + auto conv1_dst_memory = need_output_conversion + ? memory(conv1_prim_desc.dst_primitive_desc()) + : user_dst_memory; + + // Create convolution primitive and add it to net. + net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory, + conv1_weights_memory, conv1_dst_memory)); + if (need_output_conversion) { + net.push_back(reorder(conv1_dst_memory, user_dst_memory)); + } + stream(stream::kind::eager).submit(net).wait(); +} +} // namespace +#endif // INTEL_MKL + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConvF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, + int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { +#ifdef INTEL_MKL + // Since MKL_DNN cannot handle transposed convolution, this is handled by + // Eigen. + if (lhs_row_dilation > 1 || lhs_col_dilation > 1) { + __xla_cpu_runtime_EigenConvF32( + run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols, + input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, col_stride, + padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); + } else { + MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols, + input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, + col_stride, padding_top, padding_bottom, padding_left, + padding_right, lhs_row_dilation, lhs_col_dilation, + rhs_row_dilation, rhs_col_dilation); + } +#else + std::cerr << "Attempt to call MKL Conv2D runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +#endif // INTEL_MKL +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h new file mode 100644 index 00000000000000..b239e71d231c52 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ + +#include +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_MKLConvF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e0247..0bf693edd0b985 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; + // Unsupported FFT type + abort(); } } @@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + // Unsupported FFT rank + abort(); } } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 00000000000000..2613ddb12704ae --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 00000000000000..dcd133d012cf07 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index b3f4609d465efb..167aa4adda995a 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -48,13 +48,13 @@ int main(int argc, char** argv) { client->TransferToServer(*param1_literal).ConsumeValueOrDie(); // Build computation. - xla::ComputationBuilder builder(client, ""); + xla::XlaBuilder builder(""); auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto add = builder.Add(p1, p0, {0}); - xla::StatusOr computation_status = builder.Build(); - xla::Computation computation = computation_status.ConsumeValueOrDie(); + xla::StatusOr computation_status = builder.Build(); + xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); // Execute and transfer result of computation. xla::ExecutionProfile profile; diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 42fe955f1917e0..d12c5396148d32 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -115,7 +115,7 @@ ShapePartitionIterator::ShapePartitionIterator( for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { const int64 dim_size = shape_.dimensions(dimensions_[i]); dimension_partition_sizes_[i] = - std::max(1LL, dim_size / dimension_partition_counts_[i]); + std::max(int64{1}, dim_size / dimension_partition_counts_[i]); } // Calculate the partition strides for each dimension. diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h index 33d02b70e61e33..db2cda2936c834 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.h +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -38,7 +38,7 @@ namespace cpu { // // [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8) // -// Note that the last partition has residule because the dimension size is +// Note that the last partition has residual because the dimension size is // not a multiple of the partition count. // // diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b7ce5bbe474823..c4c90515ac7ec2 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -31,12 +31,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -72,23 +74,33 @@ llvm::StringRef GetHostCpuName() { } } // namespace +/*static*/ std::unique_ptr +SimpleOrcJIT::InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level) { + std::unique_ptr target_machine( + llvm::EngineBuilder() + .setTargetOptions(target_options) + .setOptLevel(opt_level) + .selectTarget( + /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", + /*MCPU=*/GetHostCpuName(), + /*MAttrs=*/DetectMachineAttributes())); + CHECK(target_machine != nullptr); + return target_machine; +} + SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook) - : target_machine_( - CHECK_NOTNULL(llvm::EngineBuilder() - .setTargetOptions(target_options) - .setOptLevel(opt_level) - .selectTarget( - /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/GetHostCpuName(), - /*MAttrs=*/DetectMachineAttributes()))), + : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( + execution_session_, [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); }, @@ -178,6 +190,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); @@ -190,6 +203,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index f4260a95bc4555..1851a3ee0bb97b 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -95,6 +95,12 @@ class SimpleOrcJIT { return &external_constant_pool_; } + // Creates an llvm::TargetMachine suitable for JITting code that will run on + // the current machine. + static std::unique_ptr InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level); + private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index eeb049737dddd1..a0cd8ee2d2be10 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { namespace cpu { -llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( +llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor( const llvm::Function& function) const { auto it = target_transform_info_cache_.find(&function); if (it == target_transform_info_cache_.end()) { @@ -31,5 +31,30 @@ llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( return &it->second; } +int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation( + int64 size_bytes) const { + // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16 + // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is + // complicated by our desire to allow parameter buffers created by clients to + // be consumed directly by the JIT. + if (size_bytes == 0) { + // No need to align empty buffers. + return 1; + } + + const int64 kMallocAlignmentThreshold = 512; + + int pointer_size = target_machine_->getPointerSize(0); + int buffer_alignment = + size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size; + DCHECK_GT(buffer_alignment, 0); + + return buffer_alignment; +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 703942615e552d..8b00ae9e47eeed 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -24,43 +24,68 @@ limitations under the License. namespace xla { namespace cpu { -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// LLVM IR code generation decisions. +// Abstract interface for classes providing information about the target we're +// compiling for. class TargetMachineFeatures { public: static constexpr int kX86AvxVectorByteSize = 32; - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} + // Input and output tensor buffers must be aligned to this many bytes if we + // want to call an Eigen backed GEMM or Convolution. + static constexpr int kEigenExpectedTensorAlignment = 16; // Return the vectorization factor, which is the number of bytes of data // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { - // Ideally this should be a function of the cache line size (which we can - // get from llvm::TargetTransformInfo::getCacheLineSize) of the target - // machine. Guess a value of 128 bytes for now. - return 128; - } + virtual int vectorization_factor_in_bytes() const = 0; // Return the size of the largest vector size in bytes. We need to pass in // "function" since llvm functions can contain annotations for specializing // them to specific micro-architectures (though currently XLA does not use // this functionality). - int vector_register_byte_size(const llvm::Function& function) const { - llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); - return tti->getRegisterBitWidth(/*Vector=*/true) / 8; - } + virtual int vector_register_byte_size( + const llvm::Function& function) const = 0; // Return the number of elements of type `type` that can fit into the largest // vector register available. We need to pass in "function" since llvm // functions can contain annotations for specializing them to specific // micro-architectures (though currently XLA does not use this functionality). + virtual int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const = 0; + + // Returns the minimum alignment for a buffer of size size_bytes. + virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; + + virtual ~TargetMachineFeatures() = default; +}; + +// Implements the TargetMachineFeatures interface using an llvm::TargetMachine. +class LLVMTargetMachineFeatures : public TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + int vectorization_factor_in_bytes() const override { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } + int vector_register_num_elements(const llvm::Function& function, - PrimitiveType type) const { + PrimitiveType type) const override { return vector_register_byte_size(function) / (primitive_util::BitWidth(type) / 8); } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override; + private: llvm::TargetTransformInfo* GetTargetTransformInfoFor( const llvm::Function& function) const; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h new file mode 100644 index 00000000000000..ffc6927cbe1a2b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { +// Delegates calls to minimum_alignment_for_allocation to a user provided +// std::function, crashes on all other methods. +// +// Primarily useful for testing. +class TargetMachineFeaturesWithFakeAlignmentLogic + : public TargetMachineFeatures { + public: + explicit TargetMachineFeaturesWithFakeAlignmentLogic( + std::function fake_alignment_logic) + : fake_alignment_logic_(std::move(fake_alignment_logic)) {} + + int vectorization_factor_in_bytes() const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int64 minimum_alignment_for_allocation(int64 size_bytes) const override { + return fake_alignment_logic_(size_bytes); + } + + private: + std::function fake_alignment_logic_; +}; +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD new file mode 100644 index 00000000000000..66ae5ef0f66e90 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -0,0 +1,176 @@ +# Description: +# Tests for LLVM-based CPU backend for XLA. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "cpu_codegen_test", + testonly = True, + hdrs = ["cpu_codegen_test.h"], + deps = [ + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_fusion_test", + srcs = ["cpu_fusion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_bytesizeof_test", + srcs = ["cpu_bytesizeof_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_external_constants_test", + srcs = ["cpu_external_constants_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "cpu_noalias_test", + srcs = ["cpu_noalias_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm//:core", + ], +) + +tf_cc_test( + name = "cpu_intrinsic_test", + srcs = ["cpu_intrinsic_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_eigen_dot_operation_test", + srcs = ["cpu_eigen_dot_operation_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_infeed_test", + srcs = ["cpu_infeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_literal_caching_test", + srcs = ["cpu_literal_caching_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_outfeed_test", + srcs = ["cpu_outfeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc new file mode 100644 index 00000000000000..d5bbe7677ace67 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +class CpuByteSizeOfTest : public ::testing::Test {}; + +TEST_F(CpuByteSizeOfTest, ARM32) { + llvm::DataLayout data_layout( + "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"); + auto tuple_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout), + data_layout.getPointerSize(0 /* default address space */)); +} + +TEST_F(CpuByteSizeOfTest, ARM64) { + llvm::DataLayout data_layout("e-m:e-i64:64-i128:128-n32:64-S128"); + auto tuple_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout), + data_layout.getPointerSize(0 /* default address space */)); +} diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h new file mode 100644 index 00000000000000..7c8d07a10baf55 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ + +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" + +namespace xla { +namespace cpu { + +// Tests that verify IR emitted by the CPU backend is as expected. +class CpuCodegenTest : public LLVMIRGenTestBase {}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc new file mode 100644 index 00000000000000..6fcce42eaa4599 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that we call into Eigen for dot operations as needed. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +struct DotTestSpec { + PrimitiveType primitive_type; + string filecheck_lines; +}; + +string DotTestSpecToString(const ::testing::TestParamInfo& info) { + return PrimitiveType_Name(info.param.primitive_type); +} + +class CpuEigenDotOperationTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + protected: + void CompileAndCheck(std::unique_ptr entry_computation, + const string& filecheck_lines) { + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(entry_computation)); + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, + filecheck_lines, + /*match_optimized_ir=*/true); + } +}; + +TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { + HloComputation::Builder builder(TestName()); + DotTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128}); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "input")); + + builder.AddInstruction( + HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + CompileAndCheck(builder.Build(), spec.filecheck_lines); +} + +TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { + HloComputation::Builder builder(TestName()); + DotTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128}); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "input")); + HloInstruction* lhs_transposed = builder.AddInstruction( + HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); + + builder.AddInstruction( + HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + CompileAndCheck(builder.Build(), spec.filecheck_lines); +} + +std::vector GetDotTestCases() { + std::vector result; + result.push_back( + {F16, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF16)"}); + result.push_back( + {F32, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF32)"}); + result.push_back( + {F64, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF64)"}); + return result; +} + +INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation, + CpuEigenDotOperationTest, + ::testing::ValuesIn(GetDotTestCases()), + DotTestSpecToString); + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc new file mode 100644 index 00000000000000..faac927027c48e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class CpuExternalConstantsTest : public CpuCodegenTest { + public: + void TestWithArray(int64 rows, int64 cols, const char* filecheck_pattern) { + HloComputation::Builder builder(TestName()); + + Array2D backing_array(rows, cols); + backing_array.FillUnique(); + + auto shape = ShapeUtil::MakeShape(F32, {rows, cols}); + + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2FromArray2D(backing_array))); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); + + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CompileAndVerifyIr(std::move(module), filecheck_pattern, + /*match_optimized_ir=*/false); + } +}; + +TEST_F(CpuExternalConstantsTest, Basic) { + TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( +CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +)"); +} + +TEST_F(CpuExternalConstantsTest, BasicNegative) { + // The constant array in this test case is small enough that there is no need + // to externalize it. + TestWithArray(/*rows=*/4, /*cols=*/4, R"( +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [16 x float] {{.*}}, align 8 +)"); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc new file mode 100644 index 00000000000000..23e7a3de4d8188 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +class CpuFusionTest : public HloTestBase { + protected: + CpuFusionTest() {} + + ErrorSpec error_spec_{0.0001, 1e-5}; +}; + +TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { + auto builder = HloComputation::Builder(TestName()); + auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + Shape vshape = input_literal1->shape(); + + auto input1 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal1))); + auto input2 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal2))); + + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kAdd, input1, input2)); + builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + auto fusion_instruction = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + EXPECT_EQ(HloOpcode::kNegate, + fusion_instruction->fused_expression_root()->opcode()); + // There should be four fused instructions: 2 parameters, the add, and the + // negate. + EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); +} + +TEST_F(CpuFusionTest, FuseElementwiseOpChain) { + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + Shape vshape = input_literal->shape(); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); + auto floor = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + auto fusion_instruction = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + EXPECT_EQ(HloOpcode::kMultiply, + fusion_instruction->fused_expression_root()->opcode()); + // There should be 7 fused instructions: 2 parameters and the fused + // operations. + EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, + error_spec_); +} + +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { + // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the + // middle. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + Shape vshape = input_literal->shape(); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + + auto cshape = ShapeUtil::MakeShape(F32, {6}); + auto concatenate = builder.AddInstruction( + HloInstruction::CreateConcatenate(cshape, {ceil, ceil}, /*dimension=*/0)); + + // Build an x+y computation to use in a reduce. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto embedded_builder = HloComputation::Builder("f32+f32"); + embedded_builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, + embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")), + embedded_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "y")))); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); + + // This is a nop reduction. + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + cshape, + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), + /*init_value=*/ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{1}, add_f32)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); + auto floor = builder.AddInstruction( + HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); + + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + + auto fusion_instruction1 = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); + EXPECT_EQ(HloOpcode::kMultiply, + fusion_instruction1->fused_expression_root()->opcode()); + // There should be 5 fused instructions in the root fusion instruction: 2 + // parameters, multiply, floor, and exp. + EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + << fusion_instruction1->fused_instructions_computation()->ToString(); + + auto fusion_instruction2 = reduce->operand(0); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); + EXPECT_EQ(HloOpcode::kReshape, + fusion_instruction2->fused_expression_root()->opcode()); + // There should be 5 fused instructions in the second fusion instruction: 1 + // parameter, negate, ceil, concat, and reshape. + EXPECT_EQ(5, fusion_instruction2->fused_instruction_count()) + << fusion_instruction2->fused_instructions_computation()->ToString(); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, + *result, error_spec_); +} + +TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { + // Test that the operands of an instruction to be fused are considered in the + // proper order to avoid duplication. Test input: + // + // constant = {...} + // negate = neg(constant) + // ceil = ceil(negate) + // add1 = add(negate, ceil) + // add2 = add(ceil, negate) + // + // In this example, the operands of both add1 and add2 should be fused in the + // order {ceil, negate} even though they have different orders in their + // operand vectors. Test for this problem by counting the number of nodes in + // each fusion instruction to ensure that negate is not duplicated. + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({1.0, 2.0, 3.0}); + Shape vshape = input_literal->shape(); + + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, constant)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, negate, ceil)); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, ceil, negate)); + + // Tie together the two adds with a tuple to create a single root. + auto result = + builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); + + // Create computation and module. + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Run fusion. + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + auto fusion1 = result->operand(0); + auto fusion2 = result->operand(1); + EXPECT_EQ(HloOpcode::kFusion, fusion1->opcode()); + EXPECT_EQ(HloOpcode::kFusion, fusion2->opcode()); + + // Each fusion instruction should have 4 fused instruction inside: add, ceil, + // negate, and the fused parameter. + EXPECT_EQ(4, fusion1->fused_instruction_count()); + EXPECT_EQ(4, fusion2->fused_instruction_count()); + + // Each fusion instruction should have one parameter and the parameter should + // be the constant. + EXPECT_EQ(1, fusion1->operand_count()); + EXPECT_EQ(constant, fusion1->operand(0)); + EXPECT_EQ(1, fusion2->operand_count()); + EXPECT_EQ(constant, fusion2->operand(0)); +} + +TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { + // Verify that expensive operations will not be fused if the fusion results in + // duplication. Test code: + // + // constant = 42.0 + // exp1 = exp(constant) + // negate1 = negate(exp1) + // exp2 = exp(constant) + // negate2 = negate(exp2) + // tuple = tuple(negate1, negate2, exp2) + // + // exp1 should be fused down into negate1, but exp2 will not be fused into + // negate2 because this will result in duplication of the expensive exp + // computation. The duplication is caused by the other use of exp2 in the + // tuple. + auto builder = HloComputation::Builder(TestName()); + auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + Shape shape = constant->shape(); + + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp1)); + + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant)); + auto negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp2)); + + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({negate1, negate2, exp2})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The only fusion instruction should be operand 0 of the tuple (formerly + // negate1). + EXPECT_EQ(HloOpcode::kFusion, tuple->operand(0)->opcode()); + EXPECT_EQ(HloOpcode::kNegate, tuple->operand(1)->opcode()); + EXPECT_EQ(HloOpcode::kExp, tuple->operand(2)->opcode()); + + auto fusion_inst = tuple->operand(0); + // There should be three fused instructions: negate2, exp2, and the fused + // parameter. + EXPECT_EQ(3, fusion_inst->fused_instruction_count()); + EXPECT_EQ(1, fusion_inst->operand_count()); + EXPECT_EQ(constant, fusion_inst->operand(0)); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc new file mode 100644 index 00000000000000..dd63b998e9b6d0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -0,0 +1,294 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class InfeedTest : public ClientLibraryTestBase { + protected: + // Transfers the given literal to the infeed interface of the device, and + // check if the returned data from Infeed HLO is same as the literal. + void TestInfeedRoundTrip(const Literal& literal) { + // TODO(b/31037751) Explicitly reset the Infeed state so that the + // test is not affected by the state from the previous tests by + // adding ClearInfeed if necessary when it is implemented. For now + // don't use ResetDevice since it is not implemented on CPU. + ASSERT_IS_OK(client_->TransferToInfeed(literal)); + XlaBuilder builder(TestName()); + builder.Infeed(literal.shape()); + if (ShapeUtil::IsTuple(literal.shape())) { + // TODO(b/30609564): Use ComputeAndCompareLiteral instead. + ComputeAndCompareTuple(&builder, literal, {}); + } else { + ComputeAndCompareLiteral(&builder, literal, {}); + } + } +}; + +TEST_F(InfeedTest, SingleInfeedR0Bool) { + TestInfeedRoundTrip(*Literal::CreateR0(true)); +} + +TEST_F(InfeedTest, SingleInfeedR1U32) { + TestInfeedRoundTrip(*Literal::CreateR1({1, 2, 3})); +} + +TEST_F(InfeedTest, SingleInfeedR2F32) { + TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); +} + +TEST_F(InfeedTest, SingleInfeedR3F32) { + TestInfeedRoundTrip( + *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { + const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); + const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); + + TestInfeedRoundTrip( + *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); + + TestInfeedRoundTrip( + *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); +} + +TEST_F(InfeedTest, SingleInfeedR4S32) { + TestInfeedRoundTrip(*Literal::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +TEST_F(InfeedTest, SingleInfeedTuple) { + TestInfeedRoundTrip( + *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR0(false).get()})); +} + +TEST_F(InfeedTest, SingleInfeedEmptyTuple) { + TestInfeedRoundTrip(*Literal::MakeTuple({})); +} + +// Tests Infeed operation used in a while loop, as in the code below. The +// computation is launched asynchronously, and then infeed data is transferred. +// +// float acc = 0.0f; +// while (acc < 40.0f) { +// acc += reduce_add(Infeed()); +// } +// return acc; +// TODO(b/30671675) enable this test once asynchronous execution is +// implemented for CPU. +TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { + XlaBuilder builder(TestName()); + const auto infeed_shape = ShapeUtil::MakeShape(F32, {3}); + const auto result_shape = ShapeUtil::MakeShape(F32, {}); + + // Create a computation for the condition: repeat until (prev < 40.0f) holds. + XlaComputation condition; + { + XlaBuilder builder("condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(40.0f), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + // Create a computation for the body: add the reduced value of the Infeed + // data to the result variable. + XlaComputation body; + { + XlaBuilder builder("body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto infeed = builder.Infeed(infeed_shape); + auto addend = + builder.Reduce(infeed, builder.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + builder.Add(prev, addend); + body = builder.Build().ConsumeValueOrDie(); + } + // Create a While node with computations for the condition and the body. + auto init = builder.ConstantR0(0.0f); + builder.While(condition, body, init); + + // Build and asynchronously launch the computation. + auto computation = builder.Build().ConsumeValueOrDie(); + std::unique_ptr result; + tensorflow::Thread* computation_thread = + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions{}, "computation_thread", [&] { + result = client_->Execute(computation, {}, &execution_options_) + .ValueOrDie(); + }); + + // Send 5 Infeed data of shape F32[3]. + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({1, 2, 3}))); + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({4, 5, 6}))); + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*Literal::CreateR1({10, 11, 12}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*Literal::CreateR1({13, 14, 15}))); + + delete computation_thread; // Joins the thread. + auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); + + // Only the first 3 infeed data should be added. + LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); +} + +// Tests two Infeed operations with a total order. The order is enforced by +// using the result of the first while loop as the initial value of the second +// while loop. The shapes of both Infeeds are Tuples, where the first tuple +// element (R1F32) is for the data to reduce and accumulate, and the second +// tuple element (PRED) to indicate whether the loop should continue. The +// computation is launched asynchronously, and then infeed data is transferred. +// +// float acc = 0.0f; +// continue = true; +// while (!continue) { +// (data, continue) = Infeed(shape1); +// acc += reduce_add(data) +// } +// continue = true; +// while(!continue) { +// (data, continue) = Infeed(shape2); +// acc += reduce_add(data) +// } +// return acc; +// TODO(b/30671675) enable this test once asynchronous execution is +// implemented for CPU. +TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { + XlaBuilder builder(TestName()); + const auto infeed1_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(PRED, {})}); + const auto infeed2_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(PRED, {})}); + const auto result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(PRED, {})}); + + // Create a computation for the condition: repeat until the second tuple + // element is false. + XlaComputation condition; + { + XlaBuilder builder("condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.GetTupleElement(prev, 1); + condition = builder.Build().ConsumeValueOrDie(); + } + + // A lambda that builds the body computation of a while loop with the given + // infeed shape, and returns the computation with the ownership. + // + // The body adds the reduced value of the Infeed data (first tuple element) + // to the previous accumulator, and returns the accumulator and the continue + // flag (second tuple element) as a tuple. + const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + XlaComputation body; + XlaBuilder builder("body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto infeed = builder.Infeed(infeed_shape); + auto addend = builder.Reduce( + builder.GetTupleElement(infeed, 0), builder.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + auto result = builder.Add(builder.GetTupleElement(prev, 0), addend); + builder.Tuple({result, builder.GetTupleElement(infeed, 1)}); + return builder.Build().ConsumeValueOrDie(); + }; + + // Create the first while loop with infeed1_shape. + auto init = builder.Tuple( + {builder.ConstantR0(0.0f), builder.ConstantR0(true)}); + auto while1 = builder.While(condition, build_body(infeed1_shape), init); + auto result1 = builder.Tuple( + {builder.GetTupleElement(while1, 0), builder.ConstantR0(true)}); + + // Create the second while loop with infeed2_shape. Note that the result from + // the first while loop is used as the initial value. + auto while2 = builder.While(condition, build_body(infeed2_shape), result1); + builder.GetTupleElement(while2, 0); + + // Build the computation. + auto computation = builder.Build().ConsumeValueOrDie(); + + // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({3, 4}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({5, 6}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({7, 8}).get(), + Literal::CreateR0(false).get()}))); + + // Asynchronously launch the execution on the device. + std::unique_ptr result; + tensorflow::Thread* computation_thread = + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions{}, "computation_thread", [&] { + result = client_->Execute(computation, {}, &execution_options_) + .ValueOrDie(); + }); + + // Wait for a second to ensure testing that the execution is waiting on the + // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). + sleep(1); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({7, 8, 9}).get(), + Literal::CreateR0(false).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({4, 5, 6}).get(), + Literal::CreateR0(true).get()}))); + + // Wait for the execution to be done, and transfer the result. + delete computation_thread; // Joins the thread. + auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); + + // Only the first 6 infeed data should be added. + LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc new file mode 100644 index 00000000000000..973aac8766f5aa --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +const char* const kTriple_x86_64 = "x86_64-pc-linux"; +const char* const kTriple_android_arm = "armv7-none-android"; + +struct IntrinsicTestSpec { + HloOpcode opcode; + tensorflow::StringPiece triple; + tensorflow::StringPiece features; + tensorflow::StringPiece check_lines; +}; + +// Tests that unary functions get lowered using intrinsic calls. +class CpuUnaryIntrinsicTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static string Name(const ::testing::TestParamInfo& info) { + auto spec = info.param; + + string opcode = HloOpcodeString(spec.opcode); + opcode[0] = toupper(opcode[0]); + + string triple{spec.triple.data(), spec.triple.size()}; + if (triple == kTriple_x86_64) { + triple = "x86_64"; + } else if (triple == kTriple_android_arm) { + triple = "android_arm"; + } else { + triple = "Unknown"; + } + + string features{spec.features.data(), spec.features.size()}; + if (!features.empty()) { + std::replace_if(features.begin(), features.end(), + [](char c) { return c != '_' && !isalnum(c); }, '_'); + } else { + features = ""; + } + + return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), + features.empty() ? "" : "_With", + features.c_str()); + } +}; + +// Creates a module with a call to the unary op, and tests if the +// compiler replaced it with a call to the intrinsic. +TEST_P(CpuUnaryIntrinsicTest, DoIt) { + HloComputation::Builder builder(TestName()); + IntrinsicTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(F32, {1024}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + builder.AddInstruction( + HloInstruction::CreateUnary(param_shape, spec.opcode, param)); + std::unique_ptr computation = builder.Build(); + + string triple{spec.triple.data(), spec.triple.size()}; + string features{spec.features.data(), spec.features.size()}; + + CpuAotCompilationOptions options{ + /*triple=*/triple, /*cpu_name=*/"", /*features=*/features, + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + string check_lines{spec.check_lines.data(), spec.check_lines.size()}; + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, check_lines, + /*match_optimized_ir=*/true); +} + +IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { + // The intrinsics are always inlined, so we match a line from it instead of + // a function call. + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_x86_64, "", + R"(CHECK: fmul fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_x86_64, "+avx", + R"(CHECK: fmul fast <8 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_android_arm, "+neon", + R"(CHECK: fmul fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_x86_64, "", + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_x86_64, "+avx", + R"(CHECK: fcmp fast uge <8 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_android_arm, "", + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_x86_64, "", + R"(CHECK: fadd fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_x86_64, "+avx", + R"(CHECK: fadd fast <8 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_android_arm, "", + R"(CHECK: fadd fast <4 x float> )"}}; + +INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation, + CpuUnaryIntrinsicTest, + ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), + CpuUnaryIntrinsicTest::Name); + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc new file mode 100644 index 00000000000000..27044b1d62027e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuDuplicateConstantsTest : public CpuCodegenTest {}; + +TEST_F(CpuDuplicateConstantsTest, RepeatedArrayConstants) { + // We use a while loop here to force the two constant HloInstructions to be in + // different computations. Otherwise the HLO optimizer itself CSEs them. + const string hlo_text = R"( +HloModule RepeatedConstants + +while_body { + arg_body = f32[2,3,2] parameter(0) + ROOT const = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) +} + +while_cond { + arg_cond = f32[2,3,2] parameter(0) + ROOT unknown = pred[] infeed() +} + +ENTRY main { + param = f32[2,3,2] parameter(0) + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body + + out0 = () outfeed(f32[2,3,2] const_a) + ROOT out1 = () outfeed(f32[2,3,2] const_b) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [12 x float] +CHECK-NOT: private constant [12 x float] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { + // We use a while loop here to force the two constant HloInstructions to be in + // different computations. Otherwise the HLO optimizer itself CSEs them. + const string hlo_text = R"( +HloModule RepeatedConstants + +while_body { + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) +} + +while_cond { + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT unknown = pred[] infeed() +} + +ENTRY main { + param = f32[2,3,2] parameter(0) + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body + + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [1 x float] +CHECK: private constant [2 x float] +CHECK-NOT: private constant [1 x float] +CHECK-NOT: private constant [2 x float] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc new file mode 100644 index 00000000000000..3b6b0ed7406561 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { + +class CpuNoAliasTest : public CpuCodegenTest {}; + +// Creates a simple HLO ir_module (runs concat(concat(x, y), x)), and then +// inspects the aliasing information for loads to its buffers. +TEST_F(CpuNoAliasTest, Concat) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + HloInstruction* concat1 = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1)); + HloInstruction* concat2 = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 6}), {concat1, param_x}, 1)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. + auto status_or_buffer_assn = BufferAssigner::Run( + hlo_module.get(), MakeUnique(hlo_module.get()), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; }); + ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); + + llvm::LLVMContext context; + llvm_ir::AliasAnalysis aa(*hlo_module, *status_or_buffer_assn.ValueOrDie(), + &context); + + // Construct an LLVM module containing loads that we annotate as being from + // the buffers in the HLO module. We'll inspect these loads to ensure that + // they have the expected alias information. + llvm::Module ir_module("test", context); + llvm::Function* func = llvm::cast( + ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); + llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); + llvm::IRBuilder<> ir_builder(bb); + auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); + llvm_ir::IrArray::Index zero2D({zero, zero}); + + llvm::ArrayType* array2d_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm::Type::getFloatTy(context), 100), 100); + + { + llvm::Value* param_x_val = + ir_module.getOrInsertGlobal("param_x", array2d_type); + llvm_ir::IrArray param_x_array(param_x_val, param_shape); + aa.AddAliasingInformationToIrArray(*param_x, ¶m_x_array); + param_x_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_param_x_array"); + } + + { + llvm::Value* concat1_val = + ir_module.getOrInsertGlobal("concat1", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 4}); + llvm_ir::IrArray concat1_array(concat1_val, shape); + aa.AddAliasingInformationToIrArray(*concat1, &concat1_array); + concat1_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_concat1_array"); + } + + { + llvm::Value* concat2_val = + ir_module.getOrInsertGlobal("concat2", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 6}); + llvm_ir::IrArray concat2_array(concat2_val, shape); + aa.AddAliasingInformationToIrArray(*concat2, &concat2_array); + concat2_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_concat2_array"); + } + + // Check the AA info in the loads. + const char* filecheck_pattern = R"( + CHECK: %read_param_x_array = load {{.*}} !noalias [[param_x_noalias:![0-9]+]] + CHECK: %read_concat1_array = load {{.*}} !alias.scope [[concat1_scope:![0-9]+]], !noalias [[concat1_noalias:![0-9]+]] + CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] + CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 + CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} + CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} + )"; + + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_match, + RunFileCheck(llvm_ir::DumpModuleToString(ir_module), filecheck_pattern)); + EXPECT_TRUE(filecheck_match); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc new file mode 100644 index 00000000000000..1ee279290b6fcf --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuOutfeedTest : public CpuCodegenTest {}; + +TEST_F(CpuOutfeedTest, OutfeedRoot) { + const string hlo_text = R"( +HloModule Outfeed + +ENTRY main { + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + + ROOT out = () outfeed(f32[2,3,2] const_a) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [12 x float] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 6479bf76aab581..edcaec584997b1 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -143,6 +143,12 @@ class VectorSupportLibrary { llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements, int64 scale) { + return ComputeOffsetPointer( + base_pointer, + ir_builder_->CreateMul(ir_builder_->getInt64(scale), offset_elements)); + } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, int64 offset_elements) { return ComputeOffsetPointer(base_pointer, diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index af48f4ab6e506d..cc1695b7f86380 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -25,7 +25,7 @@ namespace xla { // Creates an HloPassPipeline containing multiple HloPasses that can // despecialize an optimized HloModule. This is useful to run an HloModule -// optimized for one specfic platform on a different platform (undoing platform +// optimized for one specific platform on a different platform (undoing platform // specific passes) with matching numerics for comparison. // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 35db4fd2a22cc1..e228bb56bce8fe 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -29,7 +29,7 @@ StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} -StatusOr StreamExecutorMemoryAllocator::Allocate( +StatusOr StreamExecutorMemoryAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); @@ -40,22 +40,17 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, device_ordinal); } - return result; + return OwningDeviceMemory(result, device_ordinal, this); } -tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( - int device_ordinal, se::DeviceMemoryBase* mem) { - if (!mem->is_null()) { +Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, + se::DeviceMemoryBase mem) { + if (!mem.is_null()) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - // We make a local copy of 'mem' so the original is not zeroed out by the - // Deallocate() call below. This gives us a better chance of - // catching double-free bugs, since Deallocate silently succeeds for null - // values. - se::DeviceMemoryBase mem_copy(*mem); - stream_executor->Deallocate(&mem_copy); + stream_executor->Deallocate(&mem); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index da45c4d45a1c56..d87b86caf0d3ac 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -37,28 +38,29 @@ class DeviceMemoryAllocator { : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} + // Allocates memory on the device. + // + // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory + // must not be null. If size == 0, must return a null OwningDeviceMemory. + // // 'retry_on_failure': If false, and the first attempt to allocate the memory // fails, the allocation should return immediately without retrying. An // example use case is optional scratch spaces where a failure has only // performance impact. - // - // Allocate() should return a null pointer for a size-0 allocation. - // Deallocate() must be a no-op for null pointers. - virtual StatusOr Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) = 0; + virtual StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) = 0; // Two-arg version of Allocate(), which sets retry-on-failure to true. // // (We don't simply use a default argument on the virtual Allocate function // because default args on virtual functions are disallowed by the Google // style guide.) - StatusOr Allocate(int device_ordinal, uint64 size) { + StatusOr Allocate(int device_ordinal, uint64 size) { return Allocate(device_ordinal, size, /*retry_on_failure=*/true); } - virtual tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) = 0; + // Must be a nop for null pointers. + virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. const se::Platform* platform() const { return platform_; } @@ -68,6 +70,7 @@ class DeviceMemoryAllocator { virtual bool AllowsAsynchronousDeallocation() const = 0; protected: + friend class OwningDeviceMemory; const se::Platform* platform_; }; @@ -79,14 +82,13 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { const se::Platform* platform, tensorflow::gtl::ArraySlice stream_executors); - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; // Pull in two-arg overload that sets retry_on_failure to true. using DeviceMemoryAllocator::Allocate; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 0528b076027603..64678d9d745097 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -138,6 +138,9 @@ class DfsHloVisitorBase { virtual Status HandleExp(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleExpm1(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleFloor(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -150,6 +153,9 @@ class DfsHloVisitorBase { virtual Status HandleClz(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleLog1p(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCos(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -191,6 +197,10 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + virtual Status HandleDomain(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 38b5efa9fb2cdb..9a8bab353ef6b1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -418,8 +418,12 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value); + case HloOpcode::kExpm1: + return EmitExpm1(op->shape().element_type(), operand_value); case HloOpcode::kLog: return EmitLog(op->shape().element_type(), operand_value); + case HloOpcode::kLog1p: + return EmitLog1p(op->shape().element_type(), operand_value); case HloOpcode::kCos: return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: @@ -493,6 +497,22 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex( op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } + case HloOpcode::kLog1p: { + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto one = llvm::ConstantFP::get(llvm_ty, 1.0); + auto a_plus_one = ir_builder_->CreateFAdd(a, one); + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(a_plus_one, a_plus_one), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -523,6 +543,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } + case HloOpcode::kExpm1: { + // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); + auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); + auto real_result = + ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); + auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + return EmitComposeComplex(op, real_result, imag_result); + } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) @@ -975,6 +1009,28 @@ StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto negative_half = llvm::ConstantFP::get(type, -0.5); + // When x is large, the naive evaluation of ln(x + 1) is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto for_large_x, + EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. + auto for_small_x = ir_builder_->CreateFMul( + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), + x); + const auto kAntilogarithmIsSmallThreshold = 1e-4; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -993,6 +1049,29 @@ StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto half = llvm::ConstantFP::get(type, 0.5); + // When the exponent is large, the naive evaluation of e^(x) - 1 is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); + auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. + // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. + auto x_squared = ir_builder_->CreateFAdd(x, x); + auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); + auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + const auto kExponentIsSmallThreshold = 1e-5; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { @@ -1344,6 +1423,482 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( }; } +StatusOr ElementalIrEmitter::EmitElementalSelect( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + return ir_builder_->CreateSelect( + ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), + on_true_value, on_false_value); +} + +StatusOr ElementalIrEmitter::EmitElementalClamp( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * min_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * max_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } +} + +StatusOr ElementalIrEmitter::EmitElementalConcatenate( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const { + const int64 concat_dim = hlo->dimensions(0); + auto source_index = target_index; + + llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(), + AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); + llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); + auto prior_insert_point = ir_builder_->GetInsertPoint(); + + ir_builder_->SetInsertPoint(init_block); + + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + const HloInstruction* operand = hlo->operand(operand_idx); + auto true_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand", operand_idx), + ir_builder_); + auto false_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_not_from_operand", operand_idx), + ir_builder_); + auto concat_dim_size = + llvm::ConstantInt::get(source_index[concat_dim]->getType(), + operand->shape().dimensions(concat_dim)); + ir_builder_->CreateCondBr( + ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size), + true_block, false_block); + + // Create the terminator of the true block before calling operand + // generators, because they require non-degenerate basic blocks. + ir_builder_->SetInsertPoint( + llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(source_index)); + output->addIncoming(value, ir_builder_->GetInsertBlock()); + + // Subtract the size of the concat dimension of the current operand + // from the source index. + ir_builder_->SetInsertPoint(false_block); + source_index[concat_dim] = + ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + } + + ir_builder_->CreateUnreachable(); + ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + return output; +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + // Emit IR to read dynamic start indices from hlo->operand(1). + const HloInstruction* input_hlo = hlo->operand(0); + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + + // Clamp the start index so that the sliced portion fits in the operand: + // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* operand_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(operand_dim_size, output_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + } + + llvm_ir::IrArray::Index input_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR which computes: + // input_index = start_index + offset_index + input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); + } + return operand_to_generator.at(input_hlo)(input_index); +} + +StatusOr ElementalIrEmitter::EmitElementalGather( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const Shape& operand_shape = hlo->operand(0)->shape(); + const Shape& indices_shape = hlo->operand(1)->shape(); + const Shape& output_shape = hlo->shape(); + + const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers(); + + const llvm_ir::ElementGenerator& operand_generator = + operand_to_generator.at(hlo->operand(0)); + const llvm_ir::ElementGenerator& indices_generator = + operand_to_generator.at(hlo->operand(1)); + + // This is the index into `operand` that holds the element we want to + // generate. This index "unsafe" as in the components in here may be + // out of bounds. + IrArray::Index unsafe_operand_index; + + // First copy in the window indices to unsafe_operand_index. + for (int64 i = 0, e = operand_shape.dimensions_size(), + unsafe_operand_index_dim = 0; + i < e; i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + } else { + unsafe_operand_index.push_back( + index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + } + } + + // This is the index of the index vector in the gather_indices tensor. + IrArray::Index gather_index_index; + { + std::vector gather_index_index_components; + for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + gather_index_index.push_back(index[i]); + } + } + + if (gather_index_index.size() != indices_shape.dimensions_size()) { + gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + } + } + + auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, + int64 dim) { + llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( + index_component, ir_builder_->getInt64Ty()); + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = + ir_builder_->CreateAdd( + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], + gather_dim_component_extended); + }; + + if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, 0); + } else { + int64 index_vector_size = + indices_shape.dimensions(dim_numbers.index_vector_dim()); + for (int64 i = 0; i < index_vector_size; i++) { + gather_index_index[dim_numbers.index_vector_dim()] = + ir_builder_->getInt64(i); + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, i); + } + } + + IrArray::Index safe_operand_index; + for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { + safe_operand_index.push_back(ir_builder_->CreateURem( + unsafe_operand_index[i], + ir_builder_->getInt64(operand_shape.dimensions(i)))); + } + + return operand_generator(safe_operand_index); +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const HloInstruction* input_hlo = hlo->operand(0); + const HloInstruction* update_hlo = hlo->operand(1); + const HloInstruction* start_hlo = hlo->operand(2); + // Calculate slice start/end indices. + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + llvm_ir::IrArray::Index slice_limit_index(rank); + // Slice intersection gathers (ANDs) conditions on all ranks for which + // 'input' is set to 'update' + llvm::Value* slice_intersection = ir_builder_->getTrue(); + + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(start_hlo)(dim_index)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* input_dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + slice_limit_index[i] = + ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); + } + + // Emit: + // if (slice_intersection) -> return data from 'update'. + // else -> return data from 'input'. + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "ret_value_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + slice_intersection, "slice_intersection", ir_builder_); + + // Handle true BB (return data from 'update') + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + // Compute update index for intersection case. + llvm_ir::IrArray::Index update_index(rank); + for (int64 i = 0; i < rank; ++i) { + update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); + } + TF_ASSIGN_OR_RETURN(llvm::Value * true_value, + operand_to_generator.at(update_hlo)(update_index)); + ir_builder_->CreateStore(true_value, ret_value_addr); + + // Handle false BB (return data from 'input') + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * false_value, + operand_to_generator.at(input_hlo)(index)); + ir_builder_->CreateStore(false_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalPad( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const { + auto index = padded_index; + llvm::Value* in_bounds = ir_builder_->getTrue(); + for (size_t i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalDot( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const { + auto lhs_generator = operand_to_generator.at(hlo->operand(0)); + auto rhs_generator = operand_to_generator.at(hlo->operand(1)); + + const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); + int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0); + int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); + + int64 contracted_dim_size = + hlo->operand(0)->shape().dimensions(lhs_contracting_dim); + int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); + int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); + + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( + IrName(hlo, "inner"), ir_builder_->getInt64(0), + ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), + ir_builder_); + + SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); + PrimitiveType primitive_type = hlo->shape().element_type(); + llvm::Type* primitive_type_llvm = + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); + llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( + primitive_type_llvm, "dot_acc", ir_builder_); + ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), + accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); + + // This is the inner reduction loop for a dot operation that produces + // one element in the output. If the operands to the dot operation have + // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. + // Given an output index [a,b,c,d,e] in the result, we compute: + // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) + + IrArray::Index lhs_index, rhs_index; + + for (int64 i = 0; i < lhs_dims - 1; i++) { + lhs_index.push_back(dot_result_index[i]); + } + lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); + + for (int64 i = 0; i < rhs_dims - 1; i++) { + rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); + } + rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); + + llvm::Value* current_accumulator = + ir_builder_->CreateLoad(accumulator_alloca); + TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); + TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); + llvm::Value* next_accumulator; + if (primitive_util::IsComplexType(primitive_type)) { + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); + next_accumulator = ir_builder_->CreateInsertValue( + current_accumulator, + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), + {0}); + next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator, + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), + {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { + next_accumulator = ir_builder_->CreateFAdd( + current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); + } else { + next_accumulator = ir_builder_->CreateAdd( + current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value)); + } + ir_builder_->CreateStore(next_accumulator, accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); + return ir_builder_->CreateLoad(accumulator_alloca); +} + llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) @@ -1358,10 +1913,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: @@ -1411,43 +1968,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSelect: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - return ir_builder_->CreateSelect( - ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), - on_true_value, on_false_value); + return EmitElementalSelect(hlo, operand_to_generator, index); }; case HloOpcode::kClamp: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - PrimitiveType prim_type = hlo->shape().element_type(); - if (primitive_util::IsFloatingPointType(prim_type)) { - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); - } else if (primitive_util::IsIntegralType(prim_type)) { - bool is_signed = primitive_util::IsSignedIntegralType(prim_type); - return EmitIntegralMin( - max_value, EmitIntegralMax(min_value, arg_value, is_signed), - is_signed); - } else { - return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); - } + return EmitElementalClamp(hlo, operand_to_generator, index); }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( @@ -1460,70 +1986,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { - const int64 concat_dim = hlo->dimensions(0); - auto source_index = target_index; - - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), - init_block->getTerminator() == nullptr); - - llvm::BasicBlock* exit_block; - if (ir_builder_->GetInsertPoint() == init_block->end()) { - exit_block = llvm_ir::CreateBasicBlock( - /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); - } else { - exit_block = init_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); - init_block->getTerminator()->eraseFromParent(); - } - - llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); - llvm::PHINode* output = - ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_), - hlo->operands().size()); - auto prior_insert_point = ir_builder_->GetInsertPoint(); - - ir_builder_->SetInsertPoint(init_block); - - for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); - ++operand_idx) { - const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), - ir_builder_); - auto false_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_not_from_operand", operand_idx), - ir_builder_); - auto concat_dim_size = - llvm::ConstantInt::get(source_index[concat_dim]->getType(), - operand->shape().dimensions(concat_dim)); - ir_builder_->CreateCondBr( - ir_builder_->CreateICmpULT(source_index[concat_dim], - concat_dim_size), - true_block, false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - ir_builder_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, ir_builder_->GetInsertBlock()); - - // Subtract the size of the concat dimension of the current operand - // from the source index. - ir_builder_->SetInsertPoint(false_block); - source_index[concat_dim] = - ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); - } - - ir_builder_->CreateUnreachable(); - ir_builder_->SetInsertPoint(exit_block, prior_insert_point); - return output; + return EmitElementalConcatenate(hlo, operand_to_generator, + target_index); }; case HloOpcode::kReverse: return [this, hlo, &operand_to_generator]( @@ -1559,184 +2023,19 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kDynamicSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - // Emit IR to read dynamic start indices from hlo->operand(1). - const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; - } + return EmitElementalDynamicSlice(hlo, operand_to_generator, index); + }; - llvm_ir::IrArray::Index input_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); - } - return operand_to_generator.at(input_hlo)(input_index); + case HloOpcode::kGather: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + return EmitElementalGather(hlo, operand_to_generator, index); }; case HloOpcode::kDynamicUpdateSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - const HloInstruction* input_hlo = hlo->operand(0); - const HloInstruction* update_hlo = hlo->operand(1); - const HloInstruction* start_hlo = hlo->operand(2); - // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - - // Slice intersection gathers (ANDs) conditions on all ranks for which - // 'input' is set to 'update' - llvm::Value* slice_intersection = ir_builder_->getTrue(); - - for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); - - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); - slice_limit_index[i] = - ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, - ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - index_wraps, "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = - ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, - "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, - ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, - ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), - 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, - if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming( - slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; - } - - // Emit: - // if (slice_intersection) -> return data from 'update'. - // else -> return data from 'input'. - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "ret_value_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - slice_intersection, "slice_intersection", ir_builder_); - - // Handle true BB (return data from 'update') - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); - } - TF_ASSIGN_OR_RETURN(llvm::Value * true_value, - operand_to_generator.at(update_hlo)(update_index)); - ir_builder_->CreateStore(true_value, ret_value_addr); - - // Handle false BB (return data from 'input') - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * false_value, - operand_to_generator.at(input_hlo)(index)); - ir_builder_->CreateStore(false_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator, + index); }; case HloOpcode::kBitcast: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), @@ -1765,155 +2064,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRng: return MakeRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: - return [=, &operand_to_generator]( + return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { - auto index = padded_index; - llvm::Value* in_bounds = ir_builder_->getTrue(); - for (size_t i = 0; i < index.size(); ++i) { - auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); - }; - const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpEQ( - index_typed_const(0), - ir_builder_->CreateURem( - index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = ir_builder_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); - } - - // if (in_bounds) { - // ret_value = operand0[index]; // source - // } else { - // ret_value = *operand1; // padding - // } - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "pad_result_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - // Don't create phi(operand_value, padding_value) here, because invoking - // operand_to_generator may create new basic blocks, making the parent - // of operand_value or padding_value no longer a predecessor of - // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalPad(hlo, operand_to_generator, padded_index); }; case HloOpcode::kDot: - return [=, &operand_to_generator](const IrArray::Index& dot_result_index) + return [this, hlo, + &operand_to_generator](const IrArray::Index& dot_result_index) -> StatusOr { - auto lhs_generator = operand_to_generator.at(hlo->operand(0)); - auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); - int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); - int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - - std::unique_ptr inner_loop = - llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), - ir_builder_->getInt64(1), ir_builder_); - - SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), - ir_builder_); - PrimitiveType primitive_type = hlo->shape().element_type(); - llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); - llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( - primitive_type_llvm, "dot_acc", ir_builder_); - ir_builder_->CreateStore( - llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); - - // This is the inner reduction loop for a dot operation that produces - // one element in the output. If the operands to the dot operation have - // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. - // Given an output index [a,b,c,d,e] in the result, we compute: - // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - - IrArray::Index lhs_index, rhs_index; - - for (int64 i = 0; i < lhs_dims - 1; i++) { - lhs_index.push_back(dot_result_index[i]); - } - lhs_index.push_back(inner_loop->GetIndVarValue()); - - for (int64 i = 0; i < rhs_dims - 2; i++) { - rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); - } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); - - llvm::Value* current_accumulator = - ir_builder_->CreateLoad(accumulator_alloca); - TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); - TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); - llvm::Value* next_accumulator; - if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value))); - next_accumulator = ir_builder_->CreateInsertValue( - current_accumulator, - ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), - product_real), - {0}); - next_accumulator = ir_builder_->CreateInsertValue( - next_accumulator, - ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), - product_imag), - {1}); - } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = ir_builder_->CreateFAdd( - current_accumulator, - ir_builder_->CreateFMul(lhs_value, rhs_value)); - } else { - next_accumulator = ir_builder_->CreateAdd( - current_accumulator, - ir_builder_->CreateMul(lhs_value, rhs_value)); - } - ir_builder_->CreateStore(next_accumulator, accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accumulator_alloca); + return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index c516a826d9e382..d199473374ad39 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -105,6 +105,9 @@ class ElementalIrEmitter { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const; @@ -114,6 +117,9 @@ class ElementalIrEmitter { virtual StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const; @@ -142,6 +148,46 @@ class ElementalIrEmitter { return ir_builder_->getIntN(128, 0); } + StatusOr EmitElementalSelect( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalClamp( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalConcatenate( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const; + + StatusOr EmitElementalDynamicSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalGather( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalPad( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const; + + StatusOr EmitElementalDot( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const; + llvm::IRBuilder<>* const ir_builder_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc new file mode 100644 index 00000000000000..8980d4303353a1 --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ElementalIrEmitterExecutionTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { + const string hlo_text = R"( +HloModule FusedDot + +fused_computation { + arg0 = s32[1,2,1]{2,1,0} parameter(0) + reshape.lhs = s32[2,1]{1,0} reshape(arg0) + arg1 = s32[1,2,1]{2,1,0} parameter(1) + reshape.rhs = s32[2,1]{1,0} reshape(arg1) + ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY main { + entry_arg0 = s32[1,2,1]{2,1,0} parameter(0) + entry_arg1 = s32[1,2,1]{2,1,0} parameter(1) + ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation +} +)"; + + std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {lhs.get(), rhs.get()}); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 021f09d310b718..6df172db8e541c 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -129,38 +129,17 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } -Status Executable::DumpSessionModule() { - TF_RET_CHECK(dumping()); +Status Executable::DumpHloSnapshot() { + TF_RET_CHECK(dumping_snapshot()); + TF_RET_CHECK(hlo_snapshot_->has_hlo() && + hlo_snapshot_->hlo().has_hlo_module()); const string& directory_path = module_config().debug_options().xla_dump_executions_to(); - VersionedComputationHandle versioned_handle = entry_computation_handle(); - // This filename does not include the version number because the computation - // is only ever executed at one version. + const auto& module = hlo_snapshot_->hlo().hlo_module(); string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), - session_module_->entry().name().c_str(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, - *session_module_); -} - -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const SessionModule& session_module) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(session_module, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); + "computation_%lld__%s__execution_%lld", module.id(), + module.entry_computation_name().c_str(), ++execution_count_); + return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } /* static */ Status Executable::DumpToDirectory( diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index f7af1ca5749297..087bd1432945ab 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" @@ -140,21 +139,17 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& result_shape() const { - return hlo_module_->config().entry_computation_layout().result_shape(); + const Shape& host_result_shape() const { + return hlo_module_->config().host_entry_computation_layout().result_shape(); } // Dumping helpers. - void set_session_module(std::unique_ptr session_module) { - session_module_ = std::move(session_module); + void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { + hlo_snapshot_ = std::move(hlo_snapshot); } - bool dumping() const { return session_module_ != nullptr; } - SessionModule* session_module() const { return session_module_.get(); } - Status DumpSessionModule(); - - // Dump session_module to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const SessionModule& session_module); + bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } + HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } + Status DumpHloSnapshot(); // Dump hlo snapshot to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, @@ -171,8 +166,8 @@ class Executable { // around. const std::unique_ptr hlo_module_; - // SessionModule this was compiled from. Null if not dumping executions. - std::unique_ptr session_module_; + // HloSnapshot this was compiled from. Null if not dumping executions. + std::unique_ptr hlo_snapshot_; // Execution count, used to generate a unique filename for each dumped // execution. diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 2f0b9ed2bd98fb..6794cfe297b0fb 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend, } } -tensorflow::Status AsyncExecution::BlockUntilDone() const { +Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } - return tensorflow::Status::OK(); + return Status::OK(); } ExecutionTracker::ExecutionTracker() : next_handle_(1) {} @@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register( return execution_handle; } -tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { +Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { @@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { handle.handle()); } handle_to_execution_.erase(handle.handle()); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr ExecutionTracker::Resolve( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 5b6bddf9f16a85..4458152dd9a988 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -43,7 +43,7 @@ class AsyncExecution { AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); - tensorflow::Status BlockUntilDone() const; + Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -77,7 +77,7 @@ class ExecutionTracker { GlobalDataHandle data); // Unregisters the execution for the given handle. - tensorflow::Status Unregister(const ExecutionHandle& handle); + Status Unregister(const ExecutionHandle& handle); // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md similarity index 100% rename from tensorflow/compiler/xla/tools/parser/README.md rename to tensorflow/compiler/xla/service/g3doc/hlo_parser.md diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 1c72ca066502eb..020ffcd106862c 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -36,7 +36,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); Status status = GatherExpander{}.Run(module.get()).status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); @@ -63,7 +63,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); ASSERT_TRUE(changed); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index ddb687314ee822..5ee67ccb4ae147 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -89,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( } Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -115,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const auto subliteral = LiteralView::Create(literal, index); + const auto subliteral = LiteralSlice(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), @@ -137,7 +137,7 @@ Status GenericTransferManager::TransferLiteralToDevice( } Status GenericTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { + se::StreamExecutor* executor, const LiteralSlice& literal) { return Unimplemented("Generic transfer to Infeed"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 0579099de40ba3..3da9570ef7eebc 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -45,11 +45,11 @@ class GenericTransferManager : public TransferManager { se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9a5ad2807591f0..16ab2d78c9cf45 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,8 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -27,6 +29,11 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +xla_proto_library( + name = "backend_configs", + srcs = ["backend_configs.proto"], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -137,6 +144,7 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -273,6 +281,7 @@ cc_library( ] + if_cuda_is_configured(if_cuda(["nvptx_executable.h"])) + if_rocm_is_configured(if_rocm(["amdgpu_executable.h"])), deps = [ + ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", ":infeed_manager", @@ -298,6 +307,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", @@ -329,6 +339,7 @@ cc_library( srcs = ["cudnn_convolution_algorithm_picker.cc"], hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", @@ -345,6 +356,7 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -396,8 +408,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) @@ -406,10 +420,13 @@ tf_cc_test( srcs = ["instruction_fusion_test.cc"], deps = [ ":instruction_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -451,9 +468,9 @@ tf_cc_test( ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -549,6 +566,8 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -595,14 +614,18 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ + ":gpu_options", ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -631,6 +654,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", @@ -698,6 +722,27 @@ cc_library( ], ) +cc_library( + name = "gpu_options", + srcs = ["gpu_options.cc"], + hdrs = ["gpu_options.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + tf_cc_test( name = "gpu_hlo_support_checker_test", srcs = ["gpu_hlo_support_checker_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index 53471249d2dcbf..908140e3f076d3 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -100,7 +100,7 @@ namespace gpu { namespace { -using tensorflow::port::Tracing; +namespace tracing = tensorflow::tracing; // Returns the directory containing ROCm-Device-Libs files. This function is // called in AMDGPUCompiler's constructor, so can't return an error. But @@ -253,7 +253,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->mutable_device_entry_computation_layout(), stream_exec); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -323,7 +323,7 @@ StatusOr> AMDGPUCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("AMDGPUCompiler::RunHloPasses"); - Tracing::TraceMe annotation("HLO Transforms", module->name(), + tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( OptimizeHloModule(module.get(), stream_exec, device_allocator)); diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto new file mode 100644 index 00000000000000..640c6392b8b820 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package xla.gpu; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstrucitons and later +// uses during e.g. codegen. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + // Opaque algorithm number of cudnn algorithm chosen for this conv. + int64 algorithm = 1; + + // Whether we may use tensor cores when running this conv. Even if this is + // true, cudnn may choose not to use tensor cores, e.g. because the GPU or + // selected algorithm doesn't support it. + bool tensor_ops_enabled = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 837f05244f7a8c..ab5149dcdb0929 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -37,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, } StatusOr> BufferAllocations::Builder::Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { - const int64 num_buffers = buffer_assignment.Allocations().size(); - auto buffer_allocations = WrapUnique( - new BufferAllocations(num_buffers, device_ordinal, memory_allocator)); + const int64 num_buffers = buffer_assignment->Allocations().size(); + auto buffer_allocations = WrapUnique(new BufferAllocations( + num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { // If buffer #i's address is already registered (e.g. external arguments or @@ -62,28 +62,28 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( - device_ordinal, buffer_size)); - if (buffer_address == nullptr) { - return ResourceExhausted( - "Out of memory when allocating %s for buffer %lld.", - tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), - i); - } - if (reinterpret_cast(buffer_address.opaque()) % + OwningDeviceMemory buffer; + TF_ASSIGN_OR_RETURN( + buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); + if (reinterpret_cast(buffer.opaque()) % kCudaMallocAlignBytes != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer_address.opaque()); + kCudaMallocAlignBytes, buffer.opaque()); } + // We do manual memory management within BufferAllocations. Be sure not + // to do a TF_RETURN_IF_ERROR between this line and the + // buffer_allocations->SetBuffer(buffer_address) call below! + buffer_address = buffer.Forget(); } + buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { if (seen_temp_buffer) { @@ -103,28 +103,42 @@ StatusOr> BufferAllocations::Builder::Build( << "B)"; } } - return std::move(buffer_allocations); } -tensorflow::Status BufferAllocations::TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment) { - // Deallocate temporary buffers. - const int64 num_buffers = buffer_assignment.Allocations().size(); +BufferAllocations::~BufferAllocations() { + if (!torn_down_) { + // Presumably if we're executing this branch, the caller is in an error + // state, otherwise it would have explicitly called TearDown so it could + // save some set of live addresses. So ignoring any errors in TearDown is + // sensible. + TearDown(/*live_addresses=*/{}).IgnoreError(); + } +} + +Status BufferAllocations::TearDown( + const std::set& live_addresses) { + // Deallocate temporary buffers, taking care to try to deallocate all of them + // even if one of the deallocations fails. + Status status; + const int64 num_buffers = buffer_assignment_->Allocations().size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); // Deallocate buffers marked "maybe_live_out" but aren't actually live out, // and temp buffers. if ((allocation.maybe_live_out() && !live_addresses.count(buffer_address)) || allocation.IsPreallocatedTempBuffer()) { - TF_RETURN_IF_ERROR( - memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); + auto dealloc_result = + memory_allocator_->Deallocate(device_ordinal_, buffer_address); + if (!dealloc_result.ok() && status.ok()) { + status = dealloc_result; + } } } - return tensorflow::Status::OK(); + torn_down_ = true; + return status; } se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index c2fc35be4ca4bc..636623502597b3 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -48,13 +48,15 @@ class BufferAllocations { // `device_ordinal` is the number of the device this function allocates // memory on. StatusOr> Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator); private: std::map registered_buffers_; }; + ~BufferAllocations(); + BufferAllocations(const BufferAllocations&) = delete; BufferAllocations& operator=(const BufferAllocations&) = delete; @@ -76,16 +78,16 @@ class BufferAllocations { // Tears down all buffers allocated by this object that are not in // `live_addresses`. - tensorflow::Status TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment); + Status TearDown(const std::set& live_addresses); private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, - DeviceMemoryAllocator* memory_allocator) + DeviceMemoryAllocator* memory_allocator, + const BufferAssignment* buffer_assignment) : buffers_(buffer_count), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + buffer_assignment_(buffer_assignment) {} // Sets the device address of buffer `buffer_index`. void SetBuffer(BufferAllocation::Index buffer_index, @@ -100,8 +102,9 @@ class BufferAllocations { se::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; + const BufferAssignment* buffer_assignment_; + bool torn_down_ = false; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index dce8de2e301ecf..77a48965e03134 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -35,9 +35,10 @@ ConditionalThunk::ConditionalThunk( true_thunk_(std::move(true_thunk_sequence), hlo), false_thunk_(std::move(false_thunk_sequence), hlo) {} -Status ConditionalThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); - TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); +Status ConditionalThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index e40872688fdad2..ee03865d174469 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -47,7 +47,8 @@ class ConditionalThunk : public Thunk { ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 64d3b84b8c73d8..f0881124128c9b 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,11 +29,6 @@ namespace xla { namespace gpu { using se::dnn::AlgorithmDesc; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; ConvolutionThunk::ConvolutionThunk( CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index bf912fbd14de58..ee38c0318a878c 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -29,12 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( +Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( @@ -46,14 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( +Status DeviceToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 2e7eb5f3445bc9..8b128386f61636 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -39,8 +39,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const void* source_address_; @@ -62,8 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index c4c56c56928810..3dc98c4c93ea2b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -35,35 +36,22 @@ class ScratchAllocator : public se::ScratchAllocator { ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - ~ScratchAllocator() override; - int64 GetMemoryLimitInBytes(se::Stream* stream) override { return 1LL << 32; // 4GB. TODO(jlebar): Tune this? } int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; -ScratchAllocator::~ScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -se::port::StatusOr> ScratchAllocator::AllocateBytes( +StatusOr> ScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -74,19 +62,14 @@ se::port::StatusOr> ScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } // Determines whether we can safely perform a winograd non-fused convolution for @@ -197,22 +180,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // We don't put any data in these buffers, because (in theory, anyway) the // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - se::port::StatusOr input_buf = + StatusOr maybe_input_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(input_shape)); - se::port::StatusOr filter_buf = + StatusOr maybe_filter_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(filter_shape)); - se::port::StatusOr output_buf = + StatusOr maybe_output_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(output_shape)); - if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || + !maybe_output_buf.ok()) { LOG(WARNING) << "Couldn't allocate space for input/filter/output of convolution " << instr->ToString() << ". Falling back to default algorithm."; return nullopt; } + DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); + DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); + DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the inputs + // might affect performance if e.g. the inputs contain denormals, and this is + // easy enough. + if (!stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone() + .ok()) { + LOG(WARNING) + << "Couldn't zero out input/filter/output buffer for convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; @@ -225,12 +228,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - input_buf.ValueOrDie(), filter_buf.ValueOrDie(), - output_buf.ValueOrDie(), &scratch_allocator, window, - dnums, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); @@ -314,21 +317,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( Shape new_call_shape = ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - HloInstruction* algorithm_hlo = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); - HloInstruction* tensor_ops_enabled_hlo = - computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(tensor_ops_enabled))); + + CudnnConvBackendConfig backend_config; + backend_config.set_algorithm(algorithm); + backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, - tensor_ops_enabled_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1)}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( instr->convolution_dimension_numbers()); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely // (conv_result, u8[0]). diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 10b4c3de89989c..0645fbb3ad39f1 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -113,8 +115,17 @@ Status RunCudnnConvolution( // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) + input_descriptor.set_layout(input_dl) .set_feature_map_count( input_shape.dimensions(dnums.input_feature_dimension())) .set_count(input_shape.dimensions(dnums.input_batch_dimension())); @@ -126,7 +137,7 @@ Status RunCudnnConvolution( } FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( @@ -149,7 +160,7 @@ Status RunCudnnConvolution( } BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) + output_descriptor.set_layout(output_dl) .set_feature_map_count( output_shape.dimensions(dnums.output_feature_dimension())) .set_count(output_shape.dimensions(dnums.output_batch_dimension())); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2dc1fc4cd064a0..1ac1159f3afeee 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -230,6 +230,11 @@ StatusOr GpuElementalIrEmitter::EmitLog( return EmitROCDLMathCall("__ocml_log", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitROCDLMathCall("__ocml_log1p", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitSin( PrimitiveType prim_type, llvm::Value* value) const { return EmitROCDLMathCall("__ocml_sin", {value}, {prim_type}, prim_type); @@ -245,6 +250,11 @@ StatusOr GpuElementalIrEmitter::EmitExp( return EmitROCDLMathCall("__ocml_exp", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitROCDLMathCall("__ocml_expm1", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index ee257ebc01a942..b0d0fc5b77d8a9 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -64,6 +64,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const override; @@ -73,6 +76,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index cc747addbd152e..e14ee6918bf148 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -31,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -FftScratchAllocator::~FftScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } -se::port::StatusOr> FftScratchAllocator::AllocateBytes( +StatusOr> FftScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -58,18 +47,14 @@ se::port::StatusOr> FftScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate %lld bytes on device %d.", byte_size, - device_ordinal_); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } namespace { @@ -121,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type, input_shape_(input_shape), output_shape_(output_shape) {} -tensorflow::Status FftThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -222,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream( LOG(FATAL) << "unsupported fft type"; } if (launch_ok) { - return tensorflow::Status::OK(); + return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, FftTypeToString(fft_type_).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 24b1dca99865fe..b0a22564f3a09b 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -39,8 +39,6 @@ class FftScratchAllocator : public se::ScratchAllocator { FftScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator); - ~FftScratchAllocator() override; - int64 GetMemoryLimitInBytes(se::Stream* stream) override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } @@ -51,7 +49,7 @@ class FftScratchAllocator : public se::ScratchAllocator { private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; @@ -73,8 +71,8 @@ class FftThunk : public Thunk { FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ // Does the FFT for the thunk on "stream". - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const se::fft::Type fft_type_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 6e6966df3987ee..b36539e0cb8d0a 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -30,19 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit, body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); +Status ForThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + return Status::OK(); } -tensorflow::Status ForThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { for (int64 i = 0; i < loop_limit_; ++i) { // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index c78d1c50686297..41ddfe0ceb1d05 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,9 +36,10 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 2217776c7d5a5f..b22bb1d39ba177 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { @@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {}; // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule MergeSharedFusionInstruction comp.3 { @@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 { // // Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule FlopsToBytesRatioThresholdExceeded comp.2 { @@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { // is merged into Fusion0 and Fusion1) would exceed the bytes transferred // threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdExeceeded comp.2 { @@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 { // Fusion2 is reduced for this test which makes the merge operation into its // operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdNotExeceeded comp.2 { @@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { // Check that we're willing to merge f1_computation into f2_computation, even // though f2 is an input fusion node. TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m f1_computation { diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 0ec12f52d8b398..79fca43d022816 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -215,14 +215,32 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { + if (hlo_instruction.opcode() == HloOpcode::kDot) { + return hlo_instruction.dot_dimension_numbers(); + } + CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); + CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); + CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), + HloOpcode::kMultiply); + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = + hlo_instruction.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo_instruction.fused_expression_root()->operand(1); + } + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + return dot->dot_dimension_numbers(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, double alpha, + const Shape& output_shape, double alpha, const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), @@ -231,12 +249,10 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), alpha_(alpha) {} -tensorflow::Status GemmThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -284,10 +300,12 @@ tensorflow::Status GemmThunk::ExecuteOnStream( shape.dimensions(!is_row_major)); }; - const MatrixDescriptor lhs_descriptor = - make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); - const MatrixDescriptor rhs_descriptor = - make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + + const MatrixDescriptor lhs_descriptor = make_descriptor( + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + const MatrixDescriptor rhs_descriptor = make_descriptor( + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. @@ -350,7 +368,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (!launch_ok) { return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index a18f425bc38fd3..7a4830d64e7cae 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -35,22 +35,20 @@ namespace gpu { class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs rhs) * alpha" using - // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should - // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is - // a constant. + // BLAS gemm. hlo_instruction is as in Thunk. alpha is a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, - double alpha, const HloInstruction* hlo_instruction); + const Shape& output_shape, double alpha, + const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; // Does the gemm operation for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to @@ -69,8 +67,6 @@ class GemmThunk : public Thunk { const Shape rhs_shape_; const Shape output_shape_; - const bool transpose_lhs_; - const bool transpose_rhs_; const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 9db85bc788bde4..c5ccdd4a7dcec0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,14 +78,13 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } - } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last two arguments to a CUDNN convolution are two HLO constants for - // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } - } else if (ImplementedAsLibraryCall(*hlo)) { - // For all other library calls, materialize all the operands into memory. + } else if (ImplementedAsLibraryCall(*hlo) || + hlo->opcode() == HloOpcode::kCrossReplicaSum) { + // For all other library calls and cross-replica-sum, materialize all the + // operands into memory. (Cross-replica-sum gets its constant args + // materialized even if it's not implemented as a libcall to simplify the + // implementation. It's slower, but we can constant fold away constant + // args *anyway*, so we just need to make it work.) for (int64 i = 0; i < hlo->operand_count(); ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 97c10c9a5299f7..0196c4cf82457c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -32,12 +32,15 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { namespace { +using tensorflow::tracing::ScopedAnnotation; + // A helper class for profiling HLO in the course of GPU program execution. // All of the profiling is guarded internally, to avoid the caller needing to // have lots of conditionals sprinkled around. @@ -134,6 +137,7 @@ Status GpuExecutable::ExecuteThunks( CheckCompatibilityWithServiceExecutableRunOptions(run_options); se::Stream* main_stream = run_options->stream(); + se::StreamExecutor* executor = main_stream->parent(); bool do_profile = hlo_execution_profile != nullptr; if (do_profile) { @@ -145,21 +149,39 @@ Status GpuExecutable::ExecuteThunks( sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); - TF_ASSIGN_OR_RETURN( - sub_streams.back(), - run_options->BorrowStream(main_stream->parent()->device_ordinal())); + TF_ASSIGN_OR_RETURN(sub_streams.back(), + run_options->BorrowStream(executor->device_ordinal())); } HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, sub_streams, hlo_module_->entry_computation()); uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // The next event enqueued on stream N must not run until the thunk at - // last_blocking_thunk_for_stream[N] completes. - std::map last_blocking_thunk_for_stream; + // This top-level trace serves two purposes: + // 1) It marks the scope of the whole XLA module. + // 2) It tells us whether tracing is enabled. We use this to avoid the + // expensive HloInstruction::ToString() calls inside the loop below if + // tracing is disabled. + ScopedAnnotation top_level_annotation(hlo_module_->name(), "XLA GPU module"); + std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - TF_RETURN_IF_ERROR(thunk->Initialize(*this)); + // Annotate execution of this op if tracing was enabled when we started + // running this module. If tracing is enabled *while* we're running the + // module, we won't get any data, but that's probably an OK trade-off. + // + // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), + // since we expect it to be an expensive call? + tensorflow::gtl::optional op_annotation; + if (top_level_annotation.IsEnabled()) { + op_annotation.emplace( + thunk->hlo_instruction() != nullptr + ? thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()) + : "", + "XLA op"); + } + + TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -169,18 +191,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - if (last_blocking_thunk_for_stream.count(stream_no)) { - stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, - last_blocking_thunk_for_stream[stream_no]) - .get()); - last_blocking_thunk_for_stream.erase(stream_no); - } - // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); - last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -188,22 +202,11 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { + if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); - - if (thunk->ShouldBlockFutureThunks()) { - // Set last_blocking_thunk_for_stream on all streams other than this one - // so that all other streams will wait for this thunk to complete before - // executing any events that occur later in the total order. - for (int32 i = 0; i < sub_streams.size() + 1; ++i) { - if (i != stream_no) { - last_blocking_thunk_for_stream[i] = thunk; - } - } - } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -276,8 +279,8 @@ StatusOr GpuExecutable::ExecuteOnStream( se::StreamExecutor* executor = run_options->stream()->parent(); TF_ASSIGN_OR_RETURN( auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); + buffer_allocations_builder.Build( + assignment_.get(), executor->device_ordinal(), memory_allocator)); bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); @@ -319,8 +322,7 @@ StatusOr GpuExecutable::ExecuteOnStream( buffers_in_result.insert(src_base); return Status::OK(); })); - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(buffers_in_result, *assignment_)); + TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); return std::move(shaped_buffer); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 89f1e625884568..8bf62dde8b9948 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,31 +18,72 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -// cuDNN convolutions are called with specific layouts on the input, output, -// and filter: -// -// input: DataLayout::kBatchDepthYX -// output: DataLayout::kBatchDepthYX -// filter: FilterLayout::kOutputInputYX -// -// The order dimensions in the constant name is major-to-minor (eg, the -// most-major dimension of the input is batch, most-minor is X). The -// specific dimension numbers these named dimensions correspond to is -// determined by the ConvolutionDimensionNumbers argument. Y is spatial -// dimension 0, and X is spatial dimension 1. -// -// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. -static Status AddBackendConstraintsToDnnConvCustomCall( +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::FilterLayout; + +static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} + +// Returns (input, filter, output) layouts. +static std::tuple +HeuristicLayoutAssignment(const HloInstruction* instr, + stream_executor::StreamExecutor* stream_executor) { + // DataLayout and FilterLayout uses weird enum names. Translations: + // N <=> Batch or Output + // C <=> Depth or Input + // H <=> Y + // W <=> X + // + // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + + // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x + // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version + // changes, as well as the hardware updates. + if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && + IsVoltaOrLater(*stream_executor))) { + return std::make_tuple(DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); + // For BackwardInput that has stride, full NHWC layouts run significantly + // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + // + // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, + // NHWC). + if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); +} + +// Adds layout constraints on the cudnn custom-call instruction. The layout +// constraints are represented in terms of minor_to_major fields of both +// operands and the output shape. Depending on the underlying algorithm, one of +// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. +Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloInstruction* instr, LayoutConstraints* constraints) { CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); Shape input_shape; @@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall( << instr->custom_call_target(); } - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instr->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; - --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + { + DataLayout input; + FilterLayout filter; + DataLayout output; + if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); + } else { + input = DataLayout::kBatchDepthYX; + filter = FilterLayout::kOutputInputYX; + output = DataLayout::kBatchDepthYX; + } + + TF_ASSIGN_OR_RETURN( + std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), + *output_shape.mutable_layout()), + StreamExecutorConvLayoutsToXlaLayouts( + instr->convolution_dimension_numbers(), input, filter, output)); } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that @@ -132,7 +159,13 @@ static Status AddBackendConstraintsToDnnConvCustomCall( Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { + // Add convolution constraints in reverse postorder that the earliest + // convolution layout propagates first. This reduces the likelihood of fusion + // nodes with copies. + auto post_order = constraints->computation()->MakeInstructionPostOrder(); + for (auto iterator = post_order.rbegin(); iterator != post_order.rend(); + ++iterator) { + HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 86a3a7111fd794..ce24af1cf88569 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { @@ -27,8 +28,10 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + se::StreamExecutor* stream_executor) + : LayoutAssignment(entry_computation_layout), + stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} protected: @@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment { LayoutConstraints* constraints) override; bool CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) override; + + private: + Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints); + + se::StreamExecutor* stream_executor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4c45d2e94aebce..e48165c1426ea0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc new file mode 100644 index 00000000000000..35b4b4e20b6337 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { +namespace gpu { + +bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { + return !config.debug_options().xla_backend_extra_options().count( + "xla_gpu_experimental_conv_disable_layout_heuristic"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h new file mode 100644 index 00000000000000..498d4a94955cb2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +// Helper functions for querying options that are specific to the GPU backend. + +namespace xla { +namespace gpu { + +// Returns true if we should use heuristics to assign convolution layouts, as +// opposed to always assigning NCHW. +bool ConvUseLayoutHeuristic(const HloModuleConfig& config); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index b7e3d8cdf34d6e..51897b1c96a478 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -54,8 +54,8 @@ GpuTransferManager::GpuTransferManager(se::Platform::Id id) #endif .getPointerSize(0 /* default address space */)) {} -Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status GpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 1fff17f6071bc1..316fc4bf8f9503 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -37,7 +37,7 @@ class GpuTransferManager : public GenericTransferManager { ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) override; + const LiteralSlice& literal) override; Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, const void* source) override; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 42c1539e86c2ab..f766f968826d96 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -199,7 +200,7 @@ StatusOr> HloSchedule::Build( TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, CreateMemoryMinimizingSequence( - *entry_computation, [pointer_size](const LogicalBuffer& buffer) { + *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); } else { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index ece9fa04dce3fd..e230d538cc2df8 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -42,6 +42,15 @@ class HloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", VersionedComputationHandle(), + config); + } + HloVec RemoveHlo(const HloVec& input, const std::unordered_set& remove) { HloVec result(input); @@ -65,9 +74,9 @@ TEST_F(HloScheduleTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -193,11 +202,11 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -259,24 +268,24 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 3ddc1c0789d746..ae310beefad0c8 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -49,13 +49,25 @@ void InfeedManager::EnqueueBuffers(const std::vector& buffers) { } InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); + bool became_empty = false; + InfeedBuffer* current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + current_buffer = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + dequeued_buffer_.insert(current_buffer); + if (enqueued_buffer_.empty()) { + became_empty = true; + } + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } } - InfeedBuffer* current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); return current_buffer; } @@ -88,6 +100,10 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } +void InfeedManager::RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); +} + InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index d5f2216d460a45..a3fc15cfe36a49 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -100,6 +101,10 @@ class InfeedManager { // returns null. se::Stream* GetStream(se::StreamExecutor* executor); + // Registers a callback that will be called when 'enqueued_buffer_' becomes + // empty. + void RegisterOnEmptyCallback(std::function callback); + private: // TODO(b/30467474): Revisit if this mutex becomes a point of // contention. @@ -122,6 +127,10 @@ class InfeedManager { // Executor that the host_to_device_stream belongs to. Not owned. se::StreamExecutor* host_to_device_executor_; + + // List of callbacks which will be called when 'enqueued_buffer_' becomes + // empty. + std::vector> on_empty_callbacks_; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb3470..36a1b82a26d84f 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -46,41 +48,100 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot) { - if (consumer->opcode() == HloOpcode::kMultiply) { - CHECK_EQ(consumer->operand_count(), 2); - int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); - if (alpha->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(alpha->shape())) { + if (consumer->operand_count() == 2 && + (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(alpha)) { + return true; + } + } else if (consumer->opcode() == HloOpcode::kMultiply) { + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } } } - // Only allow to fuse transpose into an output fusion. + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { - if (producer->opcode() != HloOpcode::kTranspose) { - return false; - } - // Check that the transpose is the operand of a dot. + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = fused_parameter->users(); - return (fused_parameter_users.size() == 1 && - fused_parameter_users[0]->opcode() == HloOpcode::kDot); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } } - // Output fusion is not currently supported on GPUs. + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -116,12 +177,34 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, InstructionFusion::ShouldFuse(consumer, operand_index); } +bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + const HloInstruction* producer = consumer->operand(operand_index); + // The IR emitter has limited support for non-loop fusions with multi output + // at present. + // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Multi-output fusion requires instructions with compatible shapes. + if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { + return false; + } + // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for + // multi-output fusion. In particular, do not check whether an instruction is + // expensive to duplicate, since this doesn't matter here. + return GpuInstructionFusion::ShouldFuse(consumer, operand_index); +} + HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } - if (producer->opcode() == HloOpcode::kDot) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { return HloInstruction::FusionKind::kOutput; } if (HloOpcode::kFusion == consumer->opcode()) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6dfc9de..f629d9ff2c7165 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,8 +27,13 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f8f10..426b1d235c3135 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -108,8 +111,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -125,8 +128,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -140,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { // Tests that broadcasts fused into a fusion with a reduce root. TEST_F(InstructionFusionTest, BroadcastIntoReduce) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module add { @@ -169,7 +172,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { } TEST_F(InstructionFusionTest, BitcastIntoAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -191,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) { } TEST_F(InstructionFusionTest, AddIntoBitcast) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -213,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) { } TEST_F(InstructionFusionTest, DontFuseGTE) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY DontFuseGTE { p0 = (f32[10], f32[10]) parameter(0) @@ -229,15 +232,16 @@ TEST_F(InstructionFusionTest, DontFuseGTE) { } TEST_F(InstructionFusionTest, DotOutputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { - constant = f32[] constant(3) + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} p0 = f32[4,3]{1,0} parameter(0) p1 = f32[4,3]{1,0} parameter(1) transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} - dot = f32[4,4]{1,0} dot(p0, transpose) - ROOT mul = f32[4,4] multiply(constant, dot) + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) })") .ValueOrDie(); @@ -247,10 +251,334 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); EXPECT_THAT( root->fused_expression_root(), - op::Multiply(op::Parameter(), - op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Parameter()))); +} + +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = ParseHloString(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())) + << module->ToString(); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = ParseHloString(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Parameter()))); +} + +// Counts the HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +// Returns an HLO instruction from the given computation with the op code. +static StatusOr FindHloInstruction( + const HloComputation& computation, HloOpcode op) { + for (const auto* instruction : computation.instructions()) { + if (instruction->opcode() == op) { + return instruction; + } + } + return NotFound( + "Computation '%s' does not contain an instruction with op code '%s'.", + computation.name().c_str(), HloOpcodeString(op).c_str()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion) { + // sub --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT( + fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { + // tanh --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + tanh = f32[4,3]{1,0} tanh(p0) + add = f32[4,3]{1,0} add(tanh, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) + })") + .ValueOrDie(); + + // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion2) { + // sub --> add1 --\--------\ + // \----------> add2 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(sub, add1) + ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Add()), + op::Add(op::Subtract(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion3) { + // sub --> add1 ----\--------\ + // \ --> add2 --> add3 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + p3 = f32[4,3]{1,0} parameter(3) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(p2, sub) + add3 = f32[4,3]{1,0} add(add1, add2) + ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Add(), op::Add()), + op::Add(op::Parameter(), op::Subtract()))); +} + +TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { + // sub --> mul ---\ + // \--> call --> add --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + c = f32[] constant(42) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + sub = f32[4,3]{1,0} subtract(p0, p1) + mul = f32[4,3]{1,0} multiply(sub, c) + call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" + add = f32[4,3]{1,0} add(mul, call) + ROOT tuple = (f32[4,3]{1,0}) tuple(add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + // Visit instructions in post order to detect cycles. + // TODO(tjoerg): Add cycle detection to the HloVerifier. + class DummyVisitor : public DfsHloVisitorWithDefault { + public: + DummyVisitor() {} + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + } visitor; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + // Accept will return a FailedPrecondition when a cycle is detected. + EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); + } +} + +TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { + // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) + // \-------------------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[2,3]{1,0} parameter(2) + sub = f32[2,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` + // have incompatible shapes, expect that no multi-output fusion happens. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + fused_computation { + p1 = f32[10] parameter(0) + zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10] parameter(0) + mul = f32[10] multiply(p0, p0) + fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation + ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) + })") + .ValueOrDie(); + + // Multi-output fusion is not supported for non-loop fusions at present. Since + // `fused_computation` is a input fusion, expect no multi-output fusion to + // happen. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 5fc0780efb254f..3cc4ca60270546 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -69,24 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - return true; + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion && @@ -98,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { dot = hlo.fused_expression_root()->operand(1); } if (dot->opcode() == HloOpcode::kDot) { - return ImplementedAsGemm(*dot); + return DotImplementedAsGemm(*dot); } } @@ -160,19 +162,8 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn - // algorithm to use, and a boolean indicating whether to use tensor cores. - // - // It's up to a later pass to choose the algorithm and decide whether to use - // tensor cores, so to indicate that we haven't yet made a choice, we speicfy - // -1 and false for those args. - HloInstruction* negative_one = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); - HloInstruction* false_constant = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloInstruction* custom_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one, false_constant}, call_target)); + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6104dc1a83d356..bf7983f87d88a9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -112,7 +112,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global_for_const, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); + bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index b0accc08d47925..e55dfc6dae844c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -120,10 +120,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } - // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice. - BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const { + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { return ir_emitter_context_->buffer_assignment() - .GetUniqueTopLevelSlice(&hlo) + .GetUniqueSlice(&hlo, index) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 71aada080ae8df..bb47a4280541ce 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/core/lib/core/status.h" @@ -116,6 +117,26 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { + // For MOF we give the loop emitter an array for every output it should + // generate. + if (hlo.IsMultiOutputFusion()) { + std::vector target_arrays; + for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; + ++i) { + target_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (const llvm_ir::IrArray& array : target_arrays) { + tuple_operand_ptrs.push_back(array.GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, + module_); + return Status::OK(); + } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &ir_builder_) .EmitLoop(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c74795456474c9..4b0d62adf50da9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -79,6 +80,7 @@ namespace { using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::InlinedVector; using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; using tensorflow::strings::StrCat; @@ -234,8 +236,39 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( return kernel; } +namespace { +// Computes the maximum valid unroll factor for a given instruction. +int ComputeMaxUnrollFactor(const HloInstruction* hlo) { + int max_unroll_factor = hlo->GetModule() + ->config() + .debug_options() + .xla_gpu_max_kernel_unroll_factor(); + + // Find the largest possible power of two to unroll by. + // TODO(kramerb): Make this smarter. + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); + for (int i = max_unroll_factor; i > 1; i /= 2) { + if (num_elements % i == 0) { + return i; + } + } + + // Cannot unroll. + return 1; +} +} // namespace + Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { - thunk_sequence_->emplace_back(BuildKernelThunk(hlo)); + int unroll_factor = 1; + // Unfused elementwise operations are usually memory bound, unroll them. + if (hlo->IsElementwise()) { + unroll_factor = ComputeMaxUnrollFactor(hlo); + } + + thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); return IrEmitter::DefaultAction(hlo); } @@ -368,15 +401,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const HloInstruction* algorithm_inst = custom_call->operand(2); - CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); - int64 algorithm = algorithm_inst->literal().Get({}); - - const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); - CHECK(tensor_ops_enabled_inst->IsConstant()) - << tensor_ops_enabled_inst->ToString(); - bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); - + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -391,7 +417,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -404,7 +431,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -417,7 +445,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -445,12 +474,24 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kTuple: case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); + ArraySlice reduces = + root->opcode() == HloOpcode::kTuple + ? root->operands() + : ArraySlice(&root, 1); + + // For multi-output fusion emit an initializer for each tuple element. + // Otherwise it's sufficient to just initialize the single output. + for (int i = 0, e = reduces.size(); i != e; ++i) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk( + fusion, reduces[i] == root ? ShapeIndex() : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + } thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); @@ -464,11 +505,34 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - Shape input_shape = root->operand(0)->shape(); - return EmitReductionToVector( - root, input_shape, fused_emitter.GetGenerator(root->operand(0)), - fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), - root->to_apply()); + // For multi-output fusion CHECK the constraints and feed all the + // reduces into a single loop code generator. Single-output reduce + // fusion is a special case of that. + InlinedVector input_gens; + InlinedVector init_value_gens; + InlinedVector reducers; + for (const HloInstruction* reduce : reduces) { + CHECK_EQ(HloOpcode::kReduce, reduce->opcode()); + // TODO(kramerb): CHECK that layouts are equal. Currently this + // breaks multioutputfusion_test. The test has pre-fused + // instructions, but layout_assignment will not assign any layouts + // for instructions inside of a fused computation. It just removes + // the layouts instead. + CHECK(ShapeUtil::Compatible(reduces[0]->shape(), reduce->shape())); + CHECK(ShapeUtil::Compatible(reduces[0]->operand(0)->shape(), + reduce->operand(0)->shape())); + CHECK(ShapeUtil::Compatible(reduces[0]->operand(1)->shape(), + reduce->operand(1)->shape())); + CHECK(reduces[0]->dimensions() == reduce->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(reduce->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(reduce->operand(1))); + reducers.push_back(reduce->to_apply()); + } + const Shape& input_shape = reduces[0]->operand(0)->shape(); + return EmitReductionToVector(reduces[0], input_shape, input_gens, + init_value_gens, reduces[0]->dimensions(), + reducers); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -514,24 +578,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int max_unroll_factor = fusion->GetModule() - ->config() - .debug_options() - .xla_gpu_max_kernel_unroll_factor(); - - // Find the largest possible power of two to unroll by. - // TODO(kramerb): Make this smarter. - int unroll_factor = 1; - if (!fusion->IsMultiOutputFusion()) { - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - int64 num_elements = ShapeUtil::ElementsIn(fusion->shape()); - for (int i = max_unroll_factor; i > 1; i /= 2) { - if (num_elements % i == 0) { - unroll_factor = i; - break; - } - } - } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); @@ -870,8 +918,9 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; int64 num_elems = ShapeUtil::ElementsIn(input_shape); @@ -923,16 +972,19 @@ Status IrEmitterUnnested::EmitReductionToScalar( // auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = - llvm_ir::EmitAllocaAtFunctionEntry(element_ir_type, - "partial_reduction_result", - &ir_builder_); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; @@ -965,11 +1017,16 @@ Status IrEmitterUnnested::EmitReductionToScalar( llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's @@ -1004,21 +1061,24 @@ Status IrEmitterUnnested::EmitReductionToScalar( : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_, module_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_, module_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1034,14 +1094,25 @@ Status IrEmitterUnnested::EmitReductionToScalar( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), - output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + + for (int i = 0; i != num_reduces; ++i) { + ShapeIndex output_shape_index; + if (output->IsMultiOutputFusion()) { + output_shape_index = {i}; + } + llvm::Value* output_address = + GetIrArray(*output, *output, output_shape_index) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + /*linear=*/ir_builder_.getInt64(0), + ShapeUtil::GetSubshape(output->shape(), + output_shape_index), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterates through all input tiles, one per thread. @@ -1061,8 +1132,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers) { // Divide the input matrix into tiles of size Kx1. For example, when the // input matrix is 4x4 and K=2, the tiled matrix looks like // @@ -1072,9 +1144,13 @@ Status IrEmitterUnnested::EmitColumnReduction( // 4567 // Numbers indicate tile IDs. // // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. We - // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. - constexpr int64 kTileSize = 16; + // scalar is accumulated to the output vector using atomic operations. + // + // We choose 128 as the tile size based on empirical evidence. It's big enough + // to reduce the amount of atomic adds in the end, maximizing the memory + // bandwidth. + constexpr int64 kTileSize = 128; + // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); @@ -1104,15 +1180,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // } auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } // Emit an inner for-loop that partially reduces the elements in the given @@ -1170,13 +1251,17 @@ Status IrEmitterUnnested::EmitColumnReduction( .SourceIndexOfTranspose(normalized_input_shape, input_shape, transpose_dimension_mapping, &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return Status::OK(); } - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); }; // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's @@ -1205,13 +1290,24 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + ShapeIndex output_shape_index; + if (output->IsMultiOutputFusion()) { + output_shape_index = {i}; + } + llvm::Value* output_address = + GetIrArray(*output, *output, output_shape_index) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + output_shape_index), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterate through all input tiles. @@ -1231,8 +1327,10 @@ Status IrEmitterUnnested::EmitColumnReduction( Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers) { // A naive algorithm is: // 1. Divide the input tensor into tiles of size 1x1xK. // 2. Partially reduces each tile to a scalar using one thread. @@ -1322,15 +1420,20 @@ Status IrEmitterUnnested::EmitRowReduction( auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } // Emit an inner for-loop that partially reduces the elements in the given @@ -1413,13 +1516,17 @@ Status IrEmitterUnnested::EmitRowReduction( .SourceIndexOfTranspose(normalized_input_shape, input_shape, transpose_dimension_mapping, &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return Status::OK(); } - return EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( @@ -1447,21 +1554,24 @@ Status IrEmitterUnnested::EmitRowReduction( : element_ir_type; for (int shuffle_distance = (kWarpSize / 2); shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), - &ir_builder_, module_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_, module_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1475,13 +1585,24 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + ShapeIndex output_shape_index; + if (output->IsMultiOutputFusion()) { + output_shape_index = {i}; + } + llvm::Value* output_address = + GetIrArray(*output, *output, output_shape_index) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + y, + ShapeUtil::GetSubshape(output->shape(), + output_shape_index), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. @@ -1508,10 +1629,10 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer) { + tensorflow::gtl::ArraySlice reducers) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for // a fused kReduce). @@ -1546,8 +1667,8 @@ Status IrEmitterUnnested::EmitReductionToVector( // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. if (input_dims_to_keep.empty()) { - return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, - reducer); + return EmitReductionToScalar(reduce, input_shape, input_gens, + init_value_gens, reducers); } else if (input_dims_to_keep.front() == LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width @@ -1564,8 +1685,8 @@ Status IrEmitterUnnested::EmitReductionToVector( height *= input_shape.dimensions(input_dim); } } - return EmitColumnReduction(height, width, reduce, input_shape, input_gen, - init_value_gen, reducer); + return EmitColumnReduction(height, width, reduce, input_shape, input_gens, + init_value_gens, reducers); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1591,7 +1712,7 @@ Status IrEmitterUnnested::EmitReductionToVector( } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gen, init_value_gen, reducer); + input_gens, init_value_gens, reducers); } } @@ -1615,16 +1736,15 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), - [&](const llvm_ir::IrArray::Index& index) { + reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*input, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - [&](const llvm_ir::IrArray::Index& index) { + }}, + {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - dimensions_to_reduce, reducer); + }}, + dimensions_to_reduce, {reducer}); } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); @@ -1892,6 +2012,52 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() != 1) { + // TODO(b/33011107): Support nontrivial cross replica sum on GPU. + return Unimplemented( + "CrossReplicaSum with >1 replica is not implemented on GPU."); + } + + // CRS with one operand and one replica is simply the identity function. + // Buffer assignment expects a copy, so that's what we do. + // + // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely + // in algebraic-simplifier, but currently on some platforms + // HloModuleConfig::num_replicas changes between when the module is compiled + // and when it's run. + if (crs->operand_count() == 1) { + CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) + << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); + thunk_sequence_->push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*crs), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); + return Status::OK(); + } + + // One-replica CRS with multiple operands produces a tuple of the inputs. + // Again, buffer assignment expects us to copy each. + std::vector> thunks; + std::vector tuple_element_buffers; + for (int64 i = 0; i < crs->operand_count(); ++i) { + tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(crs, {i}) + .ValueOrDie()); + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*crs->operand(i)), + /*destination_buffer=*/tuple_element_buffers.back(), + /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs)); + } + + // Output a tuple of the buffers above. + thunks.push_back(MakeUnique(tuple_element_buffers, + GetAllocationSlice(*crs), crs)); + thunk_sequence_->push_back( + MakeUnique(std::move(thunks), crs)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -2158,6 +2324,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({})); + case F32: + return literal.Get({}); + case F64: + return literal.Get({}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2170,65 +2351,48 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. - false, // Do not transpose LHS. - false, // Do not transpose RHS. 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) { - const HloInstruction* mul = inst->fused_expression_root(); - const HloInstruction* dot = mul->operand(0); - const HloInstruction* alpha = mul->operand(1); - if (dot->opcode() != HloOpcode::kDot) { - std::swap(dot, alpha); - } - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - alpha->literal().Get({0}), // alpha. - inst); - } else { - const HloInstruction* dot = inst->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - 1.0, // Alpha. - inst); + CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + alpha = inst->operand(alpha->parameter_number()); + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. + inst); } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2245,7 +2409,7 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo) { + const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; const HloInstruction* init_value = [&] { @@ -2254,6 +2418,14 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( return inst->operand(2); case HloOpcode::kReduce: return inst->operand(1); + case HloOpcode::kTuple: + CHECK(hlo->IsMultiOutputFusion()) + << ": " << hlo->ToString() << " is not a multi-output fusion."; + CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce) + << ": Found '" << inst->operand(index.back())->opcode() << "' in " + << inst->ToString() << " but expected 'reduce'."; + // For multi-output fusion look through the tuple. + return inst->operand(index.back())->operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2277,7 +2449,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2293,8 +2465,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( pattern16 = literal_bytes.front(); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique(pattern32, - GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique( + pattern32, GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2304,8 +2476,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique(word, GetAllocationSlice(*hlo), - hlo)}; + return {MakeUnique( + word, GetAllocationSlice(*hlo, index), hlo)}; } } @@ -2504,16 +2676,14 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( .EmitLoop(IrName(&hlo)); } - CHECK_EQ(unroll_factor, 1) - << "multi-output fusion does not support unrolling"; - // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, + unroll_factor) .EmitLoop(IrName(&hlo))); std::vector tuple_operand_ptrs; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b842f480c6257c..b41eaa303b0aad 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -38,7 +38,7 @@ namespace gpu { // // Examples of things that are not unnested computations: // -// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The reducer of a kReduce HLO. This is emitted using IrEmitterNested. // - The body of a fusion node. IrEmitterUnenested emits the relevant code // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not // really an IrEmitter, but is more an "IR generator generator".) @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -109,28 +110,31 @@ class IrEmitterUnnested : public IrEmitter { // `EmitReductionToVector`. Note that input shape might not be // [height x width], but can be bitcast to [height x weight] with "height" // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitColumnReduction( + int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be // [depth x height x width], but can be bitcast to [depth x height x weight] // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitRowReduction( + int64 depth, int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers); // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers); // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or @@ -140,13 +144,16 @@ class IrEmitterUnnested : public IrEmitter { // generate elements of the input and the initial value. Other parameters mean // the same as for `HandleReduce`. // + // Multiple reduces can be emitted in the same loop, assuming they have the + // same input and output shapes, and the same reduce dimensions. + // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); + tensorflow::gtl::ArraySlice reducers); // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned @@ -165,7 +172,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo); + const HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e499a9b0091d3f..5355d34c019c88 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -35,29 +35,41 @@ KernelThunk::KernelThunk( kernel_name_(kernel_name), unroll_factor_(unroll_factor) {} -tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { +Status KernelThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); - if (loader_spec_) { - // Already initialized by another thread. - return tensorflow::Status::OK(); - } - - loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece text = executable.text(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(text.data(), text.size()), kernel_name_); + if (!loader_spec_) { + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); + tensorflow::StringPiece text = executable.text(); + // Convert tensorflow::StringPiece to se::port::StringPiece because + // StreamExecutor uses the latter. + loader_spec_->AddCudaPtxInMemory( + se::port::StringPiece(text.data(), text.size()), kernel_name_); // XXX figure out how to cope with both CUDA and ROCm platforms #if GOOGLE_CUDA - if (!executable.cubin().empty()) { - loader_spec_->AddCudaCubinInMemory( - reinterpret_cast(executable.cubin().data()), kernel_name_); - } + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), + kernel_name_); + } #endif + } - return tensorflow::Status::OK(); + // Load the kernel into the device if necessary. + // + // We could alternatively do this within ExecuteOnStream, but doing it here + // lets the time spent loading the kernel not count towards our execution + // profiles. + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } + } + + return Status::OK(); } void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { @@ -65,21 +77,18 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { launch_dimensions_ = launch_dims; } -tensorflow::Status KernelThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; const se::KernelBase* kernel = nullptr; + { tensorflow::mutex_lock lock(mutex_); auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; - if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); - } - } + CHECK(it != kernel_cache_.end()) + << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; kernel = &it->second; } @@ -100,7 +109,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index b556befe66b6be..7def27e189b667 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -57,11 +57,12 @@ class KernelThunk : public Thunk { int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - tensorflow::Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; // Executes the kernel for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // Buffers passed to the kernel as arguments. @@ -83,7 +84,8 @@ class KernelThunk : public Thunk { mutable tensorflow::mutex mutex_; std::unique_ptr loader_spec_ GUARDED_BY(mutex_); - // Loaded kernels for each `StreamExecutor` + // Loaded kernels for each `StreamExecutor`. Requires pointer stability of + // values. std::unordered_map kernel_cache_ GUARDED_BY(mutex_); }; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 0f229a4418a095..5fc8f1a5fe0bda 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -58,7 +58,6 @@ cc_library( "@llvm//:scalar", "@llvm//:support", "@llvm//:target", - "@llvm//:transform_utils", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/amdgpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/amdgpu_backend_lib.cc index d71ab62f8a0a07..dc6857f28a79da 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/amdgpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/amdgpu_backend_lib.cc @@ -248,7 +248,7 @@ std::vector EmitModuleToHsaco(Module* module, llvm::TargetMachine* target_ codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(codegen_passes, pstream, + target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::TargetMachine::CGFT_ObjectFile); codegen_passes.run(*module); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index df9d9be889ce83..a4e4e85bf3d2c1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = - tensorflow::Env::Default()->GetMatchingPaths( + const Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); if (status.ok() && unified_libdevice_files.size() == 1) { @@ -273,7 +272,7 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); - target_machine->addPassesToEmitFile(codegen_passes, pstream, + target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); codegen_passes.run(*module); } @@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary( - llvm::Module* module, std::pair compute_capability, - const string& libdevice_dir_path) { +Status LinkLibdeviceIfNecessary(llvm::Module* module, + std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { - return tensorflow::Status::OK(); + return Status::OK(); } llvm::Linker linker(*module); @@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary( return tensorflow::errors::Internal(tensorflow::strings::StrCat( "Error linking libdevice from ", libdevice_path)); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, @@ -491,7 +490,7 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - tensorflow::port::Tracing::TraceMe annotation( + tensorflow::tracing::ScopedActivity activity( "Compiling IR", llvm_ir::AsString(module->getName()), /*is_expensive=*/true); XLA_SCOPED_LOGGING_TIMER("Compile module " + diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 43ba786443891e..8ecd9ceb5c49f2 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -73,6 +73,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -100,7 +102,7 @@ namespace gpu { namespace { -using tensorflow::port::Tracing; +namespace tracing = tensorflow::tracing; // Returns the directory containing nvvm libdevice files. config_cuda_data_dir // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the @@ -128,9 +130,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { +Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -158,11 +159,13 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { pass.AddPass(); } + // TODO(kramerb): Remove use_fusion once instruction fusion can create + // multi-output fusions from the unfused expander output. pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*use_fusion=*/true); // Rewrite gather ops into smaller ones. pass.AddPass(); @@ -175,6 +178,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -201,18 +205,28 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_device_entry_computation_layout(), stream_exec); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); // Choose the fastest algorithm for each conv. // - // In theory doing this here is way too early: It needs to happen after - // layout assignment, because the layout of the inputs/outputs affects the - // speed of the conv. But currently we only allow only one input/output - // layout when calling cudnn, so there's no ambiguity. - // - // We pick the algorithm at this early stage so we can generate better HLO. - // After CudnnConvolutionRewriter, our convolutions are CustomCalls which - // return a tuple (conv_result, scratch_memory), and the each conv uses 0 - // bytes of scratch: + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: // // customcall = (f32[...], f32[0]) // return gte(customcall, 0) @@ -228,35 +242,15 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after layout - // assignment, fusion would already have run, and the gte(customcall, 0) - // would probably already be into a fusion node. We can't simplify across - // HloComputation boundaries, so in this case we wouldn't be able to - // simplify away the new_tuple bits. - // - // We'll need to revisit this if we ever allow multiple layouts for the - // inputs/outputs of a cudnn convolution. + // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass(); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassPipeline pipeline("layout_assignment"); - pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); pipeline.AddPass(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -283,12 +277,21 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } - return tensorflow::Status::OK(); + + { + // Do an aggressive LICM pass over while loops. In particular, this hoists + // constants that were sunk by WhileLoopConstantSinking. Leaving them in + // the while loop may result in unnecessary copies. + HloPassPipeline pipeline("while-loop-licm"); + pipeline.AddPass(true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + return Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output @@ -410,7 +413,7 @@ void WarnIfBadDriverJITVersion() { // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, int cc_minor) { - Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); + tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -481,8 +484,8 @@ StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); - Tracing::TraceMe annotation("HLO Transforms", module->name(), - /*is_expensive=*/true); + tracing::ScopedActivity activity("HLO Transforms", module->name(), + /*is_expensive=*/true); TF_RETURN_IF_ERROR( OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); @@ -692,7 +695,7 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult(const string& ptx, int cc_major, int cc_minor) { XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompilePtxOrGetCachedResult"); - Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); + tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; // Pointers into compilation_cache_ where the ptx and (optional) cubin are diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 7bda4e2fcd469b..c8f0d4185c63c5 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -370,26 +370,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } -StatusOr PadInsertion::Run(HloModule* module) { +StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { bool changed = false; - for (HloInstruction* instruction : - module->entry_computation()->MakeInstructionPostOrder()) { - if (IsCustomCallToDnnConvolution(*instruction)) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + for (HloInstruction* instruction : convs) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } return changed; } +StatusOr PadInsertion::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 5e1c68701daa02..67e51509e4c717 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -31,6 +31,7 @@ class PadInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index c8510808f10a73..88cb10883e97ae 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -20,24 +20,24 @@ limitations under the License. namespace xla { namespace gpu { -SequentialThunk::SequentialThunk(std::vector>&& thunks, +SequentialThunk::SequentialThunk(std::vector> thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} -tensorflow::Status SequentialThunk::Initialize( - const GpuExecutable& executable) { +Status SequentialThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executable)); + TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status SequentialThunk::ExecuteOnStream( +Status SequentialThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index df17b8d67b8032..135f79e413dfaa 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -31,16 +31,17 @@ namespace gpu { // require multiple kernel launches or library calls. class SequentialThunk : public Thunk { public: - SequentialThunk(std::vector>&& thunks, + SequentialThunk(std::vector> thunks, const HloInstruction* hlo); SequentialThunk(const SequentialThunk&) = delete; SequentialThunk& operator=(const SequentialThunk&) = delete; const std::vector>& thunks() const { return thunks_; } - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8c98956f1a9b2a..696fa7e0194032 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -28,6 +28,15 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", VersionedComputationHandle(), + config); + } + // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); }; @@ -41,9 +50,9 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -60,9 +69,9 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -91,24 +100,24 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc new file mode 100644 index 00000000000000..a50ddf6ac63c7f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { +namespace gpu { + +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::DataLayoutString; +using stream_executor::dnn::FilterLayout; +using stream_executor::dnn::FilterLayoutString; + +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + DataLayout input, FilterLayout filter, + DataLayout output) { + std::vector input_layout; + switch (input) { + case DataLayout::kBatchDepthYX: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.push_back(dnums.input_feature_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + input_layout.push_back(dnums.input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid input layout: ", + DataLayoutString(input)); + } + + std::vector filter_layout; + switch (filter) { + case FilterLayout::kOutputInputYX: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + break; + case FilterLayout::kOutputYXInput: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid filter layout: ", + FilterLayoutString(filter)); + } + + std::vector output_layout; + switch (output) { + case DataLayout::kBatchDepthYX: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.push_back(dnums.output_feature_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + output_layout.push_back(dnums.output_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid output layout: ", + DataLayoutString(output)); + } + + return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(output_layout)); +} + +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output) { + Layout nchw_input, nchw_filter, nchw_output; + std::tie(nchw_input, nchw_filter, nchw_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX) + .ConsumeValueOrDie(); + + Layout nhwc_input, nhwc_filter, nhwc_output; + std::tie(nhwc_input, nhwc_filter, nhwc_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth) + .ConsumeValueOrDie(); + + DataLayout input_layout; + if (LayoutUtil::Equal(input, nchw_input)) { + input_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(input, nhwc_input)) { + input_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid input layout: ", + input.ShortDebugString()); + } + + FilterLayout filter_layout; + if (LayoutUtil::Equal(filter, nchw_filter)) { + filter_layout = FilterLayout::kOutputInputYX; + } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + filter_layout = FilterLayout::kOutputYXInput; + } else { + return tensorflow::errors::Internal("Invalid filter layout: ", + filter.ShortDebugString()); + } + + DataLayout output_layout; + if (LayoutUtil::Equal(output, nchw_output)) { + output_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(output, nhwc_output)) { + output_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid output layout: ", + output.ShortDebugString()); + } + + return std::make_tuple(input_layout, filter_layout, output_layout); +} +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h new file mode 100644 index 00000000000000..8218f4fd11d397 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +// Helper functions for interacting with StreamExecutor. + +namespace xla { +namespace gpu { + +// Returns (input, filter, output) XLA Layout protos given the StreamExecutor +// layouts. +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + stream_executor::dnn::DataLayout input, + stream_executor::dnn::FilterLayout filter, + stream_executor::dnn::DataLayout output); + +// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index a0c785ed913109..931c0bffab8503 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,11 +70,14 @@ class Thunk { Kind kind() const { return kind_; } const HloInstruction* hlo_instruction() const { return hlo_instruction_; } - // Prepares for executing the thunk. This method is called only once over - // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a - // kernel, which is the same in every execution. - virtual tensorflow::Status Initialize(const GpuExecutable& executable) { - return tensorflow::Status::OK(); + // Prepares the thunk for execution on the given StreamExecutor. + // + // This may be called multiple times. Its main purpose is to give us a chance + // to do initialization outside of ExecuteOnStream() so that the + // time spent initializing doesn't count towards our execution profile. + virtual Status Initialize(const GpuExecutable& /*executable*/, + se::StreamExecutor* /*executor*/) { + return Status::OK(); } // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) @@ -89,21 +92,13 @@ class Thunk { return false; } - // Indicates whether thunks scheduled after this one should wait for this one - // to complete before running. For example, a convolution thunk creates a - // scratch allocator, then kicks off a convolution in cudnn via the stream - // executor. When the stream executor call returns, the scratch allocator goes - // out of scope, and the scratch memory is deallocated. In this case, the - // convolution thunk needs to return true so that future thunks wait for the - // convolution thunk to avoid reusing the deallocated memory until the - // convolution thunk is done with it. - virtual bool ShouldBlockFutureThunks() { return false; } - // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. - virtual tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) = 0; + // + // Precondition: Initialize(stream->parent()) has been called. + virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index ecb54857ccc40e..97cb04c38fbf18 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -20,8 +20,8 @@ limitations under the License. namespace xla { namespace gpu { -tensorflow::Status TupleThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -40,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream( tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), sizeof(void*) * tuple_element_buffer_addresses.size()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 8b459c29a136a6..951f809b51937c 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -45,8 +45,8 @@ class TupleThunk : public Thunk { TupleThunk(const TupleThunk&) = delete; TupleThunk& operator=(const TupleThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index a9f3d619a3ffd6..30b9640c4c75da 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,9 +34,11 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -Status WhileThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); +Status WhileThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index e589ca78a7ea00..22176685a92df9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,7 +45,8 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream) override; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index e6caec8625f0d6..7749201cbceece 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -144,7 +144,7 @@ class ExprTree { TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), tagged_instructions)); } - return tensorflow::Status::OK(); + return Status::OK(); } private: @@ -169,7 +169,7 @@ class MatcherBase { // Attempts to match each ExprTree in 'expr_trees_'. // Returns OK on the first successful match, error status otherwise. - virtual tensorflow::Status Run() { + virtual Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { status = MatchExprTree(expr_tree); @@ -201,7 +201,7 @@ class MatcherBase { } else if (type == S64) { *const_value = literal.GetFirstElement(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr GetTaggedInstruction( @@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase { gte_fusion_param0->name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; @@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase { GetTaggedInstruction("loop_start", tagged_instructions)); TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - return tensorflow::Status::OK(); + return Status::OK(); } const HloInstruction* while_hlo_; @@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase { return InvalidArgument("Unexpected tuple index instruction : %s", inst->name().c_str()); } else if (tag == "loop_increment") { - // Parse the constant which represents the loop induction variable - // increment value. + // ParseHloString the constant which represents the loop induction + // variable increment value. TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); } else if (tag == "param0" && inst != computation_->parameter_instruction(0)) { @@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase { } } } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 05017008e2ddbe..acf661148699da 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -82,7 +82,8 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // instructions. Sets the computation as the entry to an HLO module and returns // the module. std::unique_ptr MakeBigGraph() { - auto module = MakeUnique("BigGraph"); + HloModuleConfig config; + auto module = MakeUnique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 3dd4c4a0794e5c..06a5e0351b6327 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -32,7 +31,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = @@ -47,7 +46,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, @@ -73,11 +72,11 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + FlatMap> live_buffers; + FlatMap> used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, - const LogicalBuffer* buffer) { + const BufferValue* buffer) { if (!IgnoreBuffer(buffer)) { VLOG(4) << " Adding user " << user->name() << " to buffer " << buffer->ToString(); @@ -96,7 +95,7 @@ Status HeapSimulator::RunComputation( const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); for (const HloInstruction* user : instruction->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { - for (const LogicalBuffer* buffer : buffer_set) { + for (const BufferValue* buffer : buffer_set) { add_user_to_buffer(user, buffer); } } else { @@ -104,12 +103,12 @@ Status HeapSimulator::RunComputation( // alive. It only needs the buffers that relate to the element its // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. - for (const LogicalBuffer* buffer : points_to.element({})) { + for (const BufferValue* buffer : points_to.element({})) { add_user_to_buffer(user, buffer); } const PointsToSet& gte_points_to = points_to_analysis.GetPointsToSet(user); - for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) { add_user_to_buffer(user, buffer); } } @@ -117,24 +116,25 @@ Status HeapSimulator::RunComputation( } const HloInstruction* root = computation.root_instruction(); - auto output_source_buffers = - points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + BufferValueCompactPointerSet output_source_buffers = + ToBufferValueCompactPointerSet( + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet()); - std::vector dead_buffers_to_free; - std::vector operand_buffers_to_free; + std::vector dead_buffers_to_free; + std::vector operand_buffers_to_free; for (const HloInstruction* instruction : instruction_sequence) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); VLOG(3) << "Instruction: " << instruction->ToString(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { VLOG(4) << " Defines: " << buffer->ToString() << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); } dead_buffers_to_free.clear(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -161,7 +161,7 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { + for (const BufferValue* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } @@ -177,7 +177,7 @@ Status HeapSimulator::RunComputation( } // Sort to get a deterministic iteration order. std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); @@ -188,7 +188,7 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -199,12 +199,12 @@ Status HeapSimulator::RunComputation( // we must be the last user of the buffer. bool shared = false; if (options_.may_reuse_operand_buffers) { - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + for (const BufferValue* operand_buffer : operand_buffers_to_free) { if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( + points_to_analysis.CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index())) { VLOG(3) << " Sharing: " << buffer->ToString() << " with " << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); @@ -248,11 +248,11 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. - for (const LogicalBuffer* buffer : dead_buffers_to_free) { + for (const BufferValue* buffer : dead_buffers_to_free) { VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } - for (const LogicalBuffer* buffer : operand_buffers_to_free) { + for (const BufferValue* buffer : operand_buffers_to_free) { VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } @@ -261,10 +261,10 @@ Status HeapSimulator::RunComputation( // Any remaining live buffers must be entry parameters or output source // buffers, which had a nullptr sentry added. Free them now, in a // deterministic order. - std::vector to_free; + std::vector to_free; to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { - const LogicalBuffer* buffer = buffer_pending.first; + const BufferValue* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; @@ -272,10 +272,10 @@ Status HeapSimulator::RunComputation( } std::sort(to_free.begin(), to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); - for (const LogicalBuffer* buffer : to_free) { + for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } @@ -285,7 +285,7 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), @@ -297,7 +297,7 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} -bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { +bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { // Buffers for constants are ignored unless the alloc_constants option is // set. Also ignore buffers that we're not meant to assign. // @@ -311,7 +311,7 @@ bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { } // Alloc always calls the underlying heap algorithm. -void HeapSimulator::Alloc(const LogicalBuffer* buffer, +void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { CHECK(allocated_buffers_.count(buffer) == 0) << "Alloc called on allocated buffer: " << *buffer; @@ -331,7 +331,7 @@ void HeapSimulator::Alloc(const LogicalBuffer* buffer, // buffers whose group liveness has expired. Shared group liveness is tracked // by maintaining a refcount; the Free call on the last buffer in the group // causes Free to be called on the underlying algorithm. -void HeapSimulator::Free(const LogicalBuffer* buffer, +void HeapSimulator::Free(const BufferValue* buffer, const HloInstruction* instruction) { auto shared_it = shared_buffers_.find(buffer); if (shared_it != shared_buffers_.end()) { @@ -362,8 +362,8 @@ void HeapSimulator::Free(const LogicalBuffer* buffer, // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to // Alloc. The 'shared' buffer must be a previously allocated or shared buffer. // Both 'buffer' and 'shared' will be associated with the same SharedGroup. -void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, - const LogicalBuffer* shared, +void HeapSimulator::ShareBuffer(const BufferValue* buffer, + const BufferValue* shared, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; @@ -374,7 +374,7 @@ void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, CHECK(freed_buffers_.count(shared) == 0) << "ShareBuffer called on freed shared buffer: " << *shared; - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; auto shared_it = shared_buffers_.find(shared); if (shared_it != shared_buffers_.end()) { // The 'shared' buffer already has a group; it might be the canonical, but @@ -408,7 +408,7 @@ HeapSimulator::Result HeapSimulator::Finish() { // collecting statistics, e.g. NoFragmentationStatsHeap. if (!result.chunk_map.empty()) { for (const auto& share_pair : shared_buffers_) { - const LogicalBuffer* buffer = share_pair.first; + const BufferValue* buffer = share_pair.first; std::shared_ptr group = share_pair.second; if (buffer != group->canonical) { // The canonical must already exist in the chunk_map, since we called @@ -437,9 +437,9 @@ HeapSimulator::Result HeapSimulator::Finish() { } void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* share_with_canonical) { + const BufferValue* share_with_canonical) { HeapSimulatorTrace::Event* event = debug_trace_.add_events(); event->set_kind(kind); event->set_buffer_id(buffer->id()); @@ -453,14 +453,14 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } @@ -472,12 +472,12 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) { SetMode(kAlloc); run_.emplace_back(Op{buffer, size}); } -void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) { CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer; SetMode(kFree); run_.emplace_back(Op{buffer, size}); @@ -518,7 +518,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { run_.clear(); } -void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -586,7 +586,7 @@ void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size}); } -void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) { auto alloc_it = result_.chunk_map.find(buffer); CHECK(alloc_it != result_.chunk_map.end()) << "Free called on non-allocated buffer: " << *buffer; diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 636f19dd39f097..8b2b43a37a5c41 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -43,7 +44,7 @@ class HeapAlgorithm; // don't need to return the assignment of buffer offsets until the very end. class HeapSimulator { public: - // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be + // Chunk represents a contiguous piece of memory. Each BufferValue will be // associated with a chunk in the assignment result. struct Chunk { int64 offset; @@ -55,7 +56,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + tensorflow::gtl::FlatMap chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -81,7 +82,7 @@ class HeapSimulator { bool alloc_constants; // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. - const tensorflow::gtl::FlatSet* buffers_to_assign; + const BufferValueFlatSet* buffers_to_assign; }; // Run the heap simulation with the given algorithm, assuming the given @@ -97,7 +98,7 @@ class HeapSimulator { std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' @@ -109,7 +110,7 @@ class HeapSimulator { const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); private: @@ -118,7 +119,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -127,21 +128,21 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis); - bool IgnoreBuffer(const LogicalBuffer* buffer) const; - void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction); - void Free(const LogicalBuffer* buffer, const HloInstruction* instruction); - void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared, + bool IgnoreBuffer(const BufferValue* buffer) const; + void Alloc(const BufferValue* buffer, const HloInstruction* instruction); + void Free(const BufferValue* buffer, const HloInstruction* instruction); + void ShareBuffer(const BufferValue* buffer, const BufferValue* shared, const HloInstruction* instruction); Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* shared_with_canonical); + const BufferValue* shared_with_canonical); const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; - const LogicalBuffer::SizeFunction size_fn_; + const BufferValue::SizeFunction size_fn_; const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; @@ -160,15 +161,15 @@ class HeapSimulator { // The shared_buffers_ map associates each shared buffer (including the // canonical) to its SharedGroup control block. struct SharedGroup { - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + tensorflow::gtl::FlatSet allocated_buffers_; + tensorflow::gtl::FlatSet freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -186,10 +187,10 @@ class HeapAlgorithm { virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Alloc(const BufferValue* buffer, int64 size) = 0; // Free de-allocates a previously allocated buffer. - virtual void Free(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Free(const BufferValue* buffer, int64 size) = 0; // Finish collects the buffer offset assignment results. Free may only be // called once, after the Alloc and Free calls. @@ -205,8 +206,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: @@ -223,14 +224,14 @@ class DecreasingSizeRunsHeap : public HeapAlgorithm { : algorithm_(std::move(algorithm)) {} ~DecreasingSizeRunsHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: // A single Alloc or Free operation that we've buffered in run_. struct Op { - const LogicalBuffer* buffer; + const BufferValue* buffer; int64 size; }; @@ -266,8 +267,8 @@ class LazyBestFitHeap : public HeapAlgorithm { LazyBestFitHeap(int64 alignment) : alignment_(alignment) {} ~LazyBestFitHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 688a271712ac24..6271652412c297 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -20,11 +20,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -38,7 +39,7 @@ const char kFree[] = "Free"; const char kFinish[] = "Finish"; // CallSequence records a sequence of Alloc/Free/Finish calls. -using CallSequence = std::vector>; +using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. class HeapCallRecorder : public HeapAlgorithm { @@ -46,7 +47,7 @@ class HeapCallRecorder : public HeapAlgorithm { explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override { + void Alloc(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kAlloc, buffer); // Instead of assigning a real offset, we set the cardinality of the Alloc // call. This isn't a valid assignment, but allows us to easily test for @@ -54,7 +55,7 @@ class HeapCallRecorder : public HeapAlgorithm { const int64 offset = result_.chunk_map.size(); result_.chunk_map.emplace(buffer, Chunk{offset, size}); } - void Free(const LogicalBuffer* buffer, int64 size) override { + void Free(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kFree, buffer); } Result Finish() override { @@ -76,7 +77,8 @@ class HeapSimulatorTracker { HeapSimulatorTracker( const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -84,7 +86,7 @@ class HeapSimulatorTracker { // size of the buffers doesn't matter, so we always return 0. We rely on // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. - auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; + auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = MakeUnique( MakeUnique(&actual_calls_)); result_ = HeapSimulator::Run( @@ -94,7 +96,8 @@ class HeapSimulatorTracker { } explicit HeapSimulatorTracker(const string& name) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -115,9 +118,9 @@ class HeapSimulatorTracker { // Hack the size_fn so that it returns a decreasing value as we step through // the sequence. This lets us ensure the Alloc calls are in the sequence - // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // order. The Free calls are sorted by BufferValue.id, which is at least // deterministic. - auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; auto algorithm = MakeUnique( @@ -130,8 +133,8 @@ class HeapSimulatorTracker { HloModule* module() { return module_.get(); } // Returns the buffer defined at the given instruction and index. - const LogicalBuffer* BufferAt(const HloInstruction* instruction, - const ShapeIndex& index) const { + const BufferValue* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { return points_to_analysis_->GetBufferDefinedAt(instruction, index) .ConsumeValueOrDie(); } @@ -147,8 +150,8 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const LogicalBuffer* a = BufferAt(instruction_a, index_a); - const LogicalBuffer* b = BufferAt(instruction_b, index_b); + const BufferValue* a = BufferAt(instruction_a, index_a); + const BufferValue* b = BufferAt(instruction_b, index_b); EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) << *a << ", " << *b; } @@ -522,7 +525,7 @@ TEST_F(HeapSimulatorTest, WholeModule) { // Now the final cond less-than buffer is allocated. {kAlloc, tracker.BufferAt(cond_lt, {})}, - // The order of the remaining Free calls is based on the LogicalBuffer.id, + // The order of the remaining Free calls is based on the BufferValue.id, // which is deterministic, but not obvious. {kFree, tracker.BufferAt(param, {})}, {kFree, tracker.BufferAt(param, {0})}, @@ -544,40 +547,40 @@ TEST_F(HeapSimulatorTest, WholeModule) { class HeapAlgorithmTestBase : public ::testing::Test { protected: HeapAlgorithmTestBase() : builder_("heap_simulator_test") { - buffer_a_ = DummyLogicalBuffer(); - buffer_b_ = DummyLogicalBuffer(); - buffer_c_ = DummyLogicalBuffer(); - buffer_d_ = DummyLogicalBuffer(); - buffer_e_ = DummyLogicalBuffer(); - buffer_f_ = DummyLogicalBuffer(); - buffer_g_ = DummyLogicalBuffer(); - buffer_h_ = DummyLogicalBuffer(); - buffer_i_ = DummyLogicalBuffer(); + buffer_a_ = DummyBufferValue(); + buffer_b_ = DummyBufferValue(); + buffer_c_ = DummyBufferValue(); + buffer_d_ = DummyBufferValue(); + buffer_e_ = DummyBufferValue(); + buffer_f_ = DummyBufferValue(); + buffer_g_ = DummyBufferValue(); + buffer_h_ = DummyBufferValue(); + buffer_i_ = DummyBufferValue(); } ~HeapAlgorithmTestBase() override {} - const LogicalBuffer* buffer_a_; - const LogicalBuffer* buffer_b_; - const LogicalBuffer* buffer_c_; - const LogicalBuffer* buffer_d_; - const LogicalBuffer* buffer_e_; - const LogicalBuffer* buffer_f_; - const LogicalBuffer* buffer_g_; - const LogicalBuffer* buffer_h_; - const LogicalBuffer* buffer_i_; + const BufferValue* buffer_a_; + const BufferValue* buffer_b_; + const BufferValue* buffer_c_; + const BufferValue* buffer_d_; + const BufferValue* buffer_e_; + const BufferValue* buffer_f_; + const BufferValue* buffer_g_; + const BufferValue* buffer_h_; + const BufferValue* buffer_i_; private: - // Create a dummy LogicalBuffer to pass to the heap algorithm. - const LogicalBuffer* DummyLogicalBuffer() { - const LogicalBuffer::Id id = buffers_.size(); + // Create a dummy BufferValue to pass to the heap algorithm. + const BufferValue* DummyBufferValue() { + const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(const0, ShapeIndex{}, id)); + buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } HloComputation::Builder builder_; - std::vector> buffers_; + std::vector> buffers_; }; class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index aa6860880b7a13..1f7c1cffd324ad 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 00000000000000..b15f1f24c60771 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +template +using EnableIfDerivedFromHlo = + typename std::enable_if::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template * = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template * = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast( + Cast(const_cast(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template * = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template * = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast( + CastOrNull(const_cast(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template * = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast(instruction); +} + +// Non-const overload of DynCast. +template * = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast( + DynCast(const_cast(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template * = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template * = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast( + DynCastOrNull(const_cast(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc new file mode 100644 index 00000000000000..436a9222342dd8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DummyInstruction : public HloInstruction { + public: + DummyInstruction() + : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {} +}; + +class AnotherDummyInstruction : public HloInstruction { + public: + AnotherDummyInstruction() + : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {} +}; + +TEST(HloCastingUtilsTest, CastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(Cast(null), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = CastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(DynCast(null), ""); +} + +TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = DynCastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h new file mode 100644 index 00000000000000..658643b427a962 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +class HloInstruction; +class HloComputation; +class HloModule; + +// Data structure used to track the cloning of HloInstruction and HloComputation +// objects. +class HloCloneContext { + public: + // Creates a new HloCloneContext object to clone HloInstruction and + // HloComputation objects to be added to the module specified as argument. + // The suffix string will be appended to computation names. + explicit HloCloneContext(HloModule* module, const string& suffix = "") + : module_(module), suffix_(suffix) {} + + HloModule* module() const { return module_; } + + const string& suffix() const { return suffix_; } + + void MapInstruction(const HloInstruction* old_instruction, + HloInstruction* new_instruction) { + instructions_[old_instruction] = new_instruction; + } + + void MapComputation(const HloComputation* old_computation, + HloComputation* new_computation) { + computations_[old_computation] = new_computation; + } + + // Finds the new instruction mapped to its old copy, or return nullptr in case + // it is not found. + HloInstruction* FindInstruction(const HloInstruction* old_instruction) const { + return FindOrDefault(instructions_, old_instruction, nullptr); + } + + // Finds the new computation mapped to its old copy, or return nullptr in case + // it is not found. + HloComputation* FindComputation(const HloComputation* old_computation) const { + return FindOrDefault(computations_, old_computation, nullptr); + } + + // Retrieves the new instruction mapped to its old copy, or fail if not found. + HloInstruction* GetInstruction(const HloInstruction* old_instruction) const { + return FindOrDie(instructions_, old_instruction); + } + + // Retrieves the new computation mapped to its old copy, or fail if not found. + HloComputation* GetComputation(const HloComputation* old_computation) const { + return FindOrDie(computations_, old_computation); + } + + const tensorflow::gtl::FlatMap& + cloned_instructions() const { + return instructions_; + } + + const tensorflow::gtl::FlatMap& + cloned_computations() const { + return computations_; + } + + private: + HloModule* module_; + string suffix_; + tensorflow::gtl::FlatMap + instructions_; + tensorflow::gtl::FlatMap + computations_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88fb26e..b61eabbbf52624 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -360,25 +365,38 @@ std::list HloComputation::MakeEmbeddedComputationsList() string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } - if (options.print_percent()) { - s << "%"; + + if (!options.is_in_nested_computation()) { + if (options.print_percent()) { + s << "%"; + } + s << name() << " "; } - s << name(); + if (options.print_program_shape()) { - s << " " << ShapeUtil::HumanString(ComputeProgramShape()); - } - s << " {\n"; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; + } + s << "{\n"; + { + // Print the instructions in this computation. + HloPrintOptions new_options = options; + new_options.set_indent_amount(options.indent_amount() + 1) + .set_is_in_nested_computation(true); + CanonicalNameMap name_map; + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < new_options.indent_amount(); i++) { + s << " "; + } + s << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToStringWithCanonicalNameMap(new_options, &name_map) + << "\n"; } - s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString(options) << "\n"; } + for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } s << "}"; return s.str(); @@ -402,27 +420,37 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map) { - std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } TF_RET_CHECK(proto.root_id() != -1); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); @@ -723,18 +751,24 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloCloneContext* context) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloCloneContext* context, const string& suffix) { + std::unique_ptr context_ptr; + if (context == nullptr) { + context_ptr = MakeUnique(parent(), suffix); + context = context_ptr.get(); + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +790,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; - std::unique_ptr new_instr = nullptr; + std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) << "replacements map tried to eliminate a used instruction " << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + new_operands.push_back(context->GetInstruction(replaced_operand)); } new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + instr->CloneWithNewOperands(instr->shape(), new_operands, context); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +810,25 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/context->GetInstruction( + replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = context->GetInstruction(instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been + // successor may not have been remapped, because it might have been // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; + if (replaced_successor != nullptr) { + TF_CHECK_OK(new_instr->AddControlDependencyTo( + context->GetInstruction(replaced_successor))); } - - TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); } } - + context->MapComputation(this, result.get()); // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { @@ -809,7 +836,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( new_instr->DetachFromOperands(); } } - return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c2efd..0da4a305f3d5d6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -49,9 +50,20 @@ class HloModule; // Describes a computation at the HLO level. // -// An HloComputation contains a directed acyclic graph of HLO instructions. The -// computation has a single root instruction which produces the output of the -// computation. +// You can think of an HloComputation like a function. It has some inputs +// (parameters) and returns exactly one value (the value of its root node). If +// you want to return multiple values, you can return a tuple. +// +// The instructions inside of a computation do not have an explicit total order. +// Instead, they have a partial order determined by their data and control +// dependencies. +// +// An HloModule contains one "entry computation" -- this is like main() in a C +// program. Every other computation inside of a module is attached to one or +// more HloInstructions, as a "nested computation". For example, the kMap +// instruction has a nested computation and "applies" it to every element of its +// input, elementwise. (That is, the input [x, y, z] is transformed to [f(x), +// f(y), f(z)].) class HloComputation { public: // Builder class for HloComputation. @@ -157,14 +169,12 @@ class HloComputation { // Creates a computation from the given proto. Arguments: // - // module: the module which will contain the computation. The newly created - // computation is *not* added to the module, however. // proto: the proto to convert from. // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. @@ -291,11 +301,11 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). + // If the clone context is specified, it will be populated with the cloned + // object mappings, and its module() will be used to add new computations + // into. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + HloCloneContext* context = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +315,7 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 7b7588f4ba9aa6..25469a54c48f4f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) { EXPECT_FALSE(reachability->IsReachable(constant2, copy)); } +TEST_F(HloComputationTest, Stringification) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloComputationTest, StringificationIndent) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = + HloPrintOptions().set_print_metadata(false).set_indent_amount(2); + EXPECT_EQ(computation->ToString(options), + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"); +} + +TEST_F(HloComputationTest, StringificationCanonical) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); + + options = HloPrintOptions().Canonical(); + EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7b552ee5b1798c..5d05ccfc0b223d 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 44e4f75f75b275..94c9c7eabcc99d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -142,19 +142,25 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { } Status HloCostAnalysis::HandleParameter(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -329,6 +335,7 @@ Status HloCostAnalysis::HandleSelectAndScatter( Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -555,11 +562,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - // We can't do anything sane with CustomCalls, since we don't know what they - // do, and returning an error status will stop iteration over this - // computation, which is probably also not what we want. So just punt and - // return OK. This will cause all of the properties to be reported as 0, - // which is fine. + // Mark applicable fields as "unknown", since we don't know what CustomCall + // does. This is better than returning an error, which would stop iteration, + // and therefore would prevent us from getting *any* stats for a computation + // which contains a CustomCall. + current_properties_[kOptimalSecondsKey] = -1; + current_properties_[kBytesAccessedKey] = -1; + current_properties_[kFlopsKey] = -1; current_should_compute_bottleneck_time_ = false; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 3d055b327ee920..16fdda8a8b9ade 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -20,16 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" @@ -58,11 +55,10 @@ class HloCostAnalysisTest : public ::testing::Test { // whitebox accesses to the user computation built from the client, // as shown in the BuildHloGraph functions below. service_(static_cast(ClientLibrary::GetXlaService( - static_cast(client_)->platform()))), - computation_tracker_(service_->computation_tracker()) { + static_cast(client_)->platform()))) { // Create a computation for a unary user function: x => exp(x + 0.5) { - ComputationBuilder builder(client_, "add_and_exp"); + XlaBuilder builder("add_and_exp"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto half = builder.ConstantR0(0.5); builder.Exp(builder.Add(x, half)); @@ -73,7 +69,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -84,7 +80,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { - ComputationBuilder builder(client_, "sigmoid"); + XlaBuilder builder("sigmoid"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = builder.ConstantR0(1.0); builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); @@ -95,7 +91,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { - ComputationBuilder builder(client_, "max"); + XlaBuilder builder("max"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Max(x, y); @@ -106,7 +102,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { - ComputationBuilder builder(client_, "gt"); + XlaBuilder builder("gt"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Gt(x, y); @@ -117,35 +113,30 @@ class HloCostAnalysisTest : public ::testing::Test { } // Build HLO graph from the given builder and return the HLO module. - std::unique_ptr BuildHloGraph(ComputationBuilder* builder) { + std::unique_ptr BuildHloGraph(XlaBuilder* builder) { auto computation_status = builder->Build(); TF_CHECK_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto user_computation_status = - computation_tracker_.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - return std::move( - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) - .ValueOrDie()); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); + return HloModule::CreateFromProto(computation.proto(), config) + .ConsumeValueOrDie(); } Client* client_; Service* service_; - const ComputationTracker& computation_tracker_; // User computations used for higher order operations (e.g., Map, Reduce). - Computation add_; - Computation add_and_exp_; - Computation sigmoid_; - Computation max_; - Computation gt_; + XlaComputation add_; + XlaComputation add_and_exp_; + XlaComputation sigmoid_; + XlaComputation max_; + XlaComputation gt_; }; TEST_F(HloCostAnalysisTest, MatrixMultiply) { - ComputationBuilder builder(client_, "matrix_multiply"); + XlaBuilder builder("matrix_multiply"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); auto result = builder.Dot(lhs, rhs); @@ -167,7 +158,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { } TEST_F(HloCostAnalysisTest, Map) { - ComputationBuilder builder(client_, "map"); + XlaBuilder builder("map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); auto result = builder.Map({input}, add_and_exp_, {0}); @@ -184,7 +175,7 @@ TEST_F(HloCostAnalysisTest, Map) { } TEST_F(HloCostAnalysisTest, Convolution) { - ComputationBuilder builder(client_, "convolution"); + XlaBuilder builder("convolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -213,7 +204,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { } TEST_F(HloCostAnalysisTest, Reduce) { - ComputationBuilder builder(client_, "reduce"); + XlaBuilder builder("reduce"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = @@ -231,7 +222,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { } TEST_F(HloCostAnalysisTest, ReduceWindow) { - ComputationBuilder builder(client_, "reduce_window"); + XlaBuilder builder("reduce_window"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, @@ -248,7 +239,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { } TEST_F(HloCostAnalysisTest, SelectAndScatter) { - ComputationBuilder builder(client_, "select_and_scatter"); + XlaBuilder builder("select_and_scatter"); auto operand = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = @@ -269,7 +260,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { } TEST_F(HloCostAnalysisTest, Broadcast) { - ComputationBuilder b(client_, "broadcast"); + XlaBuilder b("broadcast"); b.Broadcast(b.ConstantR0(42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); @@ -280,7 +271,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { // Calculates the computation cost of a graph with more than one HLO node. TEST_F(HloCostAnalysisTest, FullyConnectedForward) { - ComputationBuilder builder(client_, "fully_connected_forward"); + XlaBuilder builder("fully_connected_forward"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = @@ -305,7 +296,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { - ComputationBuilder builder(client_, "conv_looking_matmul"); + XlaBuilder builder("conv_looking_matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), "input"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -318,7 +309,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis matmul_analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = @@ -370,8 +361,8 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -412,8 +403,8 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -427,7 +418,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); auto tuple = builder.Tuple({x, y}); @@ -443,7 +434,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) { } TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { - ComputationBuilder builder(client_, "BaseDilatedConvolution"); + XlaBuilder builder("BaseDilatedConvolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -458,7 +449,7 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { auto result = builder.ConvGeneralDilated( input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - ComputationBuilder::CreateDefaultConvDimensionNumbers(2)); + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 9a89888480b8c7..0fb65c845a6d44 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -162,6 +162,17 @@ StatusOr MakeConcatHlo(ArraySlice operands, HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + return computation->AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -269,7 +280,7 @@ StatusOr BroadcastZeros( StatusOr> CreateComputationWithSignature( ArraySlice domain, const Shape& range, tensorflow::StringPiece name) { - HloComputation::Builder b(name.ToString()); + HloComputation::Builder b{std::string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index c9a7361a6af0c2..49b1402d689a74 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -97,6 +97,11 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, StatusOr MakeConcatHlo( tensorflow::gtl::ArraySlice operands, int64 dimension); +// Creates a Dot HLO instruction and adds it to the computation containing `lhs` +// and `rhs` (both must be in the same computation). +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 6b681a5bf6f34b..7e7c4f95fed737 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -19,27 +19,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { using tensorflow::gtl::ArraySlice; -std::unique_ptr CreateModuleWithProgramShape( - PrimitiveType primitive_type, ArraySlice input_shape_dims, - ArraySlice output_shape_dims, HloInstruction** param, - HloComputation** entry_computation) { - Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); - Shape output_shape = ShapeUtil::MakeShape(primitive_type, output_shape_dims); - std::unique_ptr module = MakeUnique("test"); - *entry_computation = module->AddEntryComputation( - CreateComputationWithSignature({&input_shape}, output_shape, "entry") - .ValueOrDie()); - *param = (*entry_computation)->parameter_instruction(0); - return module; -} - -TEST(HloCreationUtilsTest, CollapseFirst1Dim) { +class HloCreationUtilsTest : public HloTestBase { + protected: + static std::unique_ptr CreateModuleWithProgramShape( + PrimitiveType primitive_type, ArraySlice input_shape_dims, + ArraySlice output_shape_dims, HloInstruction** param, + HloComputation** entry_computation) { + Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); + Shape output_shape = + ShapeUtil::MakeShape(primitive_type, output_shape_dims); + auto module = CreateNewModule("test"); + *entry_computation = module->AddEntryComputation( + CreateComputationWithSignature({&input_shape}, output_shape, "entry") + .ValueOrDie()); + *param = (*entry_computation)->parameter_instruction(0); + return module; + } +}; + +TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; @@ -59,7 +64,7 @@ TEST(HloCreationUtilsTest, CollapseFirst1Dim) { CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); } -TEST(HloCreationUtilsTest, CollapseFirst2Dims) { +TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; @@ -84,7 +89,7 @@ TEST(HloCreationUtilsTest, CollapseFirst2Dims) { {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } -TEST(HloCreationUtilsTest, Prepend1DegenerateDim) { +TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; @@ -104,7 +109,7 @@ TEST(HloCreationUtilsTest, Prepend1DegenerateDim) { CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); } -TEST(HloCreationUtilsTest, Prepend2DegenerateDims) { +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; @@ -124,7 +129,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDims) { CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); } -TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; @@ -144,7 +149,7 @@ TEST(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); } -TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { +TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; @@ -166,7 +171,7 @@ TEST(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } -TEST(HloCreationUtilsTest, PadVectorWithZeros) { +TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; @@ -187,7 +192,7 @@ TEST(HloCreationUtilsTest, PadVectorWithZeros) { CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); } -TEST(HloCreationUtilsTest, BroadcastZeros_S32) { +TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; @@ -208,7 +213,7 @@ TEST(HloCreationUtilsTest, BroadcastZeros_S32) { CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); } -TEST(HloCreationUtilsTest, BroadcastZeros_F32) { +TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 3b22c93733af29..dab946a099fa00 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -26,12 +26,14 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -40,16 +42,16 @@ namespace { // Find and combine identical constants. Constants are identical if they have // the same type and value. -bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { - bool changed = false; - +StatusOr CombineConstants(HloComputation* computation, + bool is_layout_sensitive) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); // Map from ShortDebugString of the layoutless shape of the constant to the // set of constant instructions with that shape. Layoutless shape is used to // bin possible common constants together to reduce number of constant // comparisons. If we end up having too many constant comparisons, a more // precise binning might have to be used. std::multimap constants; - + int64 combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { HloInstruction* instruction = *inst_it; @@ -69,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal() == it->second->literal()) { + if (instruction->literal() == it->second->literal() && + domain_map->InSameDomain(it->second, instruction)) { match = it->second; break; } @@ -80,12 +83,27 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { // Match found, replace this instruction with the one in the multimap. TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); - changed = true; + ++combined; } } } + VLOG(4) << "Combined " << combined << " constants in " << computation->name() + << " computation"; + return combined > 0; +} - return changed; +// An instruction is considered to be equivalent to another only if they +// share the exact same set of operands. +int64 CseHash(const HloInstruction* instruction) { + int64 hash = std::hash()(static_cast(instruction->opcode())); + hash = tensorflow::Hash64Combine( + hash, instruction->opcode() == HloOpcode::kGetTupleElement + ? instruction->tuple_index() + : -1); + for (auto operand : instruction->operands()) { + hash = tensorflow::Hash64Combine(hash, operand->unique_id()); + } + return hash; } } // namespace @@ -95,21 +113,34 @@ StatusOr HloCSE::Run(HloModule* module) { const std::function eq_instructions = std::equal_to(); const std::function - eq_computations = std::equal_to(); + eq_computations = [](const HloComputation* lhs, + const HloComputation* rhs) { return *lhs == *rhs; }; + + auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->Identical(*rhs, eq_instructions, eq_computations, + is_layout_sensitive_); + }; + for (auto* computation : module->computations()) { if (only_fusion_computations_ && !computation->IsFusionComputation()) { continue; } - changed |= CombineConstants(computation, is_layout_sensitive_); - - std::list post_order = - computation->MakeInstructionPostOrder(); - std::set removed_instructions; - for (auto instruction : post_order) { - // If the instruction has already been removed by CSE skip over it. - if (removed_instructions.count(instruction) > 0 || - instruction->operand_count() == 0) { + TF_ASSIGN_OR_RETURN(bool combined, + CombineConstants(computation, is_layout_sensitive_)); + changed |= combined; + + // HLO instructions are grouped into equivalency classes by using the + // cse_equal predicate defined above. This set holds a representative + // instruction for each class. + tensorflow::gtl::FlatSet + representatives(/*N=*/1024, &CseHash, cse_equal); + + for (auto instruction : computation->MakeInstructionPostOrder()) { + // If the instruction has zero operands (constants, parameters, etc.) skip + // over it. + if (instruction->operand_count() == 0) { continue; } @@ -118,31 +149,16 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } - // An instruction is considered to be equivalent to another only if they - // share the exact same set of operands. So to find equivalent - // instructions, we just search among instructions which share operand(0) - // of this instruction. - const HloInstruction* operand = instruction->operand(0); - - tensorflow::gtl::InlinedVector - equivalent_instructions; - for (HloInstruction* user : operand->users()) { - if (user != instruction && !user->HasSideEffect() && - user->Identical(*instruction, eq_instructions, eq_computations, - is_layout_sensitive_)) { - equivalent_instructions.push_back(user); - } - } - - // Replace all equivalent instructions with this instruction. - for (HloInstruction* equivalent_instruction : equivalent_instructions) { - TF_RETURN_IF_ERROR( - equivalent_instruction->ReplaceAllUsesWith(instruction)); + auto it = representatives.find(instruction); + if (it != representatives.end()) { + HloInstruction* equivalent_instruction = *it; TF_RETURN_IF_ERROR( - computation->RemoveInstruction(equivalent_instruction)); - removed_instructions.insert(equivalent_instruction); + instruction->ReplaceAllUsesWith(equivalent_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); changed = true; + continue; } + representatives.insert(instruction); } } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index df8853f34f6a72..16db374566c727 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" @@ -72,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR0(84.0); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -104,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -134,38 +135,53 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // Test that constants with the same value but different type are *not* // commoned. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + std::vector constants; + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + + const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); + for (int64 i = 0; i < constants.size(); ++i) { + constants[i] = builder.AddInstruction( + HloInstruction::CreateConvert(shape_r0, constants[i])); + } + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, constants[0], constants[1])); + for (int64 i = 2; i < constants.size(); ++i) { + root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, root, constants[i])); + } auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(7, computation->instruction_count()); + EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - EXPECT_EQ(6, computation->instruction_count()); + // CSE will remove both the second float(42.0f) and the corresponding + // convert/cast. + EXPECT_EQ(18, computation->instruction_count()); } TEST_F(HloCseTest, NonscalarConstants) { @@ -469,5 +485,56 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } +TEST_F(HloCseTest, CompareComputations) { + auto module = ParseHloString(R"( + HloModule m + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + add_computation2 { + add_lhs2 = f32[] parameter(0) + add_rhs2 = f32[] parameter(1) + ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) + } + + ENTRY entry { + p = f32[10]{0} parameter(0) + c = f32[] constant(0) + r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation + r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 + ROOT f2 = (f32[],f32[]) tuple(r1, r2) + })") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0), root->operand(1)); +} + +TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { + // Test that constants with the same value but in different domains (disjoint + // in this case) are not collapsed. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(2, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 0c37a8d75f38da..cc130a4900dc16 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -363,7 +363,7 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateConditionalValueSet( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( @@ -538,7 +538,7 @@ bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) { bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet(xla_while->while_body()->root_instruction()), &GetInstructionValueSet(xla_while->operand(0))}; if (ssa_form_) { @@ -878,4 +878,128 @@ Status HloDataflowAnalysis::Verify() const { return Status::OK(); } +bool HloDataflowAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // Iterate through all users of all uses of the fusion parameter value. + // Return false if any uses are detected, returns true otherwise. + const HloValue& value = GetValueDefinedAt(fusion_param, index); + return value.uses().empty(); + } else { + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + return false; + } + } + } + } + + return true; +} + +bool HloDataflowAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + + if (user->opcode() == HloOpcode::kFusion) { + // Get the parameter associated with 'operand'; + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + + const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); + if (value.uses().size() != 1) { + return false; + } + const HloUse& use = value.uses()[0]; + + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 7b8a74b096ff48..9868746b611388 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -118,6 +118,23 @@ class HloDataflowAnalysis { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 07f69b8e1339fe..5798326dcbf65c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1873,5 +1873,346 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); +class HloDataflowAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr dataflow_analysis_; +}; + +class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc new file mode 100644 index 00000000000000..78955db0da02f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainIsolator::RunContext { + public: + RunContext(HloModule* module, HloDomainIsolator* isolator) + : module_(module), isolator_(isolator) {} + + StatusOr Run(); + + private: + // Inserts a kDomain instruction between parent and operand, in case + // the attribute (ie, sharding) values change between instruction and operand. + // Returns the newly inserted kDomain instruction, or nullptr if no kDomain + // instruction was necessary. + StatusOr CreateDomain(HloInstruction* instruction, + HloInstruction* parent, + HloInstruction* operand); + + HloModule* module_; + HloDomainIsolator* isolator_; +}; + +StatusOr HloDomainIsolator::RunContext::CreateDomain( + HloInstruction* instruction, HloInstruction* parent, + HloInstruction* operand) { + HloInstruction* domain = nullptr; + std::unique_ptr domain_instruction = + isolator_->creator_(instruction, operand); + if (domain_instruction != nullptr) { + domain = operand->parent()->AddInstruction(std::move(domain_instruction)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + } + return domain; +} + +StatusOr HloDomainIsolator::RunContext::Run() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); + + int64 added_domains = 0; + for (HloComputation* computation : module_->computations()) { + // Walk in post order and place all the required kDomain instructions. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + continue; + } + for (HloInstruction* operand : instruction->unique_operands()) { + // When applying multiple domains, we could end up stacking more than + // one in one edge, so here we want to build the effective + // (kDomain-less) instruction->operand edge. + HloInstruction* parent = instruction; + while (operand->opcode() == HloOpcode::kDomain) { + parent = operand; + operand = operand->mutable_operand(0); + } + // Check whether a kDomain is necessary between instruction and operand. + TF_ASSIGN_OR_RETURN(HloInstruction * domain, + CreateDomain(instruction, parent, operand)); + if (domain != nullptr) { + VLOG(4) << "New domain: " << domain->ToString(); + ++added_domains; + } + } + } + } + VLOG(3) << "Added " << added_domains << " kDomain instructions"; + if (added_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + } + return added_domains > 0; +} + +HloDomainIsolator::HloDomainIsolator(DomainCreator creator) + : creator_(std::move(creator)) {} + +StatusOr HloDomainIsolator::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h new file mode 100644 index 00000000000000..e0c5718509dabe --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Domain isolation is the task of placing kDomain instructions between HLO +// instructions having different shrading. A kDomain instruction is essentially +// used to break an HLO graph edge connecting two instructions with different +// sharding. If a set of connected instructions have all the same sharding, no +// kDomain instruciton will be placed. +class HloDomainIsolator : public HloPassInterface { + public: + // Creates a new kDomain instruction for the edge between the use instruction + // (the first HloInstruction argument), and the operand instruction (the + // second HloInstruction argument). + // Returns nullptr in case no domain separation is necessary. + using DomainCreator = std::function( + HloInstruction*, HloInstruction*)>; + + explicit HloDomainIsolator(DomainCreator creator); + + tensorflow::StringPiece name() const override { return "domain_isolator"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + DomainCreator creator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc new file mode 100644 index 00000000000000..ebd5adb5d573ce --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +/* static */ StatusOr> HloDomainMap::Create( + HloComputation* computation, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + return std::move(domain_map); +} + +/* static */ StatusOr> HloDomainMap::Create( + HloModule* module, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + } + return std::move(domain_map); +} + +bool HloDomainMap::InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const { + int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); + int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + return domain_id1 >= 0 && domain_id1 == domain_id2; +} + +Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); + // We only check operands, so we are sure to not process the empty domain from + // both sides. + for (HloInstruction* operand : instruction->unique_operands()) { + if (IsDomainInstruction(operand)) { + auto domain = MakeUnique(); + domain->enter_domains.insert(operand); + domain->exit_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + } + return Status::OK(); +} + +Status HloDomainMap::Populate(HloComputation* computation) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsDomainInstruction(instruction)) { + // If this is a kDomain of the kind we are currently processing, check + // whether this is an "empty domain". + TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + continue; + } + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + if (domain_id >= 0) { + // We have already processed this instruction. + continue; + } + TF_ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction)); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + return Status::OK(); +} + +Status HloDomainMap::InsertDomain( + std::unique_ptr domain) { + int64 domain_id = instruction_domains_.size(); + instruction_domains_.push_back(std::move(domain)); + for (HloInstruction* instruction : instruction_domains_.back()->reach_set) { + instruction_to_domain_[instruction] = domain_id; + } + return Status::OK(); +} + +Status HloDomainMap::ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const { + std::vector in_queue; + in_queue.push_back(instruction); + while (!in_queue.empty()) { + HloInstruction* current_instruction = in_queue.back(); + in_queue.pop_back(); + if (domain->reach_set.insert(current_instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = + FindOrDefault(instruction_to_domain_, current_instruction, -1); + TF_RET_CHECK(domain_id < 0) + << "Instruction " << current_instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : current_instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + in_queue.push_back(operand); + } + } + for (HloInstruction* user : current_instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + in_queue.push_back(user); + } + } + } + } + return Status::OK(); +} + +StatusOr> HloDomainMap::CreateDomain( + HloInstruction* instruction) const { + auto domain = MakeUnique(); + TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + domain->instructions = MakeNonDomainInstructions(domain->reach_set); + return std::move(domain); +} + +bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { + if (instruction->opcode() != HloOpcode::kDomain) { + return false; + } + if (!domain_kind_.empty()) { + if (instruction->user_side_metadata().Kind() != domain_kind_) { + return false; + } + // Both user and operand side of the metadata must be of the same kind. + CHECK(instruction->operand_side_metadata().Kind() == domain_kind_) + << "Instruction " << instruction->ToString() + << " has mismatching metadata kinds"; + } + return true; +} + +/* static */ std::vector +HloDomainMap::MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set) { + std::vector instructions; + instructions.reserve(instruction_set.size()); + for (HloInstruction* instruction : instruction_set) { + if (instruction->opcode() != HloOpcode::kDomain) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [](HloInstruction* a, HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + return instructions; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h new file mode 100644 index 00000000000000..e62ef763fb3881 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The HloDomainMap splits a set of instructions within a module or computation, +// into different domains, separated by kDomain instructions. +// A domain is composed by a set of instructions which can reach each other via +// operand/user edges, without crossing a kDomain insutrction of a given kind. +// A domain never crosses computation boundaries. +class HloDomainMap { + public: + // Creates a new HloDomainMap, creating all the domains within the input + // computation, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create( + HloComputation* computation, string domain_kind); + + // Creates a new HloDomainMap, creating all the domains within the input + // module, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create(HloModule* module, + string domain_kind); + + // Retrieves all the domains the input module or computation are composed by. + const std::vector>& GetDomains() + const { + return instruction_domains_; + } + + // Checks whether two instructions are within the same domain. + bool InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const; + + // Checks whether instruction is a kDomain instruction of the kind we are + // currently processing. + bool IsDomainInstruction(HloInstruction* instruction) const; + + private: + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} + + // Check if the kDomain instruction is facing (via its operand link) another + // kDomain instruction of the same kind, hence defining an empty domain. + // If that is the case, create the empty domain and call the proper + // normalizer. + Status TryProcessEmptyDomain(HloInstruction* instruction); + + Status Populate(HloComputation* computation); + + // Inserts the provided domain into the ones tracked by this object, + // creating a new domain ID. + Status InsertDomain(std::unique_ptr domain); + + // From the given instruction, epxands operand and user wise, the set of + // instructions which can be reached without crossing a kDomain instruction + // of the kind specified by domain_kind_. + // The domain data structure will be populated with all the reached + // instructions, and the boundaries of the domain, with the kDomain + // instructions encountered while expanding the reach. + Status ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const; + + // Creates a domain data structure using the ExpandDomain() API. + StatusOr> CreateDomain( + HloInstruction* instruction) const; + + // Out of an instruction set, returns a vector of all the ones which are not + // a kDomain kind. + static std::vector MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set); + + string domain_kind_; + std::vector> instruction_domains_; + tensorflow::gtl::FlatMap instruction_to_domain_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h new file mode 100644 index 00000000000000..aa0308100a21f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Cannot include hlo_instruction.h as this file is included from there. +class HloInstruction; + +// The DomainMetadata represents the base class for metadata which can be +// attached to kDomain HLO instructions. +class DomainMetadata { + public: + // A Domain data structure captures all the information about a kDomain + // bounded instruction set. + struct Domain { + // The set of instructions which are reachable from each other via + // operand/user pathways, without crossing a kDomain instruction of a given + // kind. The reach_set can contain kDomain instructions of other kinds, if + // two domains of different kind intersect each other. + tensorflow::gtl::FlatSet reach_set; + + // The same instructions in reach_set, but purged from kDomain instructions. + std::vector instructions; + + // If we consider a graph edge as an arrow oriented from the operand to the + // user, the enter_domains will contain the set of kDomain instructions + // whose dataflow enters the reach set (domain), while the exit_domains + // contains the set of kDomain instructions whose dataflow exit the reach + // set. + tensorflow::gtl::FlatSet enter_domains; + tensorflow::gtl::FlatSet exit_domains; + }; + + virtual ~DomainMetadata() = default; + + // Clones the metadata object. + virtual std::unique_ptr Clone() const = 0; + + // Returns the metadata type. A unique identifier which describes the real + // metadata type. + virtual tensorflow::StringPiece Kind() const = 0; + + // Compares the metadata object with another one and returns true if the + // two matches. + virtual bool Matches(const DomainMetadata& other) const = 0; + + // Returns a string representation of the metadata. + virtual string ToString() const = 0; + + // Given a reachable set (the set of instructions which are reachable from + // each other via user/operand pathways, without crossing a kDomain + // instruciton), makes sure that all of them have metadata attributes which + // are coherent with this metadata object. + virtual Status NormalizeInstructions(const Domain& domain) const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc new file mode 100644 index 00000000000000..1d06040b0e7c92 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainRemover::RunContext { + public: + RunContext(HloModule* module, HloDomainRemover* remover) + : module_(module), remover_(remover) {} + + StatusOr Run(); + + private: + // Verifies the consistency of the domain, and normalizes the instructions + // within it. + Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain); + + HloModule* module_; + HloDomainRemover* remover_; +}; + +Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( + const DomainMetadata::Domain& domain) { + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + if (ref_metadata != nullptr) { + VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); + TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + } else { + // No kDomain instruction was present within this domain, so call the + // generic normalization functions and have them apply their heuristic. + VLOG(2) << "Applying domain-less normalization"; + TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + } + return Status::OK(); +} + +StatusOr HloDomainRemover::RunContext::Run() { + VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); + + int64 removed_domains = 0; + for (HloComputation* computation : module_->computations()) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); + // Verify and normalize every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + } + + // Now remove all the kDomain instructions of the kind specified by the + // remover, that are within the currently processed computation from the + // graph. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + for (HloInstruction* operand : instruction->unique_operands()) { + if (domain_map->IsDomainInstruction(operand)) { + VLOG(5) << "Removing " << operand->name(); + TF_RETURN_IF_ERROR( + operand->ReplaceAllUsesWith(operand->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + ++removed_domains; + } + } + } + HloInstruction* root = computation->root_instruction(); + if (root != nullptr && domain_map->IsDomainInstruction(root)) { + VLOG(5) << "Removing " << root->name(); + computation->set_root_instruction(root->mutable_operand(0)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + ++removed_domains; + } + } + VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" + << remover_->kind_ << "' kind"; + if (removed_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); + } + return removed_domains > 0; +} + +StatusOr HloDomainRemover::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h new file mode 100644 index 00000000000000..0c71dd34fd4d29 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Removes all the kDomain instructions of a given kind from the input module, +// and calls the normalizer to propagate the properties on the possibly new born +// instructions. +class HloDomainRemover : public HloPassInterface { + public: + // Creates a new HloDomainRemover object tasked at removing all the kDomain + // instructions of a given kind. + // In case a reachable set (the set of instructions within a computation, + // which are mutually reachable via operand/user pathways) has all the + // instructions in it with the same attributes (ie, sharding), a normalizer + // function is tasked at applying attribute normalization on the instructions + // within such domain. + HloDomainRemover( + tensorflow::StringPiece kind, + std::function normalizer) + : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + + tensorflow::StringPiece name() const override { return "domain_remover"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + string kind_; + std::function normalizer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc new file mode 100644 index 00000000000000..5553ddb153f7f1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloDomainTest : public HloTestBase { + protected: + bool FindUserViaDomainPath(HloInstruction* instruction, + HloInstruction* operand) const { + for (HloInstruction* user : operand->users()) { + if (user == instruction) { + return true; + } + if (user->opcode() == HloOpcode::kDomain && + FindUserViaDomainPath(instruction, user)) { + return true; + } + } + return false; + } + + // Checks whether there is a kDomain instruction in the edge between the + // instruction and the operand. + bool HasDomainEdge(HloModule* module, + tensorflow::StringPiece instruction_name, + tensorflow::StringPiece operand_name) { + HloInstruction* instruction = FindInstruction(module, instruction_name); + HloInstruction* operand = FindInstruction(module, operand_name); + CHECK_NE(instruction, nullptr); + CHECK_NE(operand, nullptr); + if (!instruction->IsUserOf(operand)) { + // If instruction is not an immediate user, we must find a path from + // operand to instruction anyway, otherwise there is a corruption. + if (FindUserViaDomainPath(instruction, operand)) { + return true; + } + LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name + << "' and '" << operand_name << "' instructions:\n" + << module->ToString(); + } + return false; + } + + StatusOr> ParseModule( + tensorflow::StringPiece hlo_string) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return ParseHloString(hlo_string, config); + } +}; + +// Dummy DomainMetadata implementation which create kDomain boundaries around +// HLO instructions with the same metadata().op_name() values. +class OpNameMetadata : public DomainMetadata { + public: + explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} + + std::unique_ptr Clone() const override { + return MakeUnique(opname_); + } + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override { + const OpNameMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a OpNameMetadata, then it is clearly a no match. + return false; + } + return opname_ == other_ptr->opname_; + } + + string ToString() const override { return opname_; } + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override { + // For the purposes of this test, nothing to do. + return Status::OK(); + } + + static tensorflow::StringPiece KindName() { return "opname"; } + + private: + string opname_; +}; + +// Creator function for OpNameMetadata domains. +std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* operand) { + if (instruction->metadata().op_name() == operand->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + MakeUnique(operand->metadata().op_name()); + std::unique_ptr user_side_metadata = + MakeUnique(instruction->metadata().op_name()); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { + // Nothing to do for the particular use this test make of the OpName domains. + return Status::OK(); +} + +TEST_F(HloDomainTest, CheckDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1} + d = f32[4] subtract(a, b), sharding={maximal device=1} + e = f32[4] multiply(c, d), sharding={maximal device=1} + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b) + d = f32[4] subtract(a, b) + e = f32[4] multiply(c, d) + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(!isolator_changed); +} + +TEST_F(HloDomainTest, CheckDomainAroundIO) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} + c = () send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} + e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} + f = f32[4] add(a, e) + g = f32[4] subtract(a, e) + ROOT h = (f32[4], f32[4]) tuple(f, g) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} + c = f32[4] add(b, b), sharding={maximal device=-1} + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_FALSE(isolator_changed); +} + +TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} + c = f32[4] add(b, b) + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_FALSE(remover_changed); + + HloInstruction* add = FindInstruction(module.get(), "c"); + ASSERT_NE(add, nullptr); + auto device = add->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, 0); +} + +TEST_F(HloDomainTest, CheckMultiDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(a, b), sharding={maximal device=1} + d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"} + e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"} + f = f32[4] add(e, c), sharding={maximal device=1} + ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator sharding_isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, + sharding_isolator.Run(module.get())); + EXPECT_TRUE(sharding_isolator_changed); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module.get())); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module.get())); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module.get())); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); +} + +TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + infeed = (f32[4], f32[4]) infeed(), + sharding={{maximal device=1}, {maximal device=0}} + gte0 = f32[4] get-tuple-element(infeed), index=0 + gte1 = f32[4] get-tuple-element(infeed), index=1 + copy0 = f32[4] copy(gte0) + copy1 = f32[4] copy(gte1) + ROOT add = f32[4] add(copy0, copy1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); + EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + + // Inject unassigned tuple/gte within the infeed domain, to simulate the + // HLO passes adding unexpected instructions. + // + // infeed + // / \ + // GTE0 GTE1 + // / \ + // COPY0 COPY1 + // \ / + // \ / + // TUPLE + // | + // DOMAIN + HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + ASSERT_NE(infeed, nullptr); + auto infeed_users = infeed->users(); + HloInstruction* new_gte0 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* new_copy0 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte0->shape(), HloOpcode::kCopy, new_gte0)); + HloInstruction* new_gte1 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_copy1 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte1->shape(), HloOpcode::kCopy, new_gte1)); + HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction::CreateTuple({new_copy0, new_copy1})); + for (HloInstruction* user : infeed_users) { + TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + struct Assignment { + HloInstruction* instruction; + int64 device; + } assignments[] = { + {new_gte0, 1}, + {new_copy0, 1}, + {new_gte1, 0}, + {new_copy1, 0}, + }; + for (auto& assignment : assignments) { + auto device = assignment.instruction->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, assignment.device); + } + EXPECT_TRUE(new_tuple->has_sharding()); + EXPECT_EQ( + new_tuple->sharding(), + HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index d236f83aeb9254..4ed1508d706768 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -119,6 +119,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { return false; } + HloCloneContext context(module); bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { @@ -140,6 +141,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || @@ -180,7 +182,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); new_hlo = computation->AddInstruction( - hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + hlo->CloneWithNewOperands(shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); @@ -189,16 +191,16 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - new_shape, new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(new_shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - hlo->shape(), new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index c5e30148345fec..1e78d775c8e172 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -52,25 +52,11 @@ namespace xla { namespace { using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::FlatSet; -using tensorflow::gtl::optional; - -template -struct is_complex_t : public std::false_type {}; - -template <> -struct is_complex_t : public std::true_type {}; - -template -struct is_complex64_t : public std::false_type {}; - -template <> -struct is_complex64_t : public std::true_type {}; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, - const Literal& lhs_literal, - const Literal& rhs_literal) { + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -108,7 +94,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -119,8 +105,8 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, template <> StatusOr> Compare( - const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, - const Literal& rhs_literal) { + const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -138,7 +124,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -147,2092 +133,48 @@ StatusOr> Compare( return std::move(result); } -template -StatusOr> ElementWiseUnaryOpImpl( - HloInstruction* instruction, - const std::function& unary_op, - const Literal& operand_literal) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return unary_op(operand_literal.Get(multi_index)); - })); - return std::move(result); -} - -// For one particular placement of a window in a base shape (the placement is -// represented as `window_count_index`), iterates inside the window. Translates -// the window index into base index. If the base index is within bound, call `f` -// with the base index. -void IterateThroughWindow( - const Shape& window_shape, const Window& window, const Shape& base_shape, - const ArraySlice& window_count_index, - const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); - DimensionVector window_index(rank); - std::fill(window_index.begin(), window_index.end(), 0); - do { - std::vector base_index(rank); - bool out_of_bound = false; - for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - f(base_index); - } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); -} - -// Creates a vector of multipliers which can be used to create a linear index -// into shape. -// -// Given the multidimensional index {i1, ..., iN} and -// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply -// -// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. -// -// This lets you calculate LI given the multidimensional indices in any order. -DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); - int64 scale = 1; - for (auto dim : LayoutUtil::MinorToMajor(shape)) { - v[dim] = scale; - scale *= shape.dimensions(dim); - } - return v; -} - } // namespace -template -class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { - public: - explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} - - // The following higher-order functions convert a function with ElementwiseT - // to a function with ReturnT. - std::function ConvertUnaryFunction( - const std::function& unary_op) { - return [&unary_op](ReturnT arg) { - return static_cast(unary_op(static_cast(arg))); - }; - } - std::function ConvertBinaryFunction( - const std::function& - binary_op) { - return [&binary_op](ReturnT arg1, ReturnT arg2) { - return static_cast(binary_op(static_cast(arg1), - static_cast(arg2))); - }; - } - std::function ConvertTernaryFunction( - const std::function& ternary_op) { - return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { - return static_cast(ternary_op(static_cast(arg1), - static_cast(arg2), - static_cast(arg3))); - }; - } - - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); - } - - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive types. - - template ::value>::type* = - nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return std::abs(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(abs->operand(0)); - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[abs], - (ElementWiseUnaryOpImpl( - abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, - operand_literal))); - - return Status::OK(); - } - - Status HandleAbs(HloInstruction* abs) override { - // If the operand is of C64 type, the return type of abs will be F32. - // However, ElementwiseT would still be the return type, F32, and thus - // specifying the ElementwiseT explicitly as C64 is needed below. - if (abs->operand(0)->shape().element_type() == C64) { - return HandleAbs(abs); - } - return HandleAbs(abs); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { - return std::round(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); - } - - Status HandleRound(HloInstruction* round) override { - return HandleRound(round); - } - - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate([&](ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { - return std::ceil(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); - } - - Status HandleCeil(HloInstruction* ceil) override { - return HandleCeil(ceil); - } - - Status HandleConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).Convert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleBitcastConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleExp(HloInstruction* exp) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { - return std::exp(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { - return std::floor(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); - } - - Status HandleFloor(HloInstruction* floor) override { - return HandleFloor(floor); - } - - Status HandleLog(HloInstruction* log) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { - return std::log(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return ~elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); - } - - Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { - return NativeT(-type(elem_operand)); - })); - return Status::OK(); - } - - template ::value || - std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp( - negate, [](ElementwiseT elem_operand) { return -elem_operand; })); - return Status::OK(); - } - - Status HandleNegate(HloInstruction* negate) override { - return HandleNegate(negate); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - return (ElementwiseT(0) < elem_operand) - - (elem_operand < ElementwiseT(0)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ElementwiseT(0) - : elem_operand / abs_val; - })); - return Status::OK(); - } - - Status HandleSign(HloInstruction* sign) override { - return HandleSign(sign); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], - ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return std::atan2(lhs_elem, rhs_elem); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); - } - - Status HandleAtan2(HloInstruction* atan2) override { - return HandleAtan2(atan2); - } - - Status HandleTanh(HloInstruction* tanh) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { - return std::tanh(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); - return Status::OK(); - } - - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - - Status HandleSubtract(HloInstruction* subtract) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); - return Status::OK(); - } - - Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; - })); - return Status::OK(); - } - - Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem / rhs_elem; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::max(lhs, rhs); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); - } - - Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); - } - - template ::value>::type* = - nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::min(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); - } - - Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); - } - - Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); - } - - Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); - } - - template ::value>::type* = - nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el & rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); - } - - Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); - } - - template ::value>::type* = - nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el | rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); - } - - Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction* shl) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shl], - ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return IsShiftOutOfBounds(rhs_elem) ? 0 - : (lhs_elem << rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); - } - - Status HandleShiftLeft(HloInstruction* shl) override { - return HandleShiftLeft(shl); - } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction* shr) { - typedef typename std::make_signed::type SignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - SignedT lhs_signed = static_cast(lhs_elem); - if (IsShiftOutOfBounds(rhs_elem)) { - return lhs_signed < 0 ? static_cast(-1) : 0; - } else { - return lhs_signed >> rhs_elem; - } - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); - } - - Status HandleShiftRightArithmetic(HloInstruction* shra) override { - return HandleShiftRightArithmetic(shra); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction* shr) { - typedef typename std::make_unsigned::type UnsignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - // If shift amount is greater than the number of bits, then return 0. - if (IsShiftOutOfBounds(rhs_elem)) { - return static_cast(0); - } - return static_cast(static_cast(lhs_elem) >> - rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); - } - - Status HandleShiftRightLogical(HloInstruction* shrl) override { - return HandleShiftRightLogical(shrl); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction* clamp) { - std::function - clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); - }; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clamp], - ElementwiseTernaryOp(clamp, - std::move(ConvertTernaryFunction(clamp_op)))); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); - } - - Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); - } - - Status HandleSelect(HloInstruction* select) override { - CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); - std::function select_op = - [](bool pred, ReturnT on_true, ReturnT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementwiseTernaryOp(select, std::move(select_op))); - return Status::OK(); - } - - Status HandleReverse(HloInstruction* reverse) override { - const auto result_shape = reverse->shape(); - const auto reverse_dimensions = reverse->dimensions(); - - auto operand = reverse->operand(0); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReverseShape(operand->shape(), - reverse_dimensions)); - - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice out_index) { - std::vector from_index(out_index.begin(), out_index.end()); - for (const int64 dim : reverse_dimensions) { - from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; - } - return operand_literal.Get(from_index); - })); - - parent_->evaluated_[reverse] = std::move(result); - return Status::OK(); - } - - Status HandleConvolution(HloInstruction* conv) override { - auto lhs = conv->operand(0); - auto rhs = conv->operand(1); - const auto& window = conv->window(); - const Shape& result_shape = conv->shape(); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); - - TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); - TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); - - const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); - CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 0); - CHECK_EQ(window.dimensions_size(), num_spatial_dims); - - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); - - CHECK_EQ(num_spatial_dims + 2, lhs_rank); - CHECK_EQ(num_spatial_dims + 2, rhs_rank); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); - CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - std::vector window_dimension_sizes; - for (auto i : dnums.kernel_spatial_dimensions()) { - window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); - } - - const Shape& window_shape = - ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); - - DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); - DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); - - auto lhs_literal_data = lhs_literal.data(); - auto rhs_literal_data = rhs_literal.data(); - - auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, - &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](ArraySlice out_index) { - // Dimension number applicable for input (lhs). - const int64 input_batch_dim = dnums.input_batch_dimension(); - const int64 input_z_dim = dnums.input_feature_dimension(); - // Dimension number applicable for kernel (rhs). - const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); - const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - // Dimension number applicable for output. - const int64 output_batch_dim = dnums.output_batch_dimension(); - const int64 output_z_dim = dnums.output_feature_dimension(); - - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - - ElementwiseT result_val = static_cast(0); - DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), - 0); - - // Convolve input feature with kernel. - do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 lhs_linear_index = 0; - lhs_linear_index += out_index[output_batch_dim] * - lhs_dim_multipliers[input_batch_dim]; - lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - - int64 rhs_linear_index = 0; - rhs_linear_index += out_index[output_z_dim] * - rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - - result_val += - static_cast(lhs_literal_data[lhs_linear_index]) * - static_cast(rhs_literal_data[rhs_linear_index]); - } - cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - - return static_cast(result_val); - }; - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); - - parent_->evaluated_[conv] = std::move(result); - return Status::OK(); - } - - Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); - - const auto& dnums = dot->dot_dimension_numbers(); - - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - - // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); - const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); - // Contracted dimension sizes must be the same. - CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), - rhs->shape().dimensions(rhs_contracting_dimension)) - << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracting_dimension) - << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(dot->shape()); - - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - - std::vector lhs_non_contracting_dims; - for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); - } - } - - std::vector rhs_non_batch_non_contracting_dims; - FlatSet batch_dims_set(dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); - } - } - - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice result_index) { - ElementwiseT result_val = static_cast(0); - - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; - } - - // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; - - result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); - } - - return static_cast(result_val); - })); - - parent_->evaluated_[dot] = std::move(result); - return Status::OK(); - } - - Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); - // Padding value must be scalar. - CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), - pad->padding_config().dimensions_size()); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferPadShape( - /*operand_shape=*/pad->operand(0)->shape(), - /*padding_value_shape=*/pad->operand(1)->shape(), - /*padding_config=*/pad->padding_config())); - CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - // Create new HLO of padded shape with padding value. - ReturnT scalar = - parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = Literal::CreateFromShape(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&scalar](ArraySlice multi_index) { return scalar; })); - - const Literal& evaluated_operand = - parent_->GetEvaluatedLiteralFor(pad->operand(0)); - - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); - - // Loop through each element of the operand, assign them to the - // corresponding index of the resulting padded literal. - const PaddingConfig& pad_config = pad->padding_config(); - - auto func = [&](ArraySlice input_index) { - for (auto i = 0; i < input_index.size(); ++i) { - // Interior padding occurs logically before edge padding, so in the case - // of negative edge padding elements are removed from the - // interior-padded operand. - target_index[i] = - pad_config.dimensions(i).edge_padding_low() + - input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); - - // Account for negative low and high padding: skip assignment if the - // any target index is out of range. - if (!(target_index[i] >= 0 && - target_index[i] < pad->shape().dimensions(i))) { - return true; - } - } - result->Set(target_index, - evaluated_operand.Get(input_index)); - return true; - }; - - std::vector zero_base(evaluated_operand.shape().dimensions_size(), - 0); - std::vector step(evaluated_operand.shape().dimensions_size(), 1); - - ShapeUtil::ForEachIndex( - evaluated_operand.shape(), zero_base, - AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); - - parent_->evaluated_[pad] = std::move(result); - return Status::OK(); - } - - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - auto operand = dynamic_slice->operand(0); - auto start_indices = dynamic_slice->operand(1); - auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - default: - LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - auto operand = dynamic_update_slice->operand(0); - auto update = dynamic_update_slice->operand(1); - auto start_indices = dynamic_update_slice->operand(2); - auto result_shape = dynamic_update_slice->shape(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - default: - LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - template - StatusOr> MapImpl(HloInstruction* map) { - auto operands = map->operands(); - HloComputation* computation = map->to_apply(); - - auto result = Literal::CreateFromShape(map->shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - std::vector> arg_literals; - arg_literals.reserve(operands.size()); - - // Construct scalar literal parameters to be passed to the map - // computation. - for (auto operand : operands) { - const Literal& arg_literal = - parent_->GetEvaluatedLiteralFor(operand); - - auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); - - arg_literals.push_back(std::move(curr_val_literal)); - } - - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) - .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - return computed_result->Get({}); - })); - return std::move(result); - } - - Status HandleMap(HloInstruction* map) override { - switch (map->operand(0)->shape().element_type()) { - case PRED: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F16: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], - MapImpl(map)); - break; - } - case F32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case C64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - default: - LOG(FATAL) << "HandleMap: unhandled primitive type for " - "input operand: " - << PrimitiveType_Name( - map->operand(0)->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleReduce(HloInstruction* reduce) override { - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - ArraySlice dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), - /*dimensions_to_reduce=*/dimensions, - /*to_apply=*/function->ComputeProgramShape())); - TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce->shape()); - - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); - std::vector arg_dim_steps(arg_dimensions.size()); - std::vector arg_dim_counts(arg_dimensions.size()); - for (const int64 dim : dimensions) { - arg_dim_steps[dim] = 1; - arg_dim_counts[dim] = arg_dimensions[dim]; - } - - // Map each dimension in the result to a dimension in arg that isn't - // being reduced. - std::vector result_to_arg_index; - for (int64 i = 0; i < arg_dimensions.size(); ++i) { - if (arg_dim_steps[i] == 0) { - result_to_arg_index.push_back(i); - } - } - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - ReturnT result_val = init_scalar; - - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } - - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](ArraySlice input_index) { - computed_result += arg_literal.Get(input_index); - return true; - }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = [&](ArraySlice input_index) { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {result_val_literal.get(), - curr_val_literal.get()}; - - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); - } - - bool IsScalarAdd(HloComputation* computation) { - HloInstruction* instruction = computation->root_instruction(); - if (instruction->opcode() == HloOpcode::kAdd && - computation->num_parameters() == 2) { - const HloInstruction* lhs = instruction->operand(0); - const HloInstruction* rhs = instruction->operand(1); - return lhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(lhs->shape()) && - rhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; - } - return false; - } - - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { - auto operand = select_and_scatter->operand(0); - auto source = select_and_scatter->operand(1); - const Window& window = select_and_scatter->window(); - - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(select_and_scatter->shape()); - - // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( - [&](ArraySlice output_index) { return init_scalar; })); - - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - HloComputation* select = select_and_scatter->select(); - HloComputation* scatter = select_and_scatter->scatter(); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - - int64 rank = ShapeUtil::Rank(operand_literal.shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); - do { - // For each element in `source`, we place a window in `operand`. For each - // window placement, we iterate inside the window twice: - // - // 1. Find the selected index by applying `select` function to all - // elements. E.g., If the `select` function is GreaterEqual, the first - // iteration through the window finds the biggest value and returns its - // index. - // - // 2. Using the selected index, scatter value from `source` to result. We - // do this by iterating through the window, and compare each index with - // the selected index. - optional selected_val; - optional> selected_index; - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - if (!selected_val) { - selected_val = curr_val; - selected_index = operand_index; - } - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto selected_val_literal = - Literal::CreateR0(*selected_val); - - const std::vector args = { - selected_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*select, args) - .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); - if (selected) { - selected_val = curr_val; - selected_index = operand_index; - } - embedded_evaluator.ResetVisitStates(); - }); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - if (std::equal(operand_index.begin(), operand_index.end(), - selected_index->begin())) { - auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - const auto source_literal = Literal::CreateR0(source); - const auto scattered_literal = - Literal::CreateR0(scattered); - - const std::vector args = { - source_literal.get(), scattered_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*scatter, args) - .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); - // Clear visit states so that the we can use the evaluator again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - } - }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); - - parent_->evaluated_[select_and_scatter] = std::move(result); - return Status::OK(); - } - - Status HandleReduceWindow(HloInstruction* reduce_window) override { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferReduceWindowShape( - /*operand_shape=*/reduce_window->operand(0)->shape(), - /*init_value=*/reduce_window->operand(1)->shape(), window, - /*to_apply_shape=*/function->ComputeProgramShape())); - TF_RET_CHECK( - ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) - << "return shape is set to: " - << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanStringWithLayout(inferred_return_shape); - - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); - VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); - VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce_window->shape()); - - // Creates a Shape object from window, for iteration below. - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice output_index) { - ReturnT result_val = init_scalar; - - std::fill(window_index.begin(), window_index.end(), 0); - std::fill(operand_index.begin(), operand_index.end(), 0); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), output_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - - // Evaluate computation with specified literal operands. - const auto curr_val_literal = - Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = { - result_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - // Clear visit states so that the we can use the evaluate again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - - result_val = computed_result->Get({}); - }); - - return result_val; - })); - - parent_->evaluated_[reduce_window] = std::move(result); - return Status::OK(); - } - - Status HandleSlice(HloInstruction* slice) override { - auto operand = slice->operand(0); - const Shape& shape = slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferSliceShape( - operand->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const int64 rank = ShapeUtil::Rank(operand->shape()); - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](ArraySlice out_index) { - DimensionVector operand_index(rank); - for (int64 i = 0; i < rank; ++i) { - operand_index[i] = - slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); - } - return operand_literal.Get(operand_index); - }; - - auto result = Literal::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); - parent_->evaluated_[slice] = std::move(result); - return Status::OK(); - } - - // Enable CLZ only for int32 and uint32. - template < - typename NativeT, - typename std::enable_if< - (std::is_floating_point::value || - std::is_integral::value || is_complex_t::value) && - !(std::is_same::value || - std::is_same::value)>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - return InvalidArgument("Unsupported type for Clz"); - } - - template ::value || - std::is_same::value>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], - ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { - return 31 - tensorflow::Log2Floor(elem_operand); - })); - return Status::OK(); - } - - Status HandleClz(HloInstruction* clz) override { - return HandleClz(clz); - } - - template ::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { - return std::sin(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); - } - - Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); - } - - template ::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { - return std::cos(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); - } - - Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[reduce_precision], - ElementWiseUnaryOp(reduce_precision, [reduce_precision]( - ElementwiseT elem) { - uint32_t value_as_int = tensorflow::bit_cast(elem); - const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); - const uint32_t exponent_bits = reduce_precision->exponent_bits(); - - // Code is based on the CPU/GPU implementation in LLVM-emitting code. - // - // Bits in float type: - // mantissa : bits [0:22] - // exponent : bits [23:30] - // sign : bits [31] - if (mantissa_bits < 23) { - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); - - // Compute rounding bias for round-to-nearest with ties to even. - // This is equal to a base value of 0111... plus one bit if the last - // remaining mantissa bit is 1. - const uint32_t base_rounding_bias = - (last_mantissa_bit_mask >> 1) - 1; - const uint32_t x_last_mantissa_bit = - (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); - const uint32_t x_rounding_bias = - x_last_mantissa_bit + base_rounding_bias; - - // Add rounding bias, and mask out truncated bits. Note that the - // case where adding the rounding bias overflows into the exponent - // bits is correct; the non-masked mantissa bits will all be zero, - // and the exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - value_as_int = value_as_int + x_rounding_bias; - value_as_int = value_as_int & truncation_mask; - } - if (exponent_bits < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the - // most- significant bit -- is equal to 1.0f for all exponent sizes. - // Adding 2^(n-1)-1 to this gives us the highest non-infinite - // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from - // this gives us the lowest' exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n - // is (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; - const bool x_overflows = x_exponent > (reduced_max_exponent << 23); - const bool x_underflows = - x_exponent <= (reduced_min_exponent << 23); - - // Compute appropriately-signed values of zero and infinity. - const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; - const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; - - // Force to zero or infinity if overflow or underflow. (Note that - // this truncates all denormal values to zero, rather than rounding - // them.) - value_as_int = x_overflows ? x_signed_inf : value_as_int; - value_as_int = x_underflows ? x_signed_zero : value_as_int; - } - - float reduced_result = tensorflow::bit_cast(value_as_int); - if (std::isnan(elem)) { - reduced_result = mantissa_bits > 0 - ? elem - : std::numeric_limits::infinity(); - } - return reduced_result; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); - } - - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return HandleReducePrecision(reduce_precision); - } - - private: - template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - - std::vector operand_indices(start.size()); - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - for (int64 i = 0; i < operand_indices.size(); ++i) { - CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); - } - - auto result = operand_literal.Get(operand_indices); - return result; - })); - - return std::move(result); - } - - template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); - auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); - std::vector start(rank, 0); - for (int64 i = 0; i < rank; ++i) { - // All other implementations currently wrap-around the index, so this - // should do so as well. - start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); - start[i] += (start[i] < 0) * result->shape().dimensions(i); - } - std::vector result_index(rank, 0); - - auto func = [&](ArraySlice update_index) { - std::transform(update_index.begin(), update_index.end(), start.begin(), - result_index.begin(), std::plus()); - // Same as above, wrap-around only to match other implementations' - // semantics. - std::transform(result_index.begin(), result_index.end(), - result->shape().dimensions().begin(), result_index.begin(), - std::modulus()); - result->Set(result_index, - update_literal.Get(update_index)); - return true; - }; - - std::vector base(update_literal.shape().dimensions_size(), 0); - std::vector step(update_literal.shape().dimensions_size(), 1); - ShapeUtil::ForEachIndex(update_literal.shape(), base, - AsInt64Slice(update_literal.shape().dimensions()), - step, func); - - return std::move(result); - } - - StatusOr> ElementWiseUnaryOp( - HloInstruction* instruction, - const std::function& unary_op) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - TF_ASSIGN_OR_RETURN( - auto result_literal, - (ElementWiseUnaryOpImpl( - instruction, ConvertUnaryFunction(unary_op), operand_literal))); - - return std::move(result_literal); - } - - StatusOr> ElementWiseBinaryOp( - HloInstruction* instruction, - const std::function& - binary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return ConvertBinaryFunction(binary_op)( - lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); - return std::move(result); - } - - template - StatusOr> ElementwiseTernaryOp( - HloInstruction* instruction, - const std::function& ternary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return ternary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index), - ehs_literal.Get(multi_index)); - })); - - return std::move(result); - } - - template - static bool IsShiftOutOfBounds(NativeT rhs) { - typedef typename std::make_unsigned::type UnsignedT; - UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; - UnsignedT rhs_unsigned = static_cast(rhs); - return rhs_unsigned >= lhs_size_unsigned; - } - - HloEvaluator* parent_; -}; // class HloEvaluator::TypedVisitor HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: U16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: S16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); - typed_visitors_[F16] = MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = + MakeUnique>(this); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. - typed_visitors_[BF16] = MakeUnique>(this); + typed_visitors_[BF16] = + MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); } @@ -2367,6 +309,35 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } +StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), + rhs_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + +StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand) { + std::unique_ptr operand_instr = + HloInstruction::CreateConstant(operand.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -2536,6 +507,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { } break; case F16: return Unimplemented("unhandled primitive type: F16."); + case BF16: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; case F32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -2912,6 +888,28 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { return Status::OK(); } +Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { + const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand.shape().dimensions(i)); + } + + TF_ASSIGN_OR_RETURN( + evaluated_[broadcast], + operand.Broadcast(broadcast->shape(), broadcast->dimensions())); + + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -2963,12 +961,14 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } Status HloEvaluator::HandleFusion(HloInstruction* fusion) { + HloModuleConfig config; // Attach cloned computation to an empty HLO module so the existing ones are // not modified. - HloModule empty_hlo_module("EmptyModuleForFusion"); + HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloCloneContext context(&empty_hlo_module); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( - /*suffix=*/"clone_with_layout", &empty_hlo_module); + /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); } @@ -3003,8 +1003,8 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* true_computation = conditional->true_computation(); auto* false_computation = conditional->false_computation(); - auto result = Literal::CreateFromShape(conditional->shape()); HloEvaluator embedded_evaluator; + std::unique_ptr result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -3028,7 +1028,7 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. // This would also handle output array of tuple types as the DefaultAction - // would go through the TypedVisitor which doesn't handle tuples. + // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { evaluated_[select] = on_true.CloneToUnique(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c0dcee0c3e382f..b53d5644de5a17 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -108,20 +109,23 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const std::unordered_map& substitutions); + StatusOr> EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs); + + StatusOr> EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand); + protected: - // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this rule, notably: - // - HandleCompare and HandleIsFinite: where the resulting literal type is - // always boolean. - // These operations are handled outside of the parent HloEvaluator handlers - // instead of from within TypedVisitor. + // Make HloEvaluatorTypedVisitor a friend because it is logically part of this + // class. // - // Type params: - // - ReturnT: The type of input and output of each operation. - // - ElementwiseT: The type in which internal computation are done. - template - class TypedVisitor; + // A straightforward implementation would be to make it a nested class + // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor + // lives as a separate class with its own header because its template gets + // instantiated many times and we want to use extern templates to shard out + // the compilation of those instantiations across multiple cc files. + template + friend class HloEvaluatorTypedVisitor; // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. @@ -168,7 +172,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; - private: + Status HandleBroadcast(HloInstruction* broadcast) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -183,14 +188,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return *(it->second); } - // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; - // Tracks the HLO instruction and its evaluated literal result. // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in @@ -199,6 +196,41 @@ class HloEvaluator : public DfsHloVisitorWithDefault { tensorflow::gtl::FlatMap> evaluated_; + private: + template + static StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = MakeUnique(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); + })); + return std::move(result); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of // each invocation to the Evaluate* method. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index dd14dd38537a83..84b4ead2dd28ca 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - LiteralTestUtil::ExpectNear(*expected, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); } else { - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } } @@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } bool use_bfloat16_; @@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -262,13 +262,13 @@ TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = @@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies Reshape operation is correctly evaluated. @@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -333,7 +333,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); } @@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } PaddingConfig CreatePaddingConfig( @@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = Literal::CreateR4FromArray4D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { auto expected = Literal::CreateR1({22.f, 28.f}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -827,7 +827,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1046,7 +1046,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1109,7 +1109,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, @@ -1180,7 +1180,7 @@ TEST_P(HloEvaluatorTest, *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { auto expected = Literal::CreateR1({6, 18}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); - LiteralTestUtil::ExpectEqual(*result_literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { {19}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { {5, -6, -7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { {5, 6, 7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { result_inner_literal.get(), }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { auto result = evaluator.EvaluateWithSubstitutions( add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1823,9 +1823,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1847,9 +1847,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1872,10 +1872,10 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 2}, {2, 1}}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1900,9 +1900,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1928,9 +1928,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1952,9 +1952,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1977,9 +1977,9 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{2, 1}, {1, 1}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2000,9 +2000,50 @@ ENTRY main { ParseAndVerifyModule(hlo_text); std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { + const string hlo_text = R"( +HloModule GatherXd + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1} +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise comparison with 2 bfloat16 operands. +TEST_P(HloEvaluatorTest, DoesCompareBF16) { + // lhs >= rhs + auto lhs = Literal::CreateR2( + {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, + {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}}); + auto rhs = Literal::CreateR2( + {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)}, + {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); + auto expected = + Literal::CreateR2({{false, true, true}, {false, true, true}}); + TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), + std::move(rhs)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h new file mode 100644 index 00000000000000..13f46407e33e36 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -0,0 +1,2142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// TODO(b/79274244): We'd like these type traits to live inside of +// HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that +// crashes clang in the frontend. +// +// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is +// a "private" header that's not exposed outside of hlo_evaluator.cc. +template +using is_complex_t = std::is_same; +template +using is_complex64_t = std::is_same; + +// Templated DfsHloVisitor for use by HloEvaluator. +// +// Typically ReturnT here indicates the resulting literal type of each evaluated +// Handle* method of a TypedVisitor. There are however a few notable exceptions +// to this rule, notably: +// - HandleCompare and HandleIsFinite: where the resulting literal type is +// always boolean. +// These operations are handled outside of the parent HloEvaluator handlers +// instead of from within TypedVisitor. +// +// Type params: +// - ReturnT: The type of input and output of each operation. +// - ElementwiseT: The type in which internal computation are done. +// +// This a logically a private part of HloEvaluator. It lives in this header +// file rather than in hlo_evaluator.cc because we use extern templates and a +// bunch of independent cc files to speed up compiling the many instantiations +// of this class. +template +class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} + + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + } + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive type. + + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (HloEvaluator::ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } + return HandleAbs(abs); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound(round); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil) override { + return HandleCeil(ceil); + } + + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleBitcastConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleExp(HloInstruction* exp) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::expm1(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Expm1"); + } + + Status HandleExpm1(HloInstruction* floor) override { + return HandleExpm1(floor); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor) override { + return HandleFloor(floor); + } + + Status HandleLog(HloInstruction* log) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::log1p(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Log1p"); + } + + Status HandleLog1p(HloInstruction* floor) override { + return HandleLog1p(floor); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } + + Status HandleNot(HloInstruction* not_) override { + return HandleNot(not_); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ElementwiseT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign) override { + return HandleSign(sign); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2(atan2); + } + + Status HandleTanh(HloInstruction* tanh) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + } + + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + + Status HandleSubtract(HloInstruction* subtract) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum(maximum); + } + + template ::value>::type* = + nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum(minimum); + } + + Status HandlePower(HloInstruction* power) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder) override { + return HandleRemainder(remainder); + } + + template ::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_) override { + return HandleAnd(and_); + } + + template ::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_) override { + return HandleOr(or_); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return IsShiftOutOfBounds(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl) override { + return HandleShiftLeft(shl); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + SignedT lhs_signed = static_cast(lhs_elem); + if (IsShiftOutOfBounds(rhs_elem)) { + return lhs_signed < 0 ? static_cast(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra) override { + return HandleShiftRightArithmetic(shra); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (IsShiftOutOfBounds(rhs_elem)) { + return static_cast(0); + } + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl) override { + return HandleShiftRightLogical(shrl); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return std::fmin(high, std::fmax(value, low)); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction*) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp) override { + return HandleClamp(clamp); + } + + Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementwiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + } + + Status HandleReverse(HloInstruction* reverse) override { + const auto result_shape = reverse->shape(); + const auto reverse_dimensions = reverse->dimensions(); + + auto operand = reverse->operand(0); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReverseShape(operand->shape(), + reverse_dimensions)); + + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto result = MakeUnique(result_shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + std::vector from_index(out_index.begin(), out_index.end()); + for (const int64 dim : reverse_dimensions) { + from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; + } + return operand_literal.Get(from_index); + })); + + parent_->evaluated_[reverse] = std::move(result); + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(ShapeUtil::IsArray(lhs_shape)); + CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 0); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + const auto lhs_rank = ShapeUtil::Rank(lhs_shape); + const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, + window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + std::vector window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); + } + + const Shape& window_shape = + ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); + + DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); + DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); + + auto lhs_literal_data = lhs_literal.data(); + auto rhs_literal_data = rhs_literal.data(); + + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, + &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, + rhs_literal_data]( + tensorflow::gtl::ArraySlice out_index) { + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + ElementwiseT result_val = static_cast(0); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), + 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + int64 lhs_linear_index = 0; + lhs_linear_index += out_index[output_batch_dim] * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; + + int64 rhs_linear_index = 0; + rhs_linear_index += out_index[output_z_dim] * + rhs_dim_multipliers[kernel_output_z_dim]; + rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + lhs_linear_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < + lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + rhs_linear_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + + result_val += + static_cast(lhs_literal_data[lhs_linear_index]) * + static_cast(rhs_literal_data[rhs_linear_index]); + } + cnt : {} + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return static_cast(result_val); + }; + + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + } + + Status HandleDot(HloInstruction* dot) override { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + auto result = MakeUnique(dot->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice result_index) { + ElementwiseT result_val = static_cast(0); + + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; + } + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; + + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); + } + + return static_cast(result_val); + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); + auto result = MakeUnique(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&scalar](tensorflow::gtl::ArraySlice multi_index) { + return scalar; + })); + + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set(target_index, + evaluated_operand.Get(input_index)); + return true; + }; + + std::vector zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + } + + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + auto operand = dynamic_slice->operand(0); + auto start_indices = dynamic_slice->operand(1); + auto result_shape = dynamic_slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), + dynamic_slice->dynamic_slice_sizes())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + default: + LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + auto operand = dynamic_update_slice->operand(0); + auto update = dynamic_update_slice->operand(1); + auto start_indices = dynamic_update_slice->operand(2); + auto result_shape = dynamic_update_slice->shape(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + default: + LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = MakeUnique(map->shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + /*arg=*/arg->shape(), + /*init_value=*/init_value->shape(), + /*dimensions_to_reduce=*/dimensions, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + std::vector arg_dim_steps(arg_dimensions.size()); + std::vector arg_dim_counts(arg_dimensions.size()); + for (const int64 dim : dimensions) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index.push_back(i); + } + } + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = init_scalar; + + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + computed_result += arg_literal.Get(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto curr_val = arg_literal.Get(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = Literal::CreateR0(curr_val); + auto result_val_literal = Literal::CreateR0(result_val); + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + result_val = computed_result->Get({}); + return true; + }; + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); + return Status::OK(); + } + + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = MakeUnique(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + return init_scalar; + })); + + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + DimensionVector source_index(rank, 0); + + // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid + // dynamic memory allocations. + auto curr_val_literal = Literal::CreateR0(ReturnT()); + auto selected_val_literal = Literal::CreateR0(ReturnT()); + auto source_literal_scatter = Literal::CreateR0(ReturnT()); + auto scattered_literal = Literal::CreateR0(ReturnT()); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional selected_val; + tensorflow::gtl::optional> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + curr_val_literal->Set({}, curr_val); + selected_val_literal->Set({}, *selected_val); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *select, + {selected_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + bool selected = !computed_result->Get({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get(source_index); + auto scattered = result->Get(operand_index); + source_literal_scatter->Set({}, source); + scattered_literal->Set({}, scattered); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate(*scatter, + {source_literal_scatter.get(), + scattered_literal.get()}) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + + Status HandleReduceWindow(HloInstruction* reduce_window) override { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferReduceWindowShape( + /*operand_shape=*/reduce_window->operand(0)->shape(), + /*init_value=*/reduce_window->operand(1)->shape(), window, + /*to_apply_shape=*/function->ComputeProgramShape())); + TF_RET_CHECK( + ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) + << "return shape is set to: " + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); + + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); + VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); + VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + // Creates a Shape object from window, for iteration below. + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + DimensionVector window_index(window.dimensions_size()); + DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce_window->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + ReturnT result_val = init_scalar; + + std::fill(window_index.begin(), window_index.end(), 0); + std::fill(operand_index.begin(), operand_index.end(), 0); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); + }); + + return result_val; + })); + + parent_->evaluated_[reduce_window] = std::move(result); + return Status::OK(); + } + + Status HandleSlice(HloInstruction* slice) override { + auto operand = slice->operand(0); + const Shape& shape = slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferSliceShape( + operand->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); + TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const int64 rank = ShapeUtil::Rank(operand->shape()); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto func = [&](tensorflow::gtl::ArraySlice out_index) { + DimensionVector operand_index(rank); + for (int64 i = 0; i < rank; ++i) { + operand_index[i] = + slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); + } + return operand_literal.Get(operand_index); + }; + + auto result = Literal::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + TF_RETURN_IF_ERROR(result->Populate(func)); + parent_->evaluated_[slice] = std::move(result); + return Status::OK(); + } + + // Enable CLZ only for int32, uint32, int64 and uint64. + template < + typename NativeT, + typename std::enable_if< + (std::is_floating_point::value || + std::is_integral::value || is_complex_t::value) && + !(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + return InvalidArgument("Unsupported type for Clz"); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 31 - tensorflow::Log2Floor(elem_operand); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 63 - tensorflow::Log2Floor64(elem_operand); + })); + return Status::OK(); + } + + Status HandleClz(HloInstruction* clz) override { + return HandleClz(clz); + } + + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision(reduce_precision); + } + + private: + // Creates a vector of multipliers which can be used to create a linear index + // into shape. + // + // Given the multidimensional index {i1, ..., iN} and + // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply + // + // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. + // + // This lets you calculate LI given the multidimensional indices in any order. + static DimensionVector MakeDimMultipliers(const Shape& shape) { + DimensionVector v(ShapeUtil::Rank(shape)); + int64 scale = 1; + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + v[dim] = scale; + scale *= shape.dimensions(dim); + } + return v; + } + + // For one particular placement of a window in a base shape (the placement is + // represented as `window_count_index`), iterates inside the window. + // Translates the window index into base index. If the base index is within + // bound, call `f` with the base index. + static void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice& window_count_index, + const std::function&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } + + template + StatusOr> DynamicSlice( + const Literal& operand_literal, const Literal& start_indices_literal, + const Shape& result_shape) { + auto start_indices_typed = start_indices_literal.data(); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + + // Clamp the start indices so the slice is in-bounds w.r.t the operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to officially document different behavior. + for (int64 i = 0; i < start.size(); ++i) { + start[i] = std::min( + std::max(int64{0}, start[i]), + operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); + } + + std::vector operand_indices(start.size()); + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < operand_indices.size(); ++i) { + CHECK_GE(multi_index[i] + start[i], 0); + operand_indices[i] = multi_index[i] + start[i]; + } + + auto result = operand_literal.Get(operand_indices); + return result; + })); + + return std::move(result); + } + + template + StatusOr> DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.CloneToUnique(); + auto start_indices_typed = start_indices_literal.data(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + // Clamp the update start indices so the slice is in-bounds w.r.t the + // operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + for (int64 i = 0; i < rank; ++i) { + start[i] = std::min( + std::max(0, start[i]), + result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + } + std::vector result_index(rank, 0); + + auto func = [&](tensorflow::gtl::ArraySlice update_index) { + std::transform(update_index.begin(), update_index.end(), start.begin(), + result_index.begin(), std::plus()); + result->Set(result_index, + update_literal.Get(update_index)); + return true; + }; + + std::vector base(update_literal.shape().dimensions_size(), 0); + std::vector step(update_literal.shape().dimensions_size(), 1); + ShapeUtil::ForEachIndex(update_literal.shape(), base, + AsInt64Slice(update_literal.shape().dimensions()), + step, func); + + return std::move(result); + } + + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (HloEvaluator::ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); + } + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& + binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + return std::move(result); + } + + template + StatusOr> ElementwiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); + })); + + return std::move(result); + } + + template + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + + HloEvaluator* parent_; +}; + +// These extern templates prevent users of this class from implicitly +// instantiating it. We explicitly instantiate this class in the various +// hlo_evaluator_typed_visitor*.cc files. +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc new file mode 100644 index 00000000000000..39c352dfb966af --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc new file mode 100644 index 00000000000000..289b40fa06d37b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc new file mode 100644 index 00000000000000..9cb4eb921fd3af --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc new file mode 100644 index 00000000000000..5e6252fbf8c24a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc new file mode 100644 index 00000000000000..ee793ae77b1b43 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc new file mode 100644 index 00000000000000..038d9d39e4a588 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc new file mode 100644 index 00000000000000..b1952ca6193958 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc new file mode 100644 index 00000000000000..0cbaffb40b7128 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc new file mode 100644 index 00000000000000..6f4bf2a392b51a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc new file mode 100644 index 00000000000000..10235447e0d266 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc new file mode 100644 index 00000000000000..8abeaa6ffca440 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc new file mode 100644 index 00000000000000..6dabd1c176eabc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index a0cb28246d3be5..eba80c0f199f62 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -15,53 +15,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -class HloExecutionProfileTest : public HloTestBase { - protected: - static constexpr int64 kInstructionCyclesIndex = 0; - static constexpr int64 kInstructionNameIndex = 19; -}; +using tensorflow::strings::StrCat; +using ::testing::AllOf; +using ::testing::ContainsRegex; -// Splits `lines` into a sequence of lines delimited by newlines and then split -// each of those lines into a sequence of words delimited by spaces. Filter out -// empty words. -std::vector> SplitIntoLinesAndWords( - tensorflow::StringPiece lines) { - std::vector> result; - for (const string& line : tensorflow::str_util::Split(lines, '\n')) { - std::vector words; - for (const string& word : tensorflow::str_util::Split(line, ' ')) { - if (!word.empty()) { - words.push_back(word); - } - } - result.push_back(std::move(words)); - } - - return result; -} +class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - std::unique_ptr hlo_module = CreateNewModule(); - - HloComputation::Builder builder(TestName()); + auto hlo_module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + lhs = f32[30,30]{1,0} parameter(0) + rhs = f32[30,30]{1,0} parameter(1) + add = f32[30,30]{1,0} add(lhs, rhs) + ROOT dot = f32[30,30]{1,0} dot(lhs, add), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + const HloInstruction* dot_instruction = + hlo_module->entry_computation()->root_instruction(); + const HloInstruction* add_instruction = dot_instruction->operand(1); Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); - HloInstruction* param_lhs = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - HloInstruction* param_rhs = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - HloInstruction* add_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloInstruction* dot_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, param_lhs, add_instruction)); - - hlo_module->AddEntryComputation(builder.Build()); auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; @@ -84,20 +64,12 @@ TEST_F(HloExecutionProfileTest, Basic) { execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); - string rendered_profile = execution_profile.ToString( - backend().default_stream_executor()->GetDeviceDescription()); - std::vector> lines_and_words = - SplitIntoLinesAndWords(rendered_profile); - ASSERT_EQ(lines_and_words.size(), 8); - - const std::vector& line_2 = lines_and_words[2]; - const std::vector& line_3 = lines_and_words[3]; - - EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); - EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); - - EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); - EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); + EXPECT_THAT(execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()), + AllOf(ContainsRegex(StrCat(dot_cycles, R"(\b.*%)", + dot_instruction->name())), + ContainsRegex(StrCat(add_cycles, R"(\b.*%)", + add_instruction->name())))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 516e14b4642ae6..61612bebd1e906 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -321,12 +321,12 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - const DebugOptions& debug_options, bool show_metadata, + const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(label.ToString()), + label_(std::string(label)), debug_options_(debug_options), - show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +365,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -392,7 +393,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -426,7 +427,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map sharding_colors_; + std::unordered_map + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -588,15 +590,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { VLOG(2) << "Dumping subcomputation " << subcomp->name(); - const char* computation_fmt = R"(subgraph %s { -%s -label = <%s>; -labelloc = t; -tooltip = " "; -%s -} // %s + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); + } -)"; + // Have we already dumped this subcomputation? If so, generating the edge + // linking it and parent_instr is all we want to do in this function. + if (cluster_ids_.find(subcomp) != cluster_ids_.end()) { + return ""; + } cluster_ids_[subcomp] = next_cluster_id_++; @@ -611,6 +624,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -639,25 +656,16 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { - const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); - VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() - << " as " << next_edge_id_; - edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = - R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( - edge_fmt, InstructionId(from), InstructionId(parent_instr), - SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); - } - - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + const char* computation_fmt = R"(subgraph %s { +%s +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s - return computation; +)"; + return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,16 +791,16 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_backend_config, + extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", - InstructionId(instr), node_body, node_shape, + InstructionId(instr), node_body, node_shape, node_metadata, NodeColorAttributes(color)); } @@ -804,7 +813,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (ShapeUtil::HasZeroElements(shape)) { + if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } @@ -816,7 +825,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -876,14 +885,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; @@ -916,6 +924,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -923,6 +932,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1002,6 +1012,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: return kPurple; + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: return kGray; @@ -1057,10 +1068,6 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { - if (!show_metadata_) { - return ""; - } - std::vector lines; if (!instr->metadata().op_name().empty()) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); @@ -1078,13 +1085,23 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; // Get the instruction's extra attributes excluding the names of its // subcomputations, since those are drawn explicitly in the graph. for (const auto& line : instr->ExtraAttributesToString( - HloPrintOptions().set_print_subcomputation_references(false))) { + HloPrintOptions().set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff))) { lines.push_back(HtmlLikeStringSanitize(line)); } @@ -1133,6 +1150,20 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return Join(lines, "
"); } +// Gets the total number of array elements in the given shape. For tuples, this +// is the sum of all the sizes of all of the array elements recursively in the +// tuple. +static int64 TotalElementsInShape(const Shape& shape) { + int64 elems = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + elems += ShapeUtil::ElementsIn(subshape); + } + }); + return elems; +} + void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1152,9 +1183,16 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } - const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; + + // We print "small" arrays using a hollow arrowhead and "large" arrays using + // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" + // means. + bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; + + const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - from->name(), to->name(), edge_label)); + (is_big_array ? "normal" : "empty"), from->name(), + to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1404,7 +1442,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1452,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,13 +1466,13 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, + HloDotDumper(node.parent(), label, debug_options, show_backend_config, /*profile=*/nullptr, filter) .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3ba7f..0b11f34abb7f0d 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,7 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1f00aa41dc783f..8e52d926d85f1c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -64,8 +64,8 @@ TEST(HloGraphDumperTest, NestedFusion) { sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, sums[i], params[i + 2]))); } - - HloModule m(TestName()); + HloModuleConfig config; + HloModule m(TestName(), config); m.AddEntryComputation(b.Build()); HloComputation* root_computation = m.entry_computation(); @@ -122,7 +122,8 @@ TEST(HloGraphDumperTest, Constant) { auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); instruction->set_name("i_am_a_constant_root_instruction"); - HloModule m(TestName()); + HloModuleConfig config; + HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); string graph = hlo_graph_dumper::DumpGraph( *root_computation, /*label=*/"an_empty_graph", DebugOptions()); @@ -130,5 +131,23 @@ TEST(HloGraphDumperTest, Constant) { EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } +TEST(HloGraphDumperTest, TupleConstant) { + Shape tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); + HloComputation::Builder b("b"); + auto constant = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape))); + auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {3, 2}), constant, 0)); + + HloModuleConfig config; + HloModule m(TestName(), config); + HloComputation* root_computation = m.AddEntryComputation(b.Build(gte)); + string graph = hlo_graph_dumper::DumpGraph( + *root_computation, /*label=*/"tuple_constant", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("tuple_constant")); + EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a714d0e1142450..06775d6a9ab661 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,9 +37,11 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -51,7 +53,7 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); @@ -109,6 +111,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->backend_config_ = proto.backend_config(); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -255,11 +258,14 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kDomain: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -417,8 +423,20 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(replica_group_ids.empty()); + CHECK(!channel_id.has_value()); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->called_computations_.push_back(reduce_computation); + return instruction; } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -437,7 +455,7 @@ HloInstruction::CreateCrossReplicaSum( << "Outfeed shape " << shape << " must be compatible with operand shape " << operand->shape(); instruction->AppendOperand(operand); - instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_config_ = std::string(outfeed_config); instruction->outfeed_shape_ = shape; return instruction; } @@ -792,23 +810,11 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot fusions, -// since those fusions are really just describing a type of dot rather than -// generating a novel computation. -static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { - switch (fusion_kind) { - case HloInstruction::FusionKind::kTransposeDot: - return "dot_fusion"; - default: - return "fusion"; - } -} - /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); @@ -824,12 +830,21 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { instruction->AppendOperand(operand); } instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->called_computations_.push_back(fusion_computation); fusion_computation->SetFusionInstruction(instruction.get()); return instruction; } +void HloInstruction::set_device_sharding(int64 device) { + HloSharding device_sharding = HloSharding::AssignDevice(device); + if (ShapeUtil::IsTuple(shape())) { + set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape()))); + } else { + set_sharding(device_sharding); + } +} + void HloInstruction::SetupDerivedInstruction( HloInstruction* derived_instruction) const { if (sharding_ != nullptr) { @@ -1123,7 +1138,7 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -bool HloInstruction::HasSideEffect() const { +bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -1135,16 +1150,22 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; - default: { - // Check if any of the called computations has a side effect. - for (const auto& computation : called_computations()) { - if (computation->HasSideEffect()) { - return true; - } - } + default: return false; + } +} + +bool HloInstruction::HasSideEffect() const { + if (HasSideEffectNoRecurse()) { + return true; + } + // Check if any of the called computations has a side effect. + for (const auto& computation : called_computations()) { + if (computation->HasSideEffect()) { + return true; } } + return false; } /* static */ std::unique_ptr HloInstruction::CreateCall( @@ -1167,7 +1188,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->custom_call_target_ = custom_call_target.ToString(); + instruction->custom_call_target_ = std::string(custom_call_target); return instruction; } @@ -1179,7 +1200,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->channel_name_ = channel_name.ToString(); + instruction->channel_name_ = std::string(channel_name); instruction->cost_estimate_ns_ = cost_estimate_ns; return instruction; } @@ -1228,10 +1249,21 @@ bool HloInstruction::HasSideEffect() const { return gather_dim_numbers; } +/* static */ std::unique_ptr HloInstruction::CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + instruction->operand_side_metadata_ = std::move(operand_side_metadata); + instruction->user_side_metadata_ = std::move(user_side_metadata); + instruction->AppendOperand(operand); + return instruction; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { @@ -1239,7 +1271,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } std::unique_ptr clone; - // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -1253,10 +1284,12 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -1309,6 +1342,14 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); + if (window_ != nullptr) { + clone->window_ = MakeUnique(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + clone->convolution_dimension_numbers_ = + MakeUnique( + *convolution_dimension_numbers_); + } break; case HloOpcode::kHostCompute: clone = CreateHostCompute(shape, new_operands, channel_name_, @@ -1342,9 +1383,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); + clone = CreateCrossReplicaSum(shape, new_operands, to_apply()); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1415,9 +1457,22 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1476,20 +1531,35 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateGather(shape, new_operands[0], new_operands[1], *gather_dimension_numbers_, gather_window_bounds_); break; + case HloOpcode::kDomain: + CHECK_EQ(new_operands.size(), 1); + clone = + CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_raw_backend_config_string(backend_config_); + if (context != nullptr) { + context->MapInstruction(this, clone.get()); + clone->ReplaceCalledComputations([&](HloComputation* callee) { + return callee->parent() != context->module() + ? context->module()->DeepCloneComputation(callee, context) + : callee; + }); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1526,71 +1596,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -1619,6 +1624,8 @@ const Literal& HloInstruction::literal() const { return *literal_; } +bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } + bool HloInstruction::CanHaveDimensionsField() const { return (opcode() == HloOpcode::kReverse || opcode() == HloOpcode::kConcatenate || @@ -1664,6 +1671,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +HloInstruction::InstructionVector HloInstruction::unique_operands() const { + InstructionVector unique; + tensorflow::gtl::FlatSet seen; + for (HloInstruction* operand : operands()) { + if (seen.insert(operand).second) { + unique.push_back(operand); + } + } + return unique; +} + Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); if (std::find(control_successors_.begin(), control_successors_.end(), @@ -1739,26 +1757,29 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -1766,6 +1787,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: @@ -1778,6 +1800,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -1789,22 +1813,26 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; + // Broadcast, Concatenate, and Transpose need the same dimensions field. + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kTranspose: + return dimensions() == other.dimensions(); + case HloOpcode::kFusion: return fusion_kind() == other.fusion_kind() && eq_computations(fused_instructions_computation(), other.fused_instructions_computation()); // These opcodes have complex or special behavior so just return false. + case HloOpcode::kDomain: case HloOpcode::kRng: case HloOpcode::kTrace: case HloOpcode::kWhile: return false; case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); + return parameter_number() == other.parameter_number(); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1816,12 +1844,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConstant: return literal() == other.literal(); - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); - // A reduce-precision operation is determined by the bit sizes. case HloOpcode::kReducePrecision: return exponent_bits() == other.exponent_bits() && @@ -1864,22 +1886,8 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); // Remaining instructions with special values. - case HloOpcode::kBitcast: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); case HloOpcode::kGetTupleElement: return tuple_index() == other.tuple_index(); case HloOpcode::kPad: @@ -1889,15 +1897,24 @@ bool HloInstruction::IdenticalSlowPath( return slice_starts_ == other.slice_starts_ && slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: + if ((window_ == nullptr) != (other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(window(), other.window()))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()))) { + return false; + } return custom_call_target_ == other.custom_call_target_; case HloOpcode::kReverse: return dimensions() == other.dimensions(); @@ -2029,6 +2046,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2160,28 +2178,68 @@ string PrintName(const string& name, const HloPrintOptions& options) { } // namespace string HloInstruction::ToString(const HloPrintOptions& options) const { - string result = - StrCat(PrintName(name(), options), " = ", - ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + CanonicalNameMap new_map; + return ToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string result = ""; + + // Logic to print the instruction name (e.g. "%foo = "). + if (options.canonicalize_instruction_names()) { + if (options.is_in_nested_computation()) { + // If we are canonicalizing instruction names and this is a top-level + // HloInstruction::ToString() call, don't print an instruction name. + StrAppend(&result, + PrintName(canonical_name_map->LookupOrInsert(name()), options), + " = "); + } + } else { + StrAppend(&result, PrintName(name(), options), " = "); + } + + // Print opcode, operand(s) and shape. + StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", + OperandsToStringWithCanonicalNameMap(options, canonical_name_map), + ")"); + + // Print additional attributes. If an instruction contains a subcomputation, + // the subcomputation is also printed here. for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); + } return result; } string HloInstruction::OperandsToString(const HloPrintOptions& options) const { + CanonicalNameMap new_map; + return OperandsToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { + // + // In HloInstruction, sometimes a constant literal is not constructed due + // to its size. Skip the printing in this case. + if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) && + ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -2215,7 +2273,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const { if (options.print_operand_shape()) { str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - if (!options.compact_operands()) { + + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, Join(str, " ")); @@ -2269,7 +2334,9 @@ std::vector HloInstruction::ExtraAttributesToString( } if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); @@ -2284,7 +2351,8 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); } - if (options.print_subcomputation_references()) { + if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { extra.push_back( StrCat("condition=", PrintName(while_condition()->name(), options))); @@ -2301,7 +2369,8 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2312,8 +2381,45 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(computation->name(), options)); }))); } + } else if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kFullBodies) { + HloPrintOptions new_options = options; + new_options.set_is_in_nested_computation(true); + switch (opcode()) { + case HloOpcode::kWhile: + extra.push_back( + StrCat("condition=\n", while_condition()->ToString(new_options))); + extra.push_back(StrCat("body=\n", while_body()->ToString(new_options))); + break; + case HloOpcode::kSelectAndScatter: + extra.push_back(StrCat("select=\n", select()->ToString(new_options))); + extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); + break; + case HloOpcode::kConditional: + extra.push_back(StrCat("true_computation=\n", + true_computation()->ToString(new_options))); + extra.push_back(StrCat("false_computation=\n", + false_computation()->ToString(new_options))); + break; + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + extra.push_back( + StrCat("to_apply=\n", to_apply()->ToString(new_options))); + break; + default: + if (!called_computations().empty()) { + extra.push_back( + StrCat("calls=\n", + Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); + } + break; + } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { extra.push_back(StrCat("channel_id=", channel_id_)); @@ -2349,14 +2455,19 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("exponent_bits=", exponent_bits_)); extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); } - + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", operand_side_metadata_->ToString(), + ", exit=", user_side_metadata_->ToString(), "}")); + } // By contract, we print the custom call target even if - // !options.print_subcomputation_references(), because the call target is not + // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. if (opcode() == HloOpcode::kCustomCall) { extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2386,6 +2497,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config_); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2451,6 +2563,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_fft_length(fft_len); } + if (has_sharding()) { + *proto.mutable_sharding() = sharding().ToProto(); + } + proto.set_channel_name(channel_name_); proto.set_cost_estimate_ns(cost_estimate_ns_); @@ -2487,8 +2603,6 @@ string HloInstruction::ToCategory() const { return "input fusion"; case FusionKind::kOutput: return "output fusion"; - case FusionKind::kTransposeDot: - return "dot"; case FusionKind::kCustom: return "custom fusion"; } @@ -2522,6 +2636,7 @@ bool HloInstruction::IsFusable() const { } // Some kinds of instructions don't make sense to fuse. switch (opcode_) { + case HloOpcode::kDomain: case HloOpcode::kParameter: return false; // Side effecting instrutions cannot be fused. @@ -2534,7 +2649,9 @@ HloComputation* HloInstruction::fused_instructions_computation() const { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(!called_computations_.empty()); auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; return fused_instructions_computation; } @@ -2673,6 +2790,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleNegate(this); case HloOpcode::kExp: return visitor->HandleExp(this); + case HloOpcode::kExpm1: + return visitor->HandleExpm1(this); case HloOpcode::kFloor: return visitor->HandleFloor(this); case HloOpcode::kCeil: @@ -2681,6 +2800,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleClz(this); case HloOpcode::kLog: return visitor->HandleLog(this); + case HloOpcode::kLog1p: + return visitor->HandleLog1p(this); case HloOpcode::kTanh: return visitor->HandleTanh(this); case HloOpcode::kCos: @@ -2745,6 +2866,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kDomain: + return visitor->HandleDomain(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2971,6 +3094,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); @@ -3026,10 +3150,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -3094,7 +3220,7 @@ bool HloInstruction::IsElementwise() const { bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { CHECK(IsElementwise()); - return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape()); + return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } namespace { @@ -3270,8 +3396,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kInput"; case HloInstruction::FusionKind::kOutput: return "kOutput"; - case HloInstruction::FusionKind::kTransposeDot: - return "kTransposeDot"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3288,9 +3412,6 @@ StatusOr StringToFusionKind( if (kind_name == "kOutput") { return HloInstruction::FusionKind::kOutput; } - if (kind_name == "kTransposeDot") { - return HloInstruction::FusionKind::kTransposeDot; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } @@ -3334,42 +3455,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3393,19 +3480,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3431,6 +3507,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = @@ -3462,6 +3560,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a5e9aecb9e7f52..ef55c6668f2baa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -60,51 +63,75 @@ class HloModule; // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: + enum class PrintSubcomputationMode { + kOff, // Do not print anything about subcomputations. + kNameOnly, // Only print the name of subcomputations. + kFullBodies, // Print the full bodies of subcomputations. + }; + // Constructs the default print options: don't print large constants, don't // compact operands, no indentation. HloPrintOptions() : print_large_constants_(false), - print_subcomputation_references_(true), + print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), print_percent_(true), - indent_amount_(0) {} + canonicalize_instruction_names_(false), + indent_amount_(0), + is_in_nested_computation_(false) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() .set_print_large_constants(true) - .set_print_subcomputation_references(true) + .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); } + // Options to produce the canonical string representing an isomorphic + // computation graph. + static HloPrintOptions Canonical() { + return HloPrintOptions() + .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) + .set_print_metadata(false) + .set_compact_operands(true) + .set_print_operand_shape(true) + .set_print_program_shape(false) + .set_print_percent(false) + .set_canonicalize_instruction_names(true); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; return *this; } - // If true, the names of subcomputations (e.g. a fusion node's fused - // computation) won't be printed. This makes the resulting text not parsable. - // - // A CustomCall's call target is printed even if - // print_subcomputation_references is false, because the call target isn't an - // HloComputation. - HloPrintOptions& set_print_subcomputation_references(bool value) { - print_subcomputation_references_ = value; + HloPrintOptions& set_print_subcomputation_mode( + PrintSubcomputationMode value) { + print_subcomputation_mode_ = value; return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -130,54 +157,175 @@ class HloPrintOptions { return *this; } + // If true, canonicalizes instructions' name. Instead of using "%foo.1" as + // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. + HloPrintOptions& set_canonicalize_instruction_names(bool value) { + canonicalize_instruction_names_ = value; + return *this; + } + // The indent of the hlo text block. HloPrintOptions& set_indent_amount(int value) { indent_amount_ = value; return *this; } + // If true, indicates the instruction being printed is inside a nested + // computation. + HloPrintOptions& set_is_in_nested_computation(bool value) { + is_in_nested_computation_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } - bool print_subcomputation_references() const { - return print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode() const { + return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool canonicalize_instruction_names() const { + return canonicalize_instruction_names_; + } int indent_amount() const { return indent_amount_; } + int is_in_nested_computation() const { return is_in_nested_computation_; } private: bool print_large_constants_; - bool print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool canonicalize_instruction_names_; int indent_amount_; + bool is_in_nested_computation_; +}; + +// For canonical string output, we need to have a canonical way to rename +// each instruction and its operands. Each operand is renamed as "tmp_", +// where is an index starting from 0. +class CanonicalNameMap { + public: + CanonicalNameMap() : index(0) {} + + string LookupOrInsert(const string& old_name) { + auto iter = canonical_name_map.find(old_name); + if (iter != canonical_name_map.end()) { + return iter->second; + } + + string new_name = tensorflow::strings::StrCat("tmp_", index++); + canonical_name_map[old_name] = new_name; + return new_name; + } + void Clear() { + canonical_name_map.clear(); + index = 0; + } + + private: + int64 index; + tensorflow::gtl::FlatMap canonical_name_map; }; -// HLO instructions are the IR used by the high-level compiler. +// HLO instructions are the atomic unit of the high-level compiler's IR. +// +// HloInstructions live inside of an HloComputation, which is analogous to a +// function in other programming languages. Nodes have no total order within +// their computation. Instead, they have a partial ordering determined by their +// data and control dependencies. +// +// HLO does not have basic blocks or explicit "branch" instructions. Instead, +// certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode +// control flow. For example, the kConditional HLO executes one of two possible +// computations, depending on the runtime value of a predicate. +// +// HLO is pure (mostly). It has no concept of mutable state. Instead, data +// values are produced by one HLO and flow into consumers across dependency +// edges. class HloInstruction { public: + // A fusion node computes the same value a call to its fusion computation + // would compute. However, the choice of fusion kind dictates codegen + // strategy for the backend. + // + // To generate code for a kFusion HloInstruction, most backends do something + // like the following: + // + // 1) Identify the "primary" HloInstruction of the fused computation. + // 2) Emit code that does the work of the primary node, creating its inputs + // and transforming its outputs as specified by the fused computation. + // + // In step (2), the code emitted is usually similar to the code that would be + // emitted for an *unfused* version of the primary node, except that + // + // - when the primary node reads an element of one of its operands, instead + // of loading the value from memory, it *computes* the value based on the + // contents of the fused computation. + // - when the primary node outputs a value, instead of storing it to memory, + // it forwards the value to its users, which then perform additional + // computations before the value is finally stored to memory at the root of + // the fusion node. + // + // An HloInstruction's FusionKind helps us find the kFusion instruction's + // primary node, and can also affect how we generate code in step (2). + // + // - kInput: The primary node is the root of the fused instruction. + // + // - kOutput: The primary node is not the root of the fused instruction. + // This fusion kind requires that one operand buffer of the fusion + // instruction be able to alias the output buffer. This constraint is + // usually enough to let backends find the primary node unambiguously. + // + // - kLoop: The primary node is the root of the fused computation, but, + // unlike in input fusion, we prescribe a specific implementation for + // codegen. Rather than generating code that looks like the code we'd emit + // for an unfused version of the primary/root node, we emit code that + // generates one element of the root at a time. + // + // - kCustom: Custom category for backend-specific fusions that don't fit + // into the above patterns. + // + // Not all backends support all fusion kinds, and given a particular fused + // computation, it's not in general safe to change its fusion kind. Creation + // of fusion nodes is always backend-specific. + // + // For elementwise ops (e.g. kAdd), most backends would emit a + // one-element-at-a-time implementation for the unfused version, so loop + // fusion and input fusion are probably equivalent if the root node is + // elementwise. They're not necessarily equivalent e.g. for kReduce, where an + // implementation might emit something more sophisticated for an unfused or + // input-fusion reduce, but will emit the naive code that reduces one element + // at a time for loop fusion with a reduce as the root. + // + // Another way to think of loop fusion is that it's equivalent to input + // fusion, but where the root node is an implicit identity node, whose + // unfused implementation is "read one element, write one element". + // + // TODO(b/79869434): This categorization scheme is not great. For one thing, + // input and loop fusion are basically the same thing: There is no reason for + // the HLO to encode backend-specific decisions about how e.g. a reduce that's + // the root of a fusion should be lowered. In addition, this scheme as + // written doesn't work for multi-output fusion, where the primary node is + // never actually the root (which is a kTuple instruction that gathers the + // multiple outputs of the fusion). enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, + kInput, + kOutput, + kCustom, }; - ~HloInstruction(); + virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // - // module: the module which will contain the instruction. The newly created - // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. @@ -185,7 +333,7 @@ class HloInstruction { // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map); @@ -278,10 +426,26 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -452,6 +616,13 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Creates a kDomain instruction which delimits an HLO domain which have + // the provided user and operand side metadata. + static std::unique_ptr CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -503,6 +674,10 @@ class HloInstruction { // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + // Returns true if this instruction has a side effect, irrespective of whether + // any called computations may contain an instruction with side effects. + bool HasSideEffectNoRecurse() const; + // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. @@ -527,6 +702,10 @@ class HloInstruction { using InstructionVector = tensorflow::gtl::InlinedVector; const InstructionVector& operands() const { return operands_; } + // Returns the vector of unique operands, in the same order they are found + // within the operand vector. + InstructionVector unique_operands() const; + // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; @@ -597,10 +776,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -615,7 +792,11 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + if (backend_config_ != other.backend_config_) { + return false; + } + + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. @@ -643,6 +824,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -695,6 +878,9 @@ class HloInstruction { // Note: only constant and parameter opcodes have an associated literal. const Literal& literal() const; + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const; + // Returns the parameter number associated with this instruction. // // Note: only parameter opcodes have an associated parameter number. @@ -942,20 +1128,41 @@ class HloInstruction { } // Returns the sharding unique device, if any. tensorflow::gtl::optional sharding_unique_device() const { - if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - return sharding_->UniqueDevice().ValueOrDie(); + auto device = sharding_->UniqueDevice(); + return device.ok() ? device.ValueOrDie() + : tensorflow::gtl::optional(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + // Sets a sharding that assigns the current instruction to device. + void set_device_sharding(int64 device); // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Checks whether the instruction has compatible sharding with the other + // instruction. + bool has_compatible_sharding(const HloInstruction* other) const { + if (!has_sharding()) { + return !other->has_sharding(); + } + return other->has_sharding() ? sharding() == other->sharding() : false; + } + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain @@ -1127,9 +1334,6 @@ class HloInstruction { return fft_length_; } - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1160,20 +1364,15 @@ class HloInstruction { // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. If the module - // pointer is not nullptr, it will be the module where the cloned computations - // will be added to (in order to support deep cloning). Ignores the control - // predecessors and successors of this HLO instruction. - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; - - // Clones the HLO instruction as above but with new shape and operands. If - // the module pointer is not nullptr, it will be the module where the cloned - // computations will be added to (in order to support deep cloning). Ignores - // the control predecessors and successors of this HLO instruction. + // the instruction to form the name of the cloned instruction. + // Ignores the control predecessors and successors of this HLO instruction. + std::unique_ptr Clone( + const string& suffix = "clone", HloCloneContext* context = nullptr) const; + + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1245,7 +1444,7 @@ class HloInstruction { // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } + void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1262,6 +1461,40 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); + } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); + } + + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1283,6 +1516,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1297,21 +1531,40 @@ class HloInstruction { void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + protected: + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + private: + // Prints an instruction to a string. + // + // The canonical string representation needs to name operands and instruction + // names in a consistent way. This is implemented through the + // canonical_name_map. + string ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Prints an operand to a string. + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and + // OperandsToStringWithCanonicalNameMap() functions. + friend class HloComputation; + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( @@ -1328,10 +1581,6 @@ class HloInstruction { // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. - HloInstruction(HloOpcode opcode, const Shape& shape); - // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the // clone is placed in the fusion instruction. instruction_to_fuse is @@ -1358,7 +1607,7 @@ class HloInstruction { // Clones a fusion instruction with a new shape and operands. std::unique_ptr CloneFusionWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloCloneContext* context = nullptr) const; // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. @@ -1367,6 +1616,10 @@ class HloInstruction { // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. @@ -1451,6 +1704,10 @@ class HloInstruction { // The sharding, if one exists. std::unique_ptr sharding_; + // Fields used by the kDomain instruction. + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; + // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; @@ -1510,6 +1767,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; @@ -1532,6 +1793,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1540,13 +1804,20 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // an HloInstruction* or a const HloInstruction*. // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of -// the hlo. +// the hlo. Exception: null pointer values compare less than non-null. // // Note that this cannot be used for HLO instructions across multiple modules // since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } return lhs->unique_id() < rhs->unique_id(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index f2980d309d01fd..313033ddadce6a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,11 +24,13 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -149,8 +151,8 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); EXPECT_THAT(foo->users(), UnorderedElementsAre(add)); @@ -186,8 +188,8 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -219,8 +221,8 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -254,8 +256,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(addtotal->Accept(&visitor)); @@ -303,8 +305,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(neg2->Accept(&visitor)); @@ -325,7 +327,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - HloModule module(TestName()); + auto module = CreateNewModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -335,7 +337,7 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); @@ -343,7 +345,7 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateParameter(0, f32a100x10, "")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(map->Accept(&visitor)); @@ -373,8 +375,8 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - HloModule module(TestName()); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto module = CreateNewModule(); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); @@ -387,7 +389,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(reduce->Accept(&visitor)); @@ -414,8 +416,8 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -449,8 +451,8 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar)); @@ -477,8 +479,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log)); @@ -514,8 +516,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -544,8 +546,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(2, bar->user_count()); @@ -609,8 +611,8 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -627,8 +629,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -645,8 +647,8 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -667,8 +669,8 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -690,8 +692,8 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -746,13 +748,13 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloModule module(TestName()); + auto module = CreateNewModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape, "param")); - return module.AddEmbeddedComputation(builder.Build()); + return module->AddEmbeddedComputation(builder.Build()); }; HloComputation* computation_x = make_map_computation(); @@ -767,7 +769,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {map_3_y}, HloInstruction::FusionKind::kLoop); @@ -814,8 +816,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -940,8 +942,8 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); int visit_num = 0; std::unordered_map visit_order; @@ -969,8 +971,8 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); for (int i = 0; i < add->operand_count(); ++i) { @@ -1013,8 +1015,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1056,8 +1058,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1099,10 +1101,10 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); + {dot, reshape}, HloInstruction::FusionKind::kLoop); auto fusion2 = fusion->Clone(); const HloInstruction* root = fusion->fused_expression_root(); @@ -1118,7 +1120,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { } TEST_F(HloInstructionTest, FusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1128,7 +1130,7 @@ TEST_F(HloInstructionTest, FusionEquality) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); auto neg = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); auto* fusion2 = computation->CreateFusionInstruction( @@ -1140,7 +1142,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1166,10 +1168,10 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { data_shape, HloOpcode::kSubtract, dot, add_operand)); builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); auto nested_fusion = computation->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + {dot, b_t}, HloInstruction::FusionKind::kLoop); auto fusion = computation->CreateFusionInstruction( {add, nested_fusion}, HloInstruction::FusionKind::kOutput); @@ -1244,15 +1246,8 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); - HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - - EXPECT_EQ( - fusion->ToString(options), - "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); @@ -1295,8 +1290,8 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " @@ -1331,8 +1326,8 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " @@ -1343,5 +1338,275 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().Canonical(); + + EXPECT_EQ(dot->ToString(options), + "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_EQ( + fusion->ToString(options), + R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + HloInstruction* loop = builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ(loop->ToString(options), + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ( + conditional->ToString(options), + R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, false_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CheckDeepClone) { + const char* const hlo_string = R"( +HloModule Module + +addy (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT zadd = s32[] add(lhs, rhs) +} + +calla (x: s32[]) -> s32[] { + x = s32[] parameter(0) + reduce = s32[] reduce-window(x, x), to_apply=addy + ROOT xadd = s32[] add(x, reduce) +} + +body (bparam: s32[]) -> s32[] { + constant = s32[] constant(1) + bparam = s32[] parameter(0) + v = s32[] call(bparam), to_apply=calla + ROOT add = s32[] add(constant, bparam) +} + +condition (cparam: s32[]) -> pred[] { + xconstant = s32[] constant(5) + cparam = s32[] parameter(0) + ROOT greater-than = pred[] greater-than(xconstant, cparam) +} + +ENTRY entry (param: s32[]) -> s32[] { + eparam = s32[] parameter(0) + ROOT while = s32[] while(eparam), condition=condition, body=body + } +)"; + // Check that deep clones really deep clones every instruction and + // computations, without leaving dangling pointers to the old module. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + std::unique_ptr clone = module->Clone(); + for (HloComputation* computation : clone->computations()) { + EXPECT_EQ(computation->parent(), clone.get()); + for (HloInstruction* instruction : computation->instructions()) { + EXPECT_EQ(instruction->parent()->parent(), clone.get()); + } + } +} + +TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { + const Shape shape = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder("test"); + HloInstruction* p = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + + EXPECT_TRUE(add1->Identical(*add2)); + add1->set_raw_backend_config_string("abc"); + EXPECT_FALSE(add1->Identical(*add2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc similarity index 94% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc rename to tensorflow/compiler/xla/service/hlo_lexer.cc index fc0e4444521247..f0d9fdbc8f86da 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include @@ -26,9 +26,8 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace tools { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; namespace { @@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { +tensorflow::StringPiece HloLexer::StringPieceFromPointers( + const char* begin, const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); + return tensorflow::StringPiece(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + tensorflow::StringPiece identifier = + StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. #define KEYWORD(STR) \ @@ -230,7 +230,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = identifier.ToString(); + str_val_ = std::string(identifier); return TokKind::kIdent; } @@ -332,23 +332,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == StringPiece::npos) { + if (line_offset == tensorflow::StringPiece::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -StringPiece HloLexer::GetLine(LocTy loc) const { +tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == StringPiece::npos + const char* start = line_start == tensorflow::StringPiece::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); - const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + const char* end = + line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -370,7 +371,7 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - StringPiece raw = + tensorflow::StringPiece raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { @@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) { } } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h similarity index 90% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.h rename to tensorflow/compiler/xla/service/hlo_lexer.h index 27880b9b8afbfa..ceb674f25e94ac 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ #include -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. class HloLexer { public: explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { @@ -57,7 +59,7 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - int64 GetInt64Val() const { + tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; } @@ -114,7 +116,7 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - int64 int64_val_; + tensorflow::int64 int64_val_; double decimal_val_; struct LineNoCacheTy { @@ -125,7 +127,6 @@ class HloLexer { mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc new file mode 100644 index 00000000000000..43c41ece6efc4f --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using Worklist = std::deque; +using Workset = std::unordered_set; + +namespace { + +void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, + Workset* workset) { + if (workset->count(instruction) == 0) { + worklist->push_back(instruction); + workset->insert(instruction); + VLOG(3) << "ADD instruction: " << instruction->name(); + } +} + +using VisitorFunction = std::function; + +void ForEachLiveIndex(const ShapeTree& index_tree, + const VisitorFunction& func) { + index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) { + if (live) { + func(shape_index); + } + }); +} + +// Marks 'instruction' output live at 'shape_index'. +// Adds to 'worklist' iff: +// *) 'instruction' is not already on worklist. +// *) 'shape_index' has not yet been visited. +void MarkLiveAtIndex(const HloInstruction* instruction, + const ShapeIndex& shape_index, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + auto it_added = live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + if (it->second.element(shape_index) == false) { + AddToWorklist(instruction, worklist, workset); + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } +} + +// Marks 'instruction' live at all shape indices in its output. +void MarkLiveAtAllIndices(const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + bool add_to_worklist = false; + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/true)); + add_to_worklist = true; + } else { + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& shape_index) { + if (it->second.element(shape_index) == false) { + add_to_worklist = true; + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } + }); + } + if (add_to_worklist) { + AddToWorklist(instruction, worklist, workset); + } +} + +// Propagates liveness through Tuple instructions. +// *) For each tuple operand: +// *) For tuple output shape index associated with operand: +// *) Propgate live shape indices to tuple operand at the associated +// shape index in the operands output, and add to worklist. +void PropagateLivenessThroughTuple( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); + for (int64 operand_index = 0; operand_index < instruction->operand_count(); + ++operand_index) { + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + if (shape_index.empty() || shape_index[0] != operand_index) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index; + for (int i = 1; i < shape_index.size(); ++i) { + operand_shape_index.push_back(shape_index[i]); + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); + } +} + +// Propagates liveness through GetTupleElement instructions. +// *) For each live index in GetTupleElement output, mark output of GTE operand +// at associated shape index in its output, and add to worklist. +void PropagateLivenessThroughGTE( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // Mark operand top-level index. + MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist, + workset); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + // Propagate live shape indices along GTE -> Tuple edge. + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + ShapeIndex operand_shape_index(shape_index); + operand_shape_index.push_front(instruction->tuple_index()); + MarkLiveAtIndex(instruction->operand(0), operand_shape_index, + live_index_map, worklist, workset); + }); +} + +// Propagates liveness through While instructions. +// *) For each live index in While output, mark shape index of while.body.root +// and while.operand (adding each to worklist). +// *) Mark while.cond.root and add to worklist. +void PropagateLivenessThroughWhile( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kWhile); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while body computation root instruction. + MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index, + live_index_map, worklist, workset); + // Propagate liveness to tuple-shaped operand. + MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map, + worklist, workset); + }); + + // Propagate liveness to while condition computation root instruction. + MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); +} + +// Propagates liveness out of Parameter instructions to callers and aliasing +// positions. This can occur if liveness propagates to a parameter in the +// while.condition computation, requiring liveness to propagate out to caller +// callsite while (and while.body.root). +void PropagateLivenessToParameterCallers( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + auto* xla_while = callsite.instruction(); + const ShapeTree& index_tree = + FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while result{shape_index} + MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist, + workset); + // Propagate liveness to while body root{shape_index}. + MarkLiveAtIndex(xla_while->while_body()->root_instruction(), + shape_index, live_index_map, worklist, workset); + // Propagate liveness to operand(0){shape_index}. + MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map, + worklist, workset); + }); + } + } + } +} + +} // namespace + +HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) + : module_(module), call_graph_(CallGraph::Build(&module)) {} + +// Runs liveness analysis on 'module_'. +// Initializes worklist with entry root instruction (and any instruction with +// side-effects), marking all of their output shape indices live. +// Visits elements on worklist, propagating liveness from an instructions +// live output shape indices to its called computations and operands. +void HloLivenessAnalysis::RunAnalysis() { + Worklist worklist; + Workset workset; + // Add entry compuation root instruction. + MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(), + &live_index_map_, &worklist, &workset); + for (auto* computation : module_.computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->HasSideEffectNoRecurse()) { + // Add instructions with side effects. + MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist, + &workset); + } + } + } + + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop_front(); + workset.erase(workset.find(instruction)); + VLOG(1) << "VISIT instruction: " << instruction->name(); + + if (instruction->opcode() == HloOpcode::kTuple) { + PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { + PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kWhile && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessToParameterCallers(instruction, &live_index_map_, + &worklist, &workset, + call_graph_.get()); + } else { + // Propagate liveness to called computations. + for (auto* called_computation : instruction->called_computations()) { + MarkLiveAtAllIndices(called_computation->root_instruction(), + &live_index_map_, &worklist, &workset); + } + // Propagate liveness to operands. + for (HloInstruction* operand : instruction->operands()) { + MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); + } + } + } +} + +bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const { + if (ContainsKey(live_index_map_, instruction)) { + return FindOrDie(live_index_map_, instruction).element(shape_index); + } + return false; +} + +/* static */ +StatusOr> HloLivenessAnalysis::Run( + const HloModule& module) { + VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); + + auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + + liveness_analysis->RunAnalysis(); + + return std::move(liveness_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h new file mode 100644 index 00000000000000..fe55a8070a42a3 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = + std::unordered_map>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc new file mode 100644 index 00000000000000..0275294a1a86ce --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -0,0 +1,402 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloLivenessAnalysisTest : public HloTestBase { + protected: + HloLivenessAnalysisTest() {} + + // Run liveness analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloLivenessAnalysis& RunLiveness(HloModule* module) { + liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie(); + return *liveness_; + } + + HloInstruction* GetInstruction(HloModule* module, const string& name) { + HloInstruction* to_return = nullptr; + for (auto* comp : module->computations()) { + for (auto* inst : comp->instructions()) { + if (inst->name() == name) { + to_return = inst; + break; + } + } + } + return CHECK_NOTNULL(to_return); + } + + std::unique_ptr liveness_; +}; + +// Test that add instruction at entry root is live at all output shape indices. +TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT add = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Test that a dead add instruction is marked as dead by analysis. +TEST_F(HloLivenessAnalysisTest, DeadAdd) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + add.1 = s32[] add(constant.1, constant.2) + ROOT add.2 = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); +} + +// Test that all output shape indices of entry root tuple (and defining +// instruction in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that all outputs of nested tuple and entry root (and defining +// instruction values appearing in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(1) + constant.2 = s32[] constant(2) + constant.3 = s32[] constant(3) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE at entry root of Tuple instruction only propgates liveness +// to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that GTE at entry root of nested Tuple instruction only propgates +// liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE of GTE (at entry root) of nested Tuple instruction only +// propgates liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { + auto module = ParseHloString(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {})); + + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_FALSE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_FALSE( + liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Test that live/dead while tuple elements are marked live/dead correctly. +TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a tuple element live in while.cond computation, propagates +// liveness to while.body.root/while.result/while.operand (where it is unused). +TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 + add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(add.1, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a use of while.result{0} propagates liveness to +// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. +TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2 + multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3) + ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.1 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + } + ENTRY SimpleLoop { + constant.2 = s32[] constant(0) + constant.3 = s32[] constant(1) + constant.4 = s32[] constant(2) + tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4) + while.1 = (s32[], s32[], s32[]) while(tuple.2), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0 + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2})); + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2})); + // While body root. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2})); + // While body param. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 69deac263ee58f..7e4b8834357d39 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -17,10 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { +using ::tensorflow::str_util::Join; + bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -195,6 +198,41 @@ void HloShardingMatcher::DescribeTo(std::ostream* os) const { } } +bool HloDotWithContractingDimsMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + + const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); + if (dim_nums.lhs_contracting_dimensions_size() != 1 || + dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong lhs_contracting_dimensions (got {" + << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" + << lhs_contracting_dim_ << "})"; + return false; + } + + if (dim_nums.rhs_contracting_dimensions_size() != 1 || + dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong rhs_contracting_dimensions (got {" + << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" + << rhs_contracting_dim_ << "})"; + return false; + } + + return true; +} + +void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with lhs_contracting_dims={" << lhs_contracting_dim_ + << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}"; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index f2ab9b5d9b6e00..c570b420c21fed 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -131,6 +132,27 @@ class HloShardingMatcher tensorflow::gtl::optional sharding_; }; +// Matches a Dot HLO instruction with specific LHS and RHS contracting +// dimensions. +class HloDotWithContractingDimsMatcher : public HloMatcher { + public: + explicit HloDotWithContractingDimsMatcher( + ::testing::Matcher lhs, + ::testing::Matcher rhs, int64 lhs_contracting_dim, + int64 rhs_contracting_dim) + : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}), + lhs_contracting_dim_(lhs_contracting_dim), + rhs_contracting_dim_(rhs_contracting_dim) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + int64 lhs_contracting_dim_; + int64 rhs_contracting_dim_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -158,7 +180,6 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); -HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -282,11 +303,21 @@ inline ::testing::Matcher Shape( const class Shape& shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); } +inline ::testing::Matcher Shape( + tensorflow::StringPiece shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( + ShapeUtil::ParseShapeString(shape).ValueOrDie())); +} inline ::testing::Matcher ShapeWithLayout( const class Shape& shape) { return ::testing::MakeMatcher( new ::xla::testing::HloShapeAndLayoutMatcher(shape)); } +inline ::testing::Matcher ShapeWithLayout( + tensorflow::StringPiece shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( + ShapeUtil::ParseShapeString(shape).ValueOrDie())); +} // Verifies the value of the HloSharing against the provided sharding object. inline ::testing::Matcher Sharding( @@ -294,12 +325,42 @@ inline ::testing::Matcher Sharding( return ::testing::MakeMatcher( new ::xla::testing::HloShardingMatcher(sharding)); } +// Matcher for Sharding from sharding string +inline ::testing::Matcher Sharding( + tensorflow::StringPiece sharding) { + return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( + ParseSharding(sharding).ValueOrDie())); +} // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { return ::testing::MakeMatcher( new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); } +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher})); +} + +// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension +// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension +// equal to `rhs_contracting_dim`. +// +// Currently the HLO verifier rejects Dot operations with more than one +// contracting dimension (even though we can represent these in the +// DotDimensionNumbers proto) so there is no need to generalize this to support +// multiple contracting dimensions. +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher, + int64 lhs_contracting_dim, int64 rhs_contracting_dim) { + return ::testing::MakeMatcher( + new ::xla::testing::HloDotWithContractingDimsMatcher( + lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim)); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index c6373b2e46af7d..9a3010cf1ff75e 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" namespace op = xla::testing::opcode_matchers; @@ -105,21 +106,28 @@ TEST(HloMatchersTest, ShapeMatcher) { 0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param"); EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {5, 7}))); + EXPECT_THAT(p0.get(), op::Shape("f32[5,7]")); EXPECT_THAT( p0.get(), ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {5, 7})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]"))); EXPECT_THAT(p0.get(), ::testing::Not(op::Shape(ShapeUtil::MakeShape(F32, {7, 5})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::Shape("f32[7,5]"))); EXPECT_THAT( p0.get(), ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {7, 5})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[7,5]"))); EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}))); + EXPECT_THAT(p0.get(), op::Shape("f32[5,7]{0,1}")); EXPECT_THAT(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout( F32, {5, 7}, {0, 1}))); + EXPECT_THAT(p0.get(), op::ShapeWithLayout("f32[5,7]{0,1}")); EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout( ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {1, 0})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]{1,0}"))); EXPECT_THAT(Explain(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))), "%param = f32[5,7]{0,1} parameter(0) has incorrect shape " @@ -139,6 +147,18 @@ TEST(HloMatchersTest, ShardingMatcher) { "param.1"); p1->set_sharding(HloSharding::AssignDevice(1)); + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}), + ShapeUtil::MakeShape(F32, {11})}); + auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2"); + Array assignment({2}); + assignment.SetValues({0, 1}); + auto sharding = HloSharding::Tuple( + tuple_shape, + {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), + HloSharding::AssignDevice(1), HloSharding::Replicate()}); + p2->set_sharding(sharding); + EXPECT_THAT(p0.get(), op::NoSharding()); EXPECT_THAT(p0.get(), ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); @@ -147,6 +167,11 @@ TEST(HloMatchersTest, ShardingMatcher) { ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + EXPECT_THAT( + p2.get(), + op::Sharding( + "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " "{maximal device=1})"); @@ -158,5 +183,41 @@ TEST(HloMatchersTest, ShardingMatcher) { "has incorrect sharding (expected: {maximal device=0})"); } +TEST(HloMatchersTest, DotMatcher) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0)); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/0, + /*rhs_contracting_dim=*/0)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "lhs_contracting_dimensions (got {1} want {0})"); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "rhs_contracting_dimensions (got {0} want {1})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 08b9a29aeda2ee..e63424c2dfb6c7 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -41,14 +41,23 @@ HloModule::HloModule(const string& name, entry_computation_handle_(entry_computation_handle), unique_id_(next_unique_module_id_++) {} -HloModule::HloModule(const string& name) - : name_(NameUniquer::GetSanitizedName(name)), - unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { @@ -58,7 +67,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_entry_computation_layout()) { + if (!config_.has_host_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -232,11 +241,14 @@ StatusOr> HloModule::CreateFromProto( TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK(expected_program_shape.parameters_size() == - module_config.entry_computation_layout().parameter_count()); + TF_RET_CHECK( + expected_program_shape.parameters_size() == + module_config.device_entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.entry_computation_layout().parameter_layout(i).shape(); + module_config.device_entry_computation_layout() + .parameter_layout(i) + .shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -246,7 +258,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.entry_computation_layout().result_layout().shape(); + module_config.device_entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -254,24 +266,44 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> computations; + HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); + computation_map[computation_id] = computation.get(); + to_proto_id[computation.get()] = computation_id; + if (computation_id == proto.entry_computation_id()) { + entry = computation.get(); + } + computations.push_back(std::move(computation)); + } + TF_RET_CHECK(entry != nullptr); + + auto module = MakeUnique(proto.name(), entry_computation_handle, + module_config); + + // Sort the computations in the proto id's order. + std::sort(computations.begin(), computations.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + // Add sorted computations to the module. + for (auto& computation : computations) { + bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_id] = module->AddComputationInternal( - std::move(computation), - /*is_entry=*/proto.entry_computation_id() == computation_id, - /*uniquify_names=*/false); + module->AddComputationInternal(std::move(computation), is_entry, + /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -306,7 +338,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_entry_computation_layout(); + module_config.mutable_host_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -314,6 +346,8 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); + *module_config.mutable_device_entry_computation_layout() = + module_config.host_entry_computation_layout(); return module_config; } @@ -462,7 +496,18 @@ std::list HloModule::MakeComputationPostOrder() const { added_computations.insert(computation.get()); } } - CHECK_EQ(post_order.size(), computations_.size()); + if (post_order.size() != computations_.size()) { + for (HloComputation* computation : post_order) { + LOG(ERROR) << "Post Order: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + for (auto& computation : computations_) { + LOG(ERROR) << "Computations: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() + << " computation_count=" << computations_.size(); + } return post_order; } @@ -479,59 +524,29 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix); - module->config_ = config_; + auto module = MakeUnique(name_ + "-" + suffix, config_); module->entry_computation_handle_ = entry_computation_handle_; module->has_entry_computation_handle_ = has_entry_computation_handle_; - std::unordered_map clone_map; - for (auto& computation : computations_) { - if (computation->IsFusionComputation()) { - // Cloning of a fused computation is handled by its fusion instruction. - continue; - } - - // When cloning a computation, pass in the new module, so that for any - // fusion instruction in this computation, the fused computation will be - // deep cloned to the new module. - auto cloned_computation = computation->Clone(suffix, module.get()); - InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); - - if (entry_computation_ == computation.get()) { - module->AddEntryComputation(std::move(cloned_computation)); - } else { - module->AddEmbeddedComputation(std::move(cloned_computation)); - } - } - - for (auto& cloned_computation : module->computations_) { - for (auto* instruction : cloned_computation->instructions()) { - // Rewrite instruction's called_computation to point to the cloned - // computations. - instruction->ReplaceCalledComputations([&](HloComputation* hlo) { - if (hlo->IsFusionComputation()) { - // Cloning of a fused computation has already been handled when its - // fusion instruction is cloned. So this hlo computation is already - // the cloned one. - return hlo; - } - return FindOrDie(clone_map, hlo); - }); - } - } + HloCloneContext context(module.get(), suffix); + auto cloned_computation = entry_computation_->Clone(suffix, &context); + module->AddEntryComputation(std::move(cloned_computation)); return module; } -HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { - HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); - TF_CHECK_OK( - clone->root_instruction()->Accept([this](HloInstruction* instruction) { - instruction->ReplaceCalledComputations([this](HloComputation* callee) { - return DeepCloneComputation(callee); - }); - return Status::OK(); - })); - return clone; +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, + HloCloneContext* context) { + HloComputation* new_computation; + if (context != nullptr) { + if ((new_computation = context->FindComputation(computation)) != nullptr) { + return new_computation; + } + new_computation = + AddEmbeddedComputation(computation->Clone(context->suffix(), context)); + } else { + new_computation = AddEmbeddedComputation(computation->Clone("")); + } + return new_computation; } uint64 HloModule::RandomNew64() const { @@ -539,6 +554,14 @@ uint64 HloModule::RandomNew64() const { return rng_(); } +HloComputation* HloModule::GetComputationWithName( + tensorflow::StringPiece name) { + auto it = c_find_if(computations(), [&](HloComputation* computation) { + return computation->name() == name; + }); + return it == computations().end() ? nullptr : *it; +} + /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 9f7f25202ba42b..c93c74d34a95cf 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -26,12 +26,14 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -41,10 +43,18 @@ namespace xla { // Describes a compilation unit at the HLO level. // -// A HLO module contains one or more HLO computations. The module contains one -// "entry" computation which produces the result. The module also includes any -// embedded computations used by instructions such as "map" and "reduce". All -// computations are owned by the module. +// HloModule is the top-level unit in the HLO IR. It corresponds to a whole +// "program". Running a module, from beginning to end, is the only way to run +// an XLA program. +// +// A module contains one "entry computation"; this HloComputation is like main() +// in a C program. The result of running the module is the result of running +// this computation. +// +// A module also contains some number of "nested computations". Each nested +// computation is attached to an HloInstruction within some other computation. +// The meaning of the nested computation depends on the instruction it's +// attached to. class HloModule { public: HloModule(const string& name, @@ -55,7 +65,6 @@ class HloModule { // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. - explicit HloModule(const string& name); explicit HloModule(const string& name, const HloModuleConfig& config); // Adds an entry computation to the module. A module can only have one entry @@ -86,8 +95,10 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all - // the called computations as well. - HloComputation* DeepCloneComputation(HloComputation* computation); + // the called computations as well. If the clone context is specified, it + // will be populated with the cloned object mappings. + HloComputation* DeepCloneComputation(HloComputation* computation, + HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { @@ -99,12 +110,20 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_entry_computation_layout() { - return config_.mutable_entry_computation_layout(); + ComputationLayout* mutable_host_entry_computation_layout() { + return config_.mutable_host_entry_computation_layout(); + } + + const ComputationLayout& host_entry_computation_layout() const { + return config_.host_entry_computation_layout(); + } + + ComputationLayout* mutable_device_entry_computation_layout() { + return config_.mutable_device_entry_computation_layout(); } - const ComputationLayout& entry_computation_layout() const { - return config_.entry_computation_layout(); + const ComputationLayout& device_entry_computation_layout() const { + return config_.device_entry_computation_layout(); } const VersionedComputationHandle& entry_computation_handle() const { @@ -131,6 +150,10 @@ class HloModule { MakeUnwrappingIterator(computations_.end())}; } + // Returns the computation in this module that has the name `name`. Returns + // null if there is no such computation. + HloComputation* GetComputationWithName(tensorflow::StringPiece name); + // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } @@ -205,6 +228,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 4205b0402cb8b2..dae5578a3158fe 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -31,11 +31,13 @@ using tensorflow::strings::StrAppend; HloModuleConfig::HloModuleConfig() {} HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : entry_computation_layout_(program_shape) {} + : host_entry_computation_layout_(program_shape), + device_entry_computation_layout_(program_shape) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - entry_computation_layout_ = ComputationLayout(program_shape); + host_entry_computation_layout_ = ComputationLayout(program_shape); + device_entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -44,11 +46,18 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { + host_entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); + host_entry_computation_layout_->result_shape().SerializeAsString()); + for (const ShapeLayout& param_layout : + device_entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend( + &key, tensorflow::str_util::Join(params, ", "), ") => ", + device_entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 586a03d412681c..cdb0b29a2399b3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -41,26 +41,44 @@ class HloModuleConfig { explicit HloModuleConfig(const ProgramShape& program_shape); // Checks if this config has an entry computation layout already. - bool has_entry_computation_layout() const { - return entry_computation_layout_.has_value(); + bool has_host_entry_computation_layout() const { + return host_entry_computation_layout_.has_value(); + } + + bool has_device_entry_computation_layout() const { + return device_entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the layout of the entry computation. + // Returns a constant reference to the on-host layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& host_entry_computation_layout() const { + CHECK(host_entry_computation_layout_.has_value()); + return *host_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-host entry computation. // Assumes the layout was set. - const ComputationLayout& entry_computation_layout() const { - CHECK(entry_computation_layout_.has_value()); - return *entry_computation_layout_; + ComputationLayout* mutable_host_entry_computation_layout() { + CHECK(host_entry_computation_layout_.has_value()); + return &(*host_entry_computation_layout_); } - // Returns a mutable pointer to the layout of the entry computation. Assumes - // the layout was set. - ComputationLayout* mutable_entry_computation_layout() { - CHECK(entry_computation_layout_.has_value()); - return &(*entry_computation_layout_); + // Returns a constant reference to the on-device layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& device_entry_computation_layout() const { + CHECK(device_entry_computation_layout_.has_value()); + return *device_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-device entry computation. + // Assumes the layout was set. + ComputationLayout* mutable_device_entry_computation_layout() { + CHECK(device_entry_computation_layout_.has_value()); + return &(*device_entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -109,7 +127,8 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional entry_computation_layout_; + tensorflow::gtl::optional host_entry_computation_layout_; + tensorflow::gtl::optional device_entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc new file mode 100644 index 00000000000000..98d20315e399c6 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool HasSendRecv(HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecv || + instruction->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (auto* sub_computation : instruction->called_computations()) { + if (HasSendRecv(sub_computation)) { + return true; + } + } + } + return false; +} + +StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + + const auto* xla_while = instruction; + auto* while_body_comp = xla_while->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + + if (!ShapeUtil::IsTuple(xla_while->shape()) || + while_body_root->opcode() != HloOpcode::kTuple || + HasSendRecv(while_body_comp)) { + // Only run DCE on tuple-shaped while loops where body root is Tuple, + // with no send/recv instructions. + VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); + continue; + } + + // Remove dead tuple elements. + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(xla_while->shape()); + for (int64 i = 0; i < tuple_element_count; ++i) { + if (liveness->IsLive(xla_while, {i})) { + continue; + } + VLOG(1) << "WhileDCE Dead while tuple element." + << " while: " << xla_while->name() << " tuple_index: " << i; + // Transform while.body computation to make tuple element at + // 'shape_index' as simple pass-through parameter (which candidate + // be removed later by simplification pass). + HloInstruction* pass_thru_gte = while_body_comp->AddInstruction( + HloInstruction::CreateGetTupleElement( + while_body_param->shape().tuple_shapes(i), while_body_param, + i)); + // Replace while.body.root Tuple operand at 'tuple_index' with + // 'pass_thru_gte', making prior operand a dead root (to be cleaned + // up with a subsequent DCE pass). + TF_RETURN_IF_ERROR( + while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + changed = true; + } + } + } + return changed; +} + +} // namespace + +StatusOr HloModuleDCE::Run(HloModule* module) { + VLOG(2) << "Before HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + std::unique_ptr liveness; + TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + + // Sweep through while instructions, transforming dead while tuple element + // computations to pass through tuple values (creating dead roots in while + // body computation in the process). + TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get())); + + // Run HloDCE to clean up any dead code created during HloModuleDCE. + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module)); + + VLOG(2) << "After HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + return hlo_module_dce_changed | hlo_dce_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h new file mode 100644 index 00000000000000..29024085c10389 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes dead code from computations in the module using +// HloModule-scoped analysis (HloLivenessAnalysis). +// +// Sweeps through live instructions which cross computation boundaries (kWhile), +// and removes code at dead shape indices. +// +class HloModuleDCE : public HloPassInterface { + public: + ~HloModuleDCE() override {} + tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc new file mode 100644 index 00000000000000..363862e4905fc1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloModuleDceTest : public HloTestBase { + protected: + HloModuleDceTest() {} + + // Returns whether the given instruction exists in the given computation. + bool HasInstruction(const HloComputation& computation, + const HloInstruction* instruction) { + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); + } + + // Returns whether the while instruction with name 'while_name' in + // 'computation' passes through its tuple element at 'tuple_index' from + // parameter to root instruction. + bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation, + const string& while_name, + const int64 tuple_index) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile && + instruction->name() == while_name) { + auto* while_body_comp = instruction->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + auto* operand = while_body_root->operand(tuple_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->tuple_index() == tuple_index && + operand->operand(0) == while_body_param) { + return true; + } + return false; + } + } + return false; + } +}; + +// Tests that a while with all outputs live is unmodified. +TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests a while loop with one unused output (which is used in the while loop +// body by an instruction with side-effects: rng) is unmodified. +TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], f32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 + constant.2 = f32[] constant(1.0) + rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform + add.1 = s32[] add(get-tuple-element.2, constant.2) + ROOT tuple = (s32[], f32[]) tuple(add, add.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.3 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + } + ENTRY SimpleLoop { + constant.4 = s32[] constant(0) + constant.5 = f32[] constant(0.0) + tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5) + while = (s32[], f32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a while loop with one dead tuple element at {1} has its while +// loop body modified to make that tuple element pass-through the while body. +TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} should now be pass-through after ModuleDCE. + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a tuple element {1} used by condition computation (which appears +// dead in while.body{1} and at while.result{1}) propgates liveness of this +// tuple element to while.body{1} and at while.result{1}. +TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[]) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4) + while = (s32[], s32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} still be pass-through after ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at index {1} between +// two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6) + while.1 = (s32[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1 and while.2 should have pass-thru elements, + // after being modified to pass through unused tuple element {1}. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and +// while.2{1}, between two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { + auto module = ParseHloString(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5) + while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 54c34ce1166516..4f1715e4cafd1a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include #include #include @@ -47,13 +48,16 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { case ComputationKind::kConditionalFalse: repr += ":CONDITIONAL_FALSE"; break; + case ComputationKind::kCallFunction: + repr += ":CALL"; + break; } return repr; } /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = absl::make_unique(modules); + auto metadata = MakeUnique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -83,6 +87,7 @@ Status HloModuleGroupMetadata::Build() { << "Peer instruction does not match the computation kind"; TF_RETURN_IF_ERROR( AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); } // Add the parents of companion instructions (they must be all of the same @@ -107,6 +112,42 @@ Status HloModuleGroupMetadata::Build() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + TF_RETURN_IF_ERROR(VerifyCompanionSets()); + if (VLOG_IS_ON(4)) { + DumpCollectedStats(); + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyCompanionSets() const { + for (const auto& companions : companion_sets_) { + // A companion set must be composed at most of an instruction per + // device/module. + std::unordered_set devices; + for (HloInstruction* instruction : *companions) { + // Go through all the communicating instructions (send, recv) of the given + // companion, and record their device. + std::unordered_set comm_devices; + for (HloInstruction* comm_instruction : + tracked_instructions_comms_.at(instruction)) { + auto device = GetInstructionDevice(*comm_instruction); + TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() + << " does not have a device"; + comm_devices.insert(*device); + } + for (int64 device : comm_devices) { + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); + } + } + } + } return Status::OK(); } @@ -194,6 +235,28 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } +tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( + const HloInstruction& instruction) const { + // The module group metadata can be created in both "single module, multiple + // devices" and "multiple modules, no explicit devices" fashions. + // The API returns an optional even though the current implementation always + // returns a device, to account for cases where we cannot guess a device. + // In such cases the VerifyChannelInstructions() will return proper errors. + tensorflow::gtl::optional device = + instruction.sharding_unique_device(); + if (!device) { + device = GetModuleId(instruction.parent()->parent()); + } + return device; +} + +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { @@ -206,6 +269,9 @@ Status HloModuleGroupMetadata::RecordInstructions() { TrackedInstruction(hlo, ComputationKind::kConditionalTrue); tracked_instructions_[hlo->false_computation()] = TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } else if (hlo->opcode() == HloOpcode::kCall) { + tracked_instructions_[hlo->to_apply()] = + TrackedInstruction(hlo, ComputationKind::kCallFunction); } if (!IsChannelInstruction(hlo)) { return Status::OK(); @@ -252,20 +318,22 @@ Status HloModuleGroupMetadata::RecordInstructions() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + VLOG(2) << "Created " << channels_.size() << " channels"; return Status::OK(); } Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2) { TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || - instruction1->opcode() == HloOpcode::kConditional); + instruction1->opcode() == HloOpcode::kConditional || + instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique>()); + tensorflow::MakeUnique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); @@ -313,44 +381,46 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } - const HloModule* send_module = channel.send->parent()->parent(); - const HloModule* send_done_module = channel.send_done->parent()->parent(); - if (send_module != send_done_module) { + auto send_device = GetInstructionDevice(*channel.send); + auto send_done_device = GetInstructionDevice(*channel.send_done); + if (!send_device) { + return FailedPrecondition("send instruction must have a device: %s", + channel.send->ToString().c_str()); + } + if (!send_done_device) { + return FailedPrecondition("send_done instruction must have a device: %s", + channel.send_done->ToString().c_str()); + } + if (*send_device != *send_done_device) { return FailedPrecondition( "send and send-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + channel.id, *send_device, *send_done_device); + } + auto recv_device = GetInstructionDevice(*channel.recv); + auto recv_done_device = GetInstructionDevice(*channel.recv_done); + if (!recv_done_device) { + return FailedPrecondition("recv_done instruction must have a device: %s", + channel.recv_done->ToString().c_str()); } - const HloModule* recv_module = channel.recv->parent()->parent(); - const HloModule* recv_done_module = channel.recv_done->parent()->parent(); - if (recv_module != recv_done_module) { + if (*recv_device != *recv_done_device) { return FailedPrecondition( "recv and recv-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + channel.id, *recv_device, *recv_done_device); } - if (send_module == recv_module) { + if (*send_device == *recv_device) { return FailedPrecondition( "send and recv (channel=%lld) must be on different devices: %lld", - channel.id, GetModuleId(send_module)); + channel.id, *send_device); } } - // Check if channel instructions are used only in allowed computations. - const auto allowed = [this](HloInstruction* hlo) { - HloComputation* computation = hlo->parent(); - const HloModule* module = computation->parent(); - if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { - return true; - } - return false; - }; for (const Channel& channel : channels_) { - if (!allowed(channel.send) || !allowed(channel.send_done) || - !allowed(channel.recv) || !allowed(channel.recv_done)) { - return FailedPrecondition("channel is used in disallowed computation"); - } + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done)); } // Check if the nest levels match for each channel. for (const Channel& channel : channels_) { @@ -368,4 +438,47 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { return Status::OK(); } +Status HloModuleGroupMetadata::CheckCommunicatingInstruction( + HloInstruction* instruction) const { + HloComputation* computation = instruction->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return Status::OK(); + } + return FailedPrecondition("channel is used in disallowed computation"); +} + +void HloModuleGroupMetadata::DumpCollectedStats() const { + std::map, int64> communication_histogram; + for (auto& channel : channels_) { + auto from_device = GetInstructionDevice(*channel.send); + auto to_device = GetInstructionDevice(*channel.recv); + LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device + << " to_device=" << *to_device << " send=" << channel.send->name() + << " send_done=" << channel.send_done->name() + << " recv=" << channel.recv->name() + << " recv_done=" << channel.recv_done->name(); + communication_histogram[std::pair(*from_device, + *to_device)] += 1; + } + for (auto& fromto_count : communication_histogram) { + LOG(INFO) << "From " << fromto_count.first.first << " to " + << fromto_count.first.second << ": " << fromto_count.second; + } + for (auto& companion_set : companion_sets_) { + LOG(INFO) << "Companion set:"; + for (HloInstruction* instruction : *companion_set) { + LOG(INFO) << " " << instruction->name(); + } + } + for (auto& instruction_comm : tracked_instructions_comms_) { + LOG(INFO) << "Communicating instruction " << instruction_comm.first->name(); + for (HloInstruction* instruction : instruction_comm.second) { + auto device = GetInstructionDevice(*instruction); + LOG(INFO) << " " << instruction->name() << " on device " << *device; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index c48a7ab0b59269..ffde3a332dfc14 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -60,6 +61,7 @@ class HloModuleGroupMetadata { kWhileBody, kConditionalTrue, kConditionalFalse, + kCallFunction, }; // Tracks the instruction mapped to a given computation, and the computation @@ -147,6 +149,15 @@ class HloModuleGroupMetadata { // the module in the module vector. int64 GetModuleId(const HloModule* module) const; + // Retrieves the device an instruction is assigned to. Either from the + // sharding information, or from the ordinal of the module the instruction + // is in. + tensorflow::gtl::optional GetInstructionDevice( + const HloInstruction& instruction) const; + + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. @@ -202,6 +213,15 @@ class HloModuleGroupMetadata { Status AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2); + // Checks whether a communicating instruction is placed in a valid position + // within the graph. + Status CheckCommunicatingInstruction(HloInstruction* instruction) const; + + // Performs a consistency check on the companion sets built for the input + // modules. Check that a companion set does not include instructions from the + // same module/device. + Status VerifyCompanionSets() const; + // Retrieves a pointer to the stored TrackedInstruction associated with a // tracked computation, or nullptr in case such computation is not tracked. const TrackedInstruction* GetTrackedInstruction( @@ -210,6 +230,9 @@ class HloModuleGroupMetadata { return it != tracked_instructions_.end() ? &it->second : nullptr; } + // Dump all the collected module group statistics to the logs. + void DumpCollectedStats() const; + // List of all companion instructions sets in the module. std::vector>> companion_sets_; @@ -221,6 +244,11 @@ class HloModuleGroupMetadata { tensorflow::gtl::FlatMap tracked_instructions_; + // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of + // communicating instructions within the proper called computation(s). + tensorflow::gtl::FlatMap> + tracked_instructions_comms_; + // All channels in the module. std::vector channels_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 289c96b0a7b90c..5a0d1e264eb509 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -289,7 +290,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = absl::make_unique(post_order); + auto reachability = MakeUnique(post_order); for (HloInstruction* hlo : post_order) { reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ca763076a16af1..1fe06ee0c0d142 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -69,11 +69,13 @@ namespace xla { V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ + V(kDomain, "domain") \ V(kDot, "dot") \ V(kDynamicSlice, "dynamic-slice") \ V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kExpm1, "exponential-minus-one") \ V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ @@ -87,6 +89,7 @@ namespace xla { V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ + V(kLog1p, "log-plus-one") \ V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index e89d94bede6c43..dcd4725fe78e8b 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -170,10 +169,10 @@ bool HloOrdering::UseIsBeforeValueDefinition( // is before the def if the instruction allows buffer sharing (in place // computation). if (use.instruction == value.defining_instruction() && - CanShareOperandBufferWithUser( + dataflow.CanShareOperandBufferWithUser( use.instruction->mutable_operand(use.operand_number), use.operand_index, value.defining_instruction(), - value.defining_index(), dataflow)) { + value.defining_index())) { VLOG(4) << " use is value def, and instruction can share use buffer"; return true; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 37a7fbad97cea2..cfe5dace05ac03 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -310,7 +310,7 @@ ENTRY while.v11 { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } @@ -347,7 +347,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc similarity index 85% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.cc rename to tensorflow/compiler/xla/service/hlo_parser.cc index fdbfc0210ea63a..3eadedfe1f8b0a 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -24,17 +26,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -namespace tools { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using ::tensorflow::StringPiece; +using ::tensorflow::gtl::optional; +using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsInts; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -53,7 +55,12 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } + + // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); private: // ParseXXX returns false if an error occurred. @@ -77,11 +84,15 @@ class HloParser { // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. - bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + Literal* literal); + bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + Literal* literal); template - bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + bool SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal); bool ParseOperands(std::vector* operands); @@ -93,9 +104,15 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr entry_metadata; + std::unique_ptr exit_metadata; }; // Types of attributes. @@ -116,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -163,21 +181,27 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector* result); + const TokKind delim, + std::vector* result); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -189,7 +213,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParseInt64(int64* result); + bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -242,10 +266,10 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -303,12 +327,18 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(p) + ->CopyLayoutFromShape(param_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } @@ -377,6 +407,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } + instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } @@ -433,10 +464,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - int64 parameter_number; + tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || @@ -470,10 +505,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -550,11 +587,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply)); break; } case HloOpcode::kReshape: { @@ -589,7 +629,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -600,7 +640,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -614,7 +654,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -625,7 +665,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -639,7 +679,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -697,7 +737,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -710,7 +750,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -722,7 +762,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -748,7 +788,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -761,7 +801,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -805,7 +845,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -829,7 +869,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -843,7 +883,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -859,7 +899,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -876,7 +916,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -947,8 +987,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -993,7 +1033,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kHostCompute: { optional channel_name; - optional cost_estimate_ns; + optional cost_estimate_ns; attrs["channel_name"] = {/*required=*/true, AttrTy::kString, &channel_name}; attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, @@ -1006,16 +1046,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; @@ -1047,20 +1087,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; + optional> output_window_dims; attrs["output_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; + optional> elided_window_dims; attrs["elided_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; + optional> gather_dims_to_operand_dims; attrs["gather_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &gather_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; + optional> window_bounds; attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, &window_bounds}; @@ -1080,6 +1120,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.entry_metadata), + std::move(domain.exit_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); @@ -1087,8 +1139,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1105,6 +1156,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_raw_backend_config_string(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1154,8 +1208,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { @@ -1182,7 +1236,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - int64 dim; + tensorflow::int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1194,7 +1248,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - int64 device; + tensorflow::int64 device; if (!ParseInt64(&device)) { return false; } @@ -1253,10 +1307,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); *sharding->mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment_dimensions) { + for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (int64 device : devices) { + for (tensorflow::int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1265,6 +1319,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; + optional entry_sharding; + optional exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector* instructions) { @@ -1291,40 +1373,50 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, +bool HloParser::SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, int64 linear_index, +bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1335,7 +1427,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, } } -bool HloParser::SetValueInLiteral(bool value, int64 linear_index, +bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -1348,7 +1440,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index, } template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal) { // Check that linear_index is in range. if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { @@ -1460,7 +1553,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; } @@ -1468,8 +1561,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // Create a literal with the given shape in default layout. *literal = Literal::CreateFromDimensions(shape.element_type(), AsInt64Slice(shape.dimensions())); - int64 nest_level = 0; - int64 linear_index = 0; + tensorflow::int64 nest_level = 0; + tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -1477,16 +1570,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), - elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim( + elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1561,7 +1653,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { LocTy loc = lexer_.GetLoc(); - int64 value; + tensorflow::int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -1601,29 +1693,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, switch (shape.element_type()) { case PRED: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F32: return ParseSparseLiteralHelper(literal, shape); case BF16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F64: return ParseSparseLiteralHelper(literal, shape); default: @@ -1636,9 +1728,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, template bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, const Shape& shape) { - std::vector index; + std::vector index; - int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = ShapeUtil::Rank(shape); *literal = MakeUnique(shape); @@ -1656,7 +1748,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, LocTy index_loc = lexer_.GetLoc(); index.clear(); if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); + tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); if (rank != 1) { return Error( @@ -1674,7 +1766,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1689,7 +1781,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, value = static_cast(lexer_.GetKind() == TokKind::kw_true); lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value_s64; + tensorflow::int64 value_s64; if (!ParseInt64(&value_s64)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -1842,7 +1934,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -1850,23 +1954,24 @@ bool HloParser::ParseAttributeHelper( LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(result); return true; } case AttrTy::kInt32: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -1900,7 +2005,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -1942,12 +2047,12 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } @@ -1992,6 +2097,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast(attr_out_ptr)); + } } }(); if (!success) { @@ -2018,9 +2126,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2030,7 +2139,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2094,7 +2205,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2118,7 +2230,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( << str; } - const int64 rank = lhs_rhs_out[0].length(); + const tensorflow::int64 rank = lhs_rhs_out[0].length(); if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2232,7 +2344,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2266,7 +2378,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { if (!ParseToken(start, StrCat("expects an int64 list starting with ", TokKindToString(start)))) { return false; @@ -2275,7 +2387,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // empty } else { do { - int64 i; + tensorflow::int64 i; if (!ParseInt64(&i)) { return false; } @@ -2392,7 +2504,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParser::ParseDxD(const string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, @@ -2400,7 +2513,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 1D if (lexer_.GetKind() == TokKind::kInt) { - int64 number; + tensorflow::int64 number; if (!ParseInt64(&number)) { return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } @@ -2420,7 +2533,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParser::ParseWindowPad( + std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -2431,7 +2545,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (int i = 0; i < padding_str.size(); i++) { - std::vector low_high; + std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -2455,7 +2569,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -2477,7 +2591,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -2564,7 +2678,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -2647,10 +2761,48 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShardingOnly() { + lexer_.Lex(); + OpSharding op_sharding; + if (!ParseSharding(&op_sharding)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after sharding"); + } + return HloSharding::FromProto(op_sharding); +} + +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace -StatusOr> Parse(StringPiece str, - const HloModuleConfig& config) { +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -2658,10 +2810,29 @@ StatusOr> Parse(StringPiece str, return parser.ConsumeHloModule(); } -StatusOr> Parse(StringPiece str) { +StatusOr> ParseHloString( + tensorflow::StringPiece str) { + HloModuleConfig config; + return ParseHloString(str, config); +} + +StatusOr ParseSharding(tensorflow::StringPiece str) { HloModuleConfig config; - return Parse(str, config); + HloParser parser(str, config); + return parser.ParseShardingOnly(); +} + +StatusOr ParseWindow(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h similarity index 52% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.h rename to tensorflow/compiler/xla/service/hlo_parser.h index 2f97a2b9b19d0c..3f3a51215e34bb 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -13,30 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace tools { + +// For details about the syntax accepted by this parser, see +// g3doc/hlo_parser.md. // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. -StatusOr> Parse(tensorflow::StringPiece str, - const HloModuleConfig& config); +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> Parse(tensorflow::StringPiece str); +StatusOr> ParseHloString( + tensorflow::StringPiece str); + +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +StatusOr ParseSharding(tensorflow::StringPiece str); + +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +StatusOr ParseSharding(tensorflow::StringPiece str); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc similarity index 86% rename from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc rename to tensorflow/compiler/xla/service/hlo_parser_test.cc index adc8b1d620eb65..08068dc5042d58 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { -namespace tools { + namespace { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; struct TestData { string test_name; @@ -65,7 +66,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -81,13 +82,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -232,6 +234,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} } +)" +}, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + )" }, // int32 result = 0; @@ -885,6 +898,24 @@ ENTRY Gather { ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} } +)" +}, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add +} + )" }, }); @@ -899,12 +930,12 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - // Expects "ToString(Parse(string)) == string", that is, parses the string, - // asserts that it succeeded, stringifies the parsed module, and checks that - // the it equals the original string. + // Expects "ToString(ParseHloString(string)) == string", that is, parses the + // string, asserts that it succeeded, stringifies the parsed module, and + // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString( HloPrintOptions().set_print_large_constants(true))); @@ -915,7 +946,7 @@ class HloParserShortTest : public HloParserTest { protected: void ExpectEqualShort() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); @@ -936,14 +967,14 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOpcode) { @@ -956,8 +987,8 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongShape) { @@ -968,8 +999,8 @@ ENTRY %blabla (x: g32[]) -> g32[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOperandsSize) { @@ -981,8 +1012,8 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, OperandNotFound) { @@ -992,8 +1023,8 @@ ENTRY %blabla (x: f32[]) -> pred[] { %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, MoreConstants) { @@ -1007,12 +1038,25 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // Constant instructions have no name. The string will be parsed successfully // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = ParseHloString(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->raw_backend_config_string()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1021,8 +1065,8 @@ ENTRY %some_2 () -> f32[2] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); } @@ -1035,8 +1079,8 @@ ENTRY %some_2x3 () -> f32[2,3] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); } @@ -1049,8 +1093,8 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); } @@ -1064,8 +1108,8 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } )"; - auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + auto result = ParseHloString(original); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); } @@ -1078,7 +1122,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // The string will be parsed successfully but the output strings are not // exactly the same, because "3e2" is parsed into value 300 and will be @@ -1092,11 +1136,11 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, InvalidDimLabels) { @@ -1112,17 +1156,18 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; + ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + ExpectHasSubstr( - Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), - "expects dim labels pattern"); - - ExpectHasSubstr(Parse(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) - .status() - .error_message(), - "must have the same rank"); + "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { @@ -1137,8 +1182,8 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { @@ -1153,7 +1198,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "attribute channel_id is expected but not seen"); } @@ -1169,7 +1214,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "'done' is not defined"); } @@ -1182,7 +1227,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { @@ -1196,7 +1241,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects padding_low and padding_high separated by '_'"); } @@ -1208,7 +1253,7 @@ ENTRY %test_comma.v4 () -> f32[] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { @@ -1218,7 +1263,7 @@ ENTRY %CustomCall () -> f32[1] { %constant = f32[1]{0} constant({12345}) ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "Shape of computation CustomCall, f32[1], is not compatible " "with that of its root instruction foo, f32[1,2,3]"); } @@ -1237,9 +1282,9 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->entry_computation_layout(); + auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); @@ -1260,7 +1305,7 @@ c1 { c2 { const2 = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); } @@ -1271,7 +1316,7 @@ ENTRY consts { first = f32[1]{0} constant({12345}) last = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ( module.ValueOrDie()->entry_computation()->root_instruction()->name(), @@ -1286,7 +1331,7 @@ ENTRY c1 { ENTRY c2 { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects only one ENTRY"); } @@ -1296,25 +1341,10 @@ ENTRY consts { ROOT const1 = f32[1]{0} constant({12345}) ROOT const2 = f32[1]{0} constant({12345}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "one computation should have only one ROOT"); } -TEST_F(HloParserTest, InstructionExists) { - const string original = R"(HloModule comp_exists -c1 { - instr = f32[1]{0} constant({12345}) -} -c2 { - instr = f32[1]{0} constant({67890}) -})"; - - ExpectHasSubstr(Parse(original).status().error_message(), - R"(was parsing 3:3: error: instruction previously defined here - instr = f32[1]{0} constant({12345}) - ^)"); -} - TEST_F(HloParserTest, ComputationExists) { const string original = R"(HloModule comp_exists comp { @@ -1323,12 +1353,52 @@ comp { comp { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), R"(was parsing 2:1: error: computation previously defined here comp { ^)"); } +TEST_F(HloParserTest, CrossComputationLookup) { + const string original = R"(HloModule cross_computation_lookup: +tcalla (a: (s32[], s32[])) -> (s32[], s32[]) { + ROOT aparam = (s32[], s32[]) parameter(0) +} + +tcallb (b: (s32[], s32[])) -> s32[] { + rparam = (s32[], s32[]) parameter(0) + ROOT gte0 = s32[] get-tuple-element(aparam), index=0 +} + +ENTRY entry { + param = (s32[], s32[]) parameter(0) + call0 = (s32[], s32[]) call(param), to_apply=tcalla + ROOT call1 = s32[] call(param), to_apply=tcallb +})"; + ExpectHasSubstr( + ParseHloString(original).status().error_message(), + "was parsing 8:39: error: instruction does not exist: aparam"); +} + +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + } // namespace -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 5120775737bfa3..d8f1ab916b5c5c 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -90,7 +90,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = name().ToString() + ": pipeline start"; + string prefix = std::string(name()) + ": pipeline start"; bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +98,12 @@ StatusOr HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), - "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + std::string(name()), "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(pass->name().ToString()) > 0) { + if (disabled_passes.count(std::string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -121,7 +121,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - name().ToString(), pass->name().ToString()); + std::string(name()), std::string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 8e167633bb1347..4738e46f8aeb96 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion( const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; + SetReachabilityToUnionHelper(inputs, instruction, &bit_vector); + return bit_vector != tmp_bit_vector_; +} +void HloReachabilityMap::FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); +} + +void HloReachabilityMap::SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { - bit_vector.SetToZero(); + bit_vector->SetToZero(); } - bit_vector.Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector.OrWith(GetBitVector(input)); + bit_vector->OrWith(GetBitVector(input)); } - - return bit_vector != tmp_bit_vector_; } void HloReachabilityMap::SetReachable(const HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 553ec11f6f9a29..69bb2b3cee6daf 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -57,6 +57,11 @@ class HloReachabilityMap { tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency @@ -133,6 +138,11 @@ class HloReachabilityMap { return bit_vectors_[GetIndex(instruction)]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector); + // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index b0632448933df4..39b85de0f12024 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -273,9 +273,8 @@ ItemList GetUsers(const InstructionList& instruction_list, for (const BufferAlias& buffer_alias : points_to_analysis.GetBufferAliases(*logical_buffer)) { for (const HloInstruction* user : buffer_alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(buffer_alias.instruction(), - buffer_alias.index(), user, - points_to_analysis)) { + if (points_to_analysis.DoesNotUseOperandBuffer( + buffer_alias.instruction(), buffer_alias.index(), user)) { // The alias may be an operand of 'user', but the LogicalBuffer cannot // possibly be used by the instruction so ignore 'user'. This is the // case, for example, for the tuple element buffers in a GetTupleElement @@ -1216,7 +1215,7 @@ StatusOr HloRematerialization::Run( // Create initial sequence of HLO instructions. TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( *module, - [this](const LogicalBuffer& buffer) { + [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 81c43db292a75d..e1f9d8efd49740 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,13 +19,12 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } namespace { @@ -81,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } HloRunner::HloRunner(se::Platform* platform) { @@ -93,53 +92,108 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CreateExecutable(std::move(module), run_hlo_passes)); - se::Stream stream(backend().default_stream_executor()); - stream.Init(); +StatusOr HloRunner::TransferLiteralToDevice( + const Literal& literal) { + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), backend().memory_allocator(), + backend().default_device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, buffer)); + return std::move(buffer); +} - ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice( - backend().default_device_ordinal(), &stream, nullptr)); - const ExecutableRunOptions& run_options = service_run_options.run_options(); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals) { + std::vector buffers; + for (const Literal* literal : literals) { + CHECK(literal != nullptr); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + TransferLiteralToDevice(*literal)); + buffers.push_back(std::move(buffer)); + } + return std::move(buffers); +} - // Copy arguments to device. - std::vector argument_buffers; - for (Literal* argument : arguments) { - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer argument_buffer, - backend().transfer_manager()->AllocateScopedShapedBuffer( - argument->shape(), run_options.allocator(), - run_options.device_ordinal())); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, argument_buffer)); - argument_buffers.push_back(std::move(argument_buffer)); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals) { + std::vector literal_pointers; + literal_pointers.reserve(literals.size()); + for (const auto& literal : literals) { + literal_pointers.push_back(literal.get()); } + return TransferLiteralsToDevice(literal_pointers); +} + +StatusOr> HloRunner::TransferLiteralFromDevice( + const ShapedBuffer& buffer) { + return backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), buffer); +} - std::vector argument_buffer_ptrs; - argument_buffer_ptrs.reserve(argument_buffers.size()); - for (const auto& buf : argument_buffers) { - argument_buffer_ptrs.push_back(&buf); +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_buffers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(argument.get()); } + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result, - executable->ExecuteOnStreamWrapper( - &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); - - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), result); - if (result_literal.ok()) { - VLOG(4) << "Executed binary and got result: " - << result_literal.ValueOrDie()->ToString(); - } else { - VLOG(4) << "Executed binary and got status: " - << result_literal.status().ToString(); +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); + return executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); } - return result_literal; + return ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); } StatusOr>> HloRunner::ExecuteReplicated( @@ -171,7 +225,7 @@ StatusOr>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(absl::make_unique(executor)); + streams.push_back(MakeUnique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -198,7 +252,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = absl::make_unique( + pool = MakeUnique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -229,7 +283,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + auto literal = MakeUnique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { @@ -278,14 +332,14 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( run_options.set_device_ordinal(device); run_options.set_stream(stream); run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); run_options.set_intra_op_thread_pool( backend().eigen_intra_op_thread_pool_device()); if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } - return ServiceExecutableRunOptions(run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); + return ServiceExecutableRunOptions( + run_options, backend().StreamBorrower(), + /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); } Backend& HloRunner::backend() { @@ -296,4 +350,8 @@ Backend& HloRunner::backend() { return *backend_; } +const Backend& HloRunner::backend() const { + return const_cast(this)->backend(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 53f7c6fe4a0911..65537f07f56e74 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -102,6 +102,15 @@ class HloRunner { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Transfers data between the host and device. + StatusOr TransferLiteralToDevice(const Literal& literal); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals); + StatusOr> TransferLiteralFromDevice( + const ShapedBuffer& buffer); + // Executes the given module with given literals as input and returns the // result as a Literal. // @@ -109,20 +118,25 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( std::unique_ptr module, const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes = true) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - c_transform( - arguments, std::back_inserter(argument_pointers), - [](const std::unique_ptr& literal) { return literal.get(); }); - return Execute(std::move(module), argument_pointers, run_hlo_passes); - } + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + // As Execute(), but accepts and returns device buffers instead of host + // buffers. + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as @@ -137,6 +151,7 @@ class HloRunner { // This creates the backend lazily so it's possible to instantiate an // HloRunner in a program without any backends linked in. Backend& backend(); + const Backend& backend() const; private: // Creates an executable object given an HLO module. If run_hlo_passes is diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 1a767628f6e2d3..68b2cde83a2eb4 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -62,7 +62,34 @@ StatusOr MinimumMemoryForSequence( namespace { // Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. +// sequence which minimizes memory usage by preferring to schedule the node that +// frees bigger buffer and defines smaller outputs. +// +// Note that list scheduler is a greedy algorithm which cannot guarantee a +// global optimal solution. As a counterexample, considering the following +// graph: +// +// +--> B ===> C -------+ +// A -> | | +// | v +// +--> D ---> F=======>G +// | ^ +// | | +// +--> E -----+ +// +// --> : Buffer with size 1 +// ==> : Buffer with size 2 +// +// The list scheduler will always try to defer scheduling B in a greedy way +// since its output buffer is bigger than input. The sequence it creates will +// be: +// A D E F B C G +// , which has a maximum memory usage of 6 (B is alive while F is executing). +// +// An optimal way to shedule the previous graph is: +// A B C D E F G +// , which has a maximum memory usage of 5 (when F is executing). +// class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions @@ -70,8 +97,11 @@ class ListScheduler { static StatusOr> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + ListScheduler scheduler(computation, points_to_analysis, size_function, + memory_by_computation); return scheduler.CreateSchedule(); } @@ -92,10 +122,13 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function) { + size_function_(size_function), + memory_by_computation_(memory_by_computation) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -185,6 +218,12 @@ class ListScheduler { } // Returns the number of bytes freed if the HLO instruction is scheduled. + // If the instruction calls subcomputations, we count the memory used by the + // subcomputations as memory "defined" by the instruction. This is not + // entirely accurate, because subcomputation memory will be freed after the + // instruction finishes. But it is more accurate than not taking + // subcomputations into account at all. In the future, we may improve + // accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -194,7 +233,19 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - return freed_bytes - entry.bytes_defined; + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : entry.instruction->called_computations()) { + auto it = memory_by_computation_.find(c); + if (it != memory_by_computation_.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; } // Constructs the scheduling priority of the given instruction. @@ -248,6 +299,8 @@ class ListScheduler { auto best_it = ready_queue.end(); --best_it; const HloInstruction* best = best_it->second.instruction; + VLOG(2) << "Schedule instruction: " << best->ToShortString() + << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); @@ -315,6 +368,11 @@ class ListScheduler { const HloComputation& computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; + // Computations are analyzed in post-order. When scheduling an instruction + // that includes subcomputations, such as a while loop, we use this map to + // look up the memory needed by subcomputations. + const tensorflow::gtl::FlatMap& + memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. tensorflow::gtl::FlatMap> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function, + memory_by_computation); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation); +} + +} // namespace + StatusOr MinimumMemoryForComputation( const HloComputation& computation, const std::vector& sequence, @@ -352,24 +428,12 @@ StatusOr MinimumMemoryForComputation( return result.heap_size; } -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm) { - return algorithm(computation, points_to_analysis, size_function); - } - return DefaultMemoryScheduler(computation, points_to_analysis, size_function); -} - -} // namespace - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { // This ordering is based on DFS post-order, with a heuristic to decide which // operand to visit first. The heuristic is based on 'extra_users', which is // simply users-1 for each instruction. By subtracting 1, we're saying that @@ -395,6 +459,13 @@ StatusOr> DFSMemoryScheduler( extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + // total_sizes[hlo] transitively includes the sizes of all nodes that + // lead to it. But computation is a DAG, so we are double-counting nodes, + // which can lead to overflows for large programs. + // cumulative_total_size caps the size to prevent overflows. + // NOTE(dimvar): this is quite ugly and should be changed. It's unclear + // why we care about transitive sizes; when scheduling a node, its input + // and output buffers should be all that matters, not its "history". total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); @@ -421,52 +492,87 @@ StatusOr> DFSMemoryScheduler( })); CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; -} +} // namespace xla StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - return ListScheduler::Run(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return ListScheduler::Run(computation, points_to_analysis, size_function, + memory_by_computation); +} + +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + const auto& post_order = computation.MakeInstructionPostOrder(); + return std::vector{post_order.begin(), + post_order.end()}; } StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We try a few schedulers and choose whichever returns a lower min-memory, + // not accounting for fragmentation. + // - List is a scheduler that uses greedy heuristics. + // - DFS visits HLOs in postorder, with a heuristic to decide the order of + // children. + // - Postorder does not use any heuristics. + // List wins for most of our benchmarks; postorder-based schedulers win for + // some RNNs. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListMemoryScheduler(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - DFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + DFSMemoryScheduler(computation, points_to_analysis, + size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); - if (list_memory <= dfs_memory) { + TF_ASSIGN_OR_RETURN( + std::vector post_order_sequence, + PostOrderMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); + TF_ASSIGN_OR_RETURN( + const int64 post_order_memory, + MinimumMemoryForComputation(computation, post_order_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); + + if (min_memory == list_memory) { VLOG(2) << "Chose min-memory list sequence: " << HumanReadableNumBytes(list_memory); return list_sequence; - } else { + } else if (min_memory == dfs_memory) { VLOG(2) << "Chose min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; } } @@ -477,24 +583,32 @@ CreateMemoryMinimizingSequence(const HloModule& module, SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN( - sequence[computation], - CreateMemoryMinimizingSequence(*computation, *points_to_analysis, - size_function, algorithm)); + tensorflow::gtl::FlatMap memory_by_computation; + for (const auto* computation : module.MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function, + algorithm, memory_by_computation)); + memory_by_computation[computation] = + MinimumMemoryForComputation(*computation, one_computation_sequence, + *points_to_analysis, size_function) + .ValueOrDie(); + sequence[computation] = std::move(one_computation_sequence); + } } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { + const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); + tensorflow::gtl::FlatMap empty_map; return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, algorithm); + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 068e68383deb17..49b927eefd24f4 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -34,26 +34,47 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +// Returns the minimum memory required to compute the given computation, +// assuming no fragmentation. +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function>( const HloComputation&, const TuplePointsToAnalysis&, - const LogicalBuffer::SizeFunction&)> + const LogicalBuffer::SizeFunction&, + const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // DFS-order scheduler StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// Naive Post Order scheduler +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, @@ -61,7 +82,9 @@ StatusOr> DFSMemoryScheduler( StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes @@ -72,10 +95,10 @@ CreateMemoryMinimizingSequence(const HloModule& module, const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. +// Currently only used by the GPU backend. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); + const LogicalBuffer::SizeFunction& size_function); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 74544c4a67a819..db7ef6f0d4bd96 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -77,7 +77,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - auto size_fn = [](const LogicalBuffer& buffer) { + auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; @@ -124,7 +124,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { + CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -158,9 +158,9 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); - auto size_fn = [](const LogicalBuffer& buffer) { + auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( @@ -190,5 +190,199 @@ ENTRY root { instructions_by_name.at("e"))); } +TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) + // ROOT %not-equal-to = pred[] not-equal-to( + // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) + // } + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // ROOT %subtract = f32[4]{0} subtract( + // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) + // } + // %SubcomputationsNotAccounted () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant( + // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) + // %transpose = f32[2,4]{1,0} transpose( + // f32[2,4]{1,0} %constant.3), dimensions={0,1} + // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), + // condition=%WhileCond, + // body=%WhileBody + // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} + // ROOT %add = f32[2,4]{1,0} add( + // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }, + ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // This schedule is an example of List's greedy heuristics being suboptimal. + // The while_loop is more expensive than transpose, so it would have been + // better to schedule it first, instead of during the busy time. + EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); + EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); +} + +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 994de441237493..58224ef870096a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -49,9 +49,6 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { @@ -126,6 +123,24 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +StatusOr> HloSharding::AsShapeTree( + const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + int64 num_leaves = result.leaf_count(); + TF_RET_CHECK(num_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + return std::move(result); + } else { + return ShapeTree(shape, *this); + } +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -367,10 +382,11 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - ShapeTree sub_shape_tree(ShapeUtil::GetSubshape(shape, index), - Replicate()); + Shape sub_shape = ShapeUtil::GetSubshape(shape, index); + ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return Tuple(sub_shape_tree); + return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) + : sub_shape_tree.element(ShapeIndex({})); } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 2b8e757f42991f..f4a0fb626f2c3e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -99,6 +99,9 @@ class HloSharding { static bool IsReservedDevice(int64 device) { return device < 0; } OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -160,19 +163,9 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. + StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { - if (IsTuple()) { - ShapeTree result(shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } else { - return ShapeTree(shape, *this); - } + return AsShapeTree(shape).ValueOrDie(); } // Retrieves the sub sharding at a given index, out of a tuple sharding. @@ -208,6 +201,12 @@ class HloSharding { return h; } + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); + } + }; + // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc new file mode 100644 index 00000000000000..82cff2a4b7146c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -0,0 +1,401 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +struct PassThrough { + PassThrough(HloInstruction* user, HloInstruction* operand) + : user(user), operand(operand) {} + + HloInstruction* user = nullptr; + HloInstruction* operand = nullptr; +}; + +void SetDeviceSharding(HloInstruction* instruction, int64 device) { + VLOG(4) << " " << instruction->name() << " to device " << device; + instruction->set_device_sharding(device); +} + +tensorflow::gtl::optional ShardingUniqueDevice( + const HloSharding& sharding) { + if (sharding.IsTileMaximal()) { + auto device = sharding.UniqueDevice(); + if (device.ok()) { + return device.ValueOrDie(); + } + } + return tensorflow::gtl::optional(); +} + +bool ShardingMatches(const HloSharding& sharding1, + const HloSharding& sharding2) { + auto device1 = ShardingUniqueDevice(sharding1); + if (device1) { + auto device2 = ShardingUniqueDevice(sharding2); + if (device2) { + return *device1 == *device2; + } + } + // Anything which is not tile maximal with unique device, gets a full sharding + // compare. + return sharding1 == sharding2; +} + +// When we create domains, they are never "empty", where with empty we mean +// that a kDomain instruction has as operand another kDomain instruction of the +// same kind. +// But when the HLO optimizations are run, empty domains can be created. +// For example: +// +// Domain(device=None, device=0) -> +// Tuple(device=0) -> +// GTE(device=0) -> +// Domain(device=0, device=None) +// +// In that case the tuple simplifier could create something like: +// +// Domain(device=None, device=0) -> Domain(device=0, device=None) +// +// Which is a so called empty domain. +// In the case above, crossing an empty domain which was transiting through +// device 0, requires the normalization phase to fixup the empty domain by +// adding back a Tuple+GTE pair with the proper device. +// One particular case where this can create problems is the result of the +// entry computation, where the GTE assignments are used by TF to tell the +// XLA where the results should be sent. +std::vector LocatePassThroughDomainLinks( + const DomainMetadata::Domain& domain) { + std::vector pass_through; + for (HloInstruction* instruction : domain.enter_domains) { + CHECK(instruction->opcode() == HloOpcode::kDomain) + << "Instruction is not a kDomain: " << instruction->ToString(); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(user) != 0) { + pass_through.emplace_back(user, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " " << user->ToString(); + VLOG(2) << " " << instruction->ToString(); + } + } + } + return pass_through; +} + +Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + for (auto& pass_through : LocatePassThroughDomainLinks(domain)) { + HloInstruction* tuple = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateTuple({pass_through.operand})); + HloInstruction* gte = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), + tuple, 0)); + gte->set_sharding(sharding); + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } + return Status::OK(); +} + +std::unique_ptr CloneShardingForDomain( + const HloSharding& sharding) { + auto device = ShardingUniqueDevice(sharding); + if (!device) { + return MakeUnique(sharding); + } + return MakeUnique(HloSharding::AssignDevice(*device)); +} + +Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain, + int64 device) { + VLOG(4) << "Applying device " << device << " sharding"; + for (HloInstruction* instruction : domain.instructions) { + // We only change instructions without sharding, since otherwise we might + // mess up with eventual HLO passes which has knowledge of it. + if (!instruction->has_sharding()) { + SetDeviceSharding(instruction, device); + } else { + VLOG(4) << " " << instruction->name() << " already has sharding " + << instruction->sharding(); + } + } + return Status::OK(); +} + +// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. +// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() +// sharding will be returned. +ShapeTree GetTupleSharding(HloInstruction* tuple) { + if (tuple->has_sharding()) { + return tuple->sharding().GetAsShapeTree(tuple->shape()); + } + return ShapeTree(tuple->shape(), HloSharding::Replicate()); +} + +// Retrieves the sharding of operand, asked from a user instruction which is +// within domain. If operand is a kDomain, it means that sharding argument is +// the operand sharding, otherwise the operand's own sharding will be returned. +const HloSharding* GetOperandSharding(const HloInstruction* operand, + const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); + // Here the user of operand is within the domain instruction set, and since it + // is user of operand, we need to look into the enter_domains set. If this is + // not a kDomain within the user domains set, then return the operand + // sharding, if any. + if (operand->opcode() != HloOpcode::kDomain || + domain.enter_domains.count(const_cast(operand)) == 0) { + return operand->has_sharding() ? &operand->sharding() : nullptr; + } + // At this point operand is a kDomain of the currently processed domain, so we + // can refer to sharding as the domain sharding. + return &sharding; +} + +// Tries to propagate the sharding information into the instructions that are +// part of the domain, in a post order manner (operand propagate to user). +StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + int64 assigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (instruction->has_sharding()) { + continue; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* tuple = instruction->mutable_operand(0); + const HloSharding* tuple_sharding = + GetOperandSharding(tuple, domain, sharding); + if (tuple_sharding != nullptr) { + TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString(); + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + ++assigned; + } + } else if (instruction->opcode() == HloOpcode::kTuple) { + int64 tuple_assigned = 0; + ShapeTree shape_tree = GetTupleSharding(instruction); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloSharding* operand_sharding = + GetOperandSharding(instruction->operand(i), domain, sharding); + if (operand_sharding != nullptr && + shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } + } + if (tuple_assigned > 0) { + HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); + VLOG(4) << " " << instruction->name() << " to sharding " + << tuple_sharding; + instruction->set_sharding(tuple_sharding); + ++assigned; + } + } else { + // If all the operand of the given instruction has the same single device + // assignment, assign that device to this instruction as well. + const HloSharding* common_sharding = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + const HloSharding* operand_sharding = + GetOperandSharding(operand, domain, sharding); + if (operand_sharding != nullptr) { + if (common_sharding != nullptr && + *common_sharding != *operand_sharding) { + common_sharding = nullptr; + break; + } + common_sharding = operand_sharding; + } + } + if (common_sharding != nullptr) { + VLOG(4) << " " << instruction->name() << " to sharding " + << *common_sharding; + instruction->set_sharding(*common_sharding); + ++assigned; + } + } + } + return assigned; +} + +Status ApplyDomainSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + auto device = ShardingUniqueDevice(sharding); + if (device) { + // Shortcut the simple case. We have a unique device sharding, so we call + // the ApplyDomainDeviceSharding() API which will apply array or tuple + // shaped device sharding to the domain instructions. + return ApplyDomainDeviceSharding(domain, *device); + } + VLOG(1) << "Assigning non-trivial sharding " << sharding; + for (;;) { + TF_ASSIGN_OR_RETURN(int64 assigned, + ApplyDomainShardingPass(domain, sharding)); + if (assigned == 0) { + break; + } + } + int64 unassigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (!instruction->has_sharding()) { + LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); + ++unassigned; + } + } + // Should we error out if unassigned > 0? + return Status::OK(); +} + +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +std::unique_ptr CreateDomain(HloInstruction* instruction, + HloInstruction* operand) { + const HloSharding* instruction_sharding = + instruction->has_sharding() ? &instruction->sharding() : nullptr; + const HloSharding* operand_sharding = + operand->has_sharding() ? &operand->sharding() : nullptr; + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && operand_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && operand_sharding != nullptr && + ShardingMatches(*instruction_sharding, *operand_sharding)) { + return nullptr; + } + std::unique_ptr real_instruction_sharding; + std::unique_ptr real_operand_sharding; + if (instruction_sharding != nullptr) { + real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); + } + if (operand_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*operand_sharding); + } + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (real_instruction_sharding != nullptr + ? real_instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (real_operand_sharding != nullptr + ? real_operand_sharding->ToString() + : "None"); + + std::unique_ptr operand_side_metadata = + MakeUnique(std::move(real_operand_sharding)); + std::unique_ptr user_side_metadata = + MakeUnique(std::move(real_instruction_sharding)); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +StatusOr> ExtractOriginalCommonSharding( + tensorflow::gtl::ArraySlice instructions) { + // If we are here, all the instructions being passed had the same sharding + // (or no sharding), by the means of the ShardingMatches() API. + // As such, no kDomain was inserted, and here we are asked to extract the + // original common sharding. + // All the instructions passed to this API are part of the same computation. + const HloSharding* sharding = nullptr; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + if (sharding == nullptr) { + sharding = &instruction->sharding(); + } else { + TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) + << "Sharding " << *sharding << " does not match the one in " + << instruction->ToString(); + } + } + } + if (sharding == nullptr) { + return std::unique_ptr(); + } + VLOG(4) << "Extracted sharding is " << *sharding; + return CloneShardingForDomain(*sharding); +} + +} // namespace + +std::unique_ptr ShardingMetadata::Clone() const { + std::unique_ptr sharding; + if (sharding_ != nullptr) { + sharding = MakeUnique(*sharding_); + } + return MakeUnique(std::move(sharding)); +} + +bool ShardingMetadata::Matches(const DomainMetadata& other) const { + const ShardingMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a ShardingMetadata, then it is clearly a no match. + return false; + } + if (sharding_ == nullptr) { + return other_ptr->sharding_ == nullptr; + } + return other_ptr->sharding_ != nullptr + ? ShardingMatches(*sharding_, *other_ptr->sharding_) + : false; +} + +string ShardingMetadata::ToString() const { + return sharding_ != nullptr ? sharding_->ToString() : "None"; +} + +Status ShardingMetadata::NormalizeInstructions( + const DomainMetadata::Domain& domain) const { + if (sharding_ != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); + } + return Status::OK(); +} + +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() + << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } + return Status::OK(); +} + +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand) { + return CreateDomain(instruction, operand); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h new file mode 100644 index 00000000000000..ec162c34904ee2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// A DomainMetadata implementation that internally wraps a sharding attribute. +class ShardingMetadata : public DomainMetadata { + public: + explicit ShardingMetadata(std::unique_ptr sharding) + : sharding_(std::move(sharding)) {} + + std::unique_ptr Clone() const override; + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override; + + string ToString() const override; + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override; + + static tensorflow::StringPiece KindName() { return "sharding"; } + + private: + std::unique_ptr sharding_; +}; + +// Within a set of instructions which had common sharding attributes before +// entring the HLO passes pipeline, apply sharding heuristics and normalize the +// instructions whose sharding deviates from the one which is inferred as to be +// the original one. +// Policy wise, HLO passes are allowed to create new unassigned instructions, +// but if they do create assigned ones, they have to conform to the ones around. +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); + +// Given an HLO graph edge between instruction and one of its operands, creates +// a ShardingMetadata based kDomain instruction if the sharding between +// instruction and operand changes. Returns nullptr if there is no need for a +// domain separation. +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 3bf0d25efb7fad..ee7133689b1534 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_sharding.h" - #include #include #include #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -312,5 +311,48 @@ TEST_F(HloShardingTest, OstreamTest) { EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); } +TEST_F(HloShardingTest, ParseHloString) { + auto check = [](const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding, + ParseSharding(sharding.ToString())); + EXPECT_EQ(sharding, parsed_sharding); + }; + check(HloSharding::Replicate()); + check(HloSharding::AssignDevice(2)); + check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}}))); + // Empty tuple. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {})); + { + // Non-nested tuple. + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})}); + check(HloSharding::Tuple( + tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)})); + } + { + // Nested tuple. + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})})}); + std::vector leaf_shardings = { + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)}; + ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); + // Assign leaf_shardings to sharding_tree leaves. + auto it = leaf_shardings.begin(); + for (auto& index_to_sharding : sharding_tree.leaves()) { + index_to_sharding.second = *it++; + } + check(HloSharding::Tuple(sharding_tree)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index f8d98f06785967..be156d765dc10d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h similarity index 84% rename from tensorflow/compiler/xla/tools/parser/hlo_token.h rename to tensorflow/compiler/xla/service/hlo_token.h index 7928bee5c2097f..533429608bc2e1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ #include @@ -22,9 +22,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Defines different kinds of tokens in a hlo module string. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. enum class TokKind { // Markers kEof, @@ -72,7 +74,6 @@ enum class TokKind { string TokKindToString(TokKind kind); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 05b7dce3d1ecf9..7b27dbfec376b8 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -69,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) { HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) - : id_(id), is_phi_(is_phi) { + : BufferValue(instruction, index, id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. positions_.push_back(HloPosition{instruction, index}); } @@ -90,8 +92,8 @@ string HloValue::ToShortString() const { string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) ? defining_index().ToString() : ""; - return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(), - index_str); + return StrCat(id(), " ", is_phi_ ? "PHI " : "", + defining_instruction()->name(), index_str); } string HloValue::ToString(int indent) const { diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 2a711e8b42590c..a1151f65e07dff 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -16,16 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ -#include +#include #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -80,30 +84,9 @@ struct HloUse { std::ostream& operator<<(std::ostream& out, const HloUse& use); -// Class describing a value used by the dataflow analysis. XLA arrays are -// trivially a single HloValue. Tuples are made up of more than one HloValue: an -// HloValue for the pointer vector, and an HloValue for each child element. -// -// Every HloValue is defined by a particular instruction and most instructions -// define only a single HloValue. Instructions which define a single HloValue -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single HloValue -// which is a vector of pointers to the values containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple values only the top-level HloValue (the vector of pointers) is -// defined by the Tuple instruction. The values containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one HloValue. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. These tuple-shaped instructions do -// not assemble a tuple from existing HloValues like the Tuple instruction does, -// but rather define all the HloValues in the tuple. -class HloValue { +// HloDataflowAnalysis uses this subclass of BufferValue. +class HloValue : public BufferValue { public: - using Id = int64; - // Predicate comparing HloValues by increasing id, useful for std::sort. static bool IdLessThan(const HloValue* a, const HloValue* b) { return a->id() < b->id(); @@ -120,6 +103,7 @@ class HloValue { // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + ~HloValue() override {} // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not @@ -127,10 +111,6 @@ class HloValue { void SetPositionsAndComputeUses( tensorflow::gtl::ArraySlice positions); - // Return a unique identifier for this HloValue. This value is used for stable - // sorting and iteration - Id id() const { return id_; } - // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -142,12 +122,18 @@ class HloValue { return defining_position().instruction; } + HloInstruction* instruction() const override { + return defining_instruction(); + } + // Return the shape index at which this HloValue is defined in the output of // its defining instruction. const ShapeIndex& defining_index() const { return defining_position().index; } + const ShapeIndex& index() const override { return defining_index(); } + // Return the shape of this HloValue. - const Shape& shape() const { return defining_position().shape(); } + const Shape& shape() const override { return defining_position().shape(); } // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } @@ -164,12 +150,11 @@ class HloValue { // Return a single-line string representation of the value. string ToShortString() const; - string ToString(int indent = 0) const; + string ToString(int indent) const; - private: - // Unique identifier for this HloValue. Used for stable sorting and iteration. - const Id id_; + string ToString() const override { return ToString(0); } + private: // Whether this instruction is a phi value. const bool is_phi_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8a30cbf9cd622f..9cfd8a9bf74bc6 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -106,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the @@ -116,7 +114,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -127,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -164,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { @@ -183,7 +179,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { operand_shape.dimensions(operand_dimension)) << broadcast->ToString() << " operand shape " << operand_shape; } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { @@ -191,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { @@ -200,22 +196,18 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, @@ -384,6 +376,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: @@ -410,7 +403,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -490,14 +483,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -541,13 +534,13 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), instr2->ToString().c_str(), instr2->channel_id()); } - return tensorflow::Status::OK(); + return Status::OK(); } string ComputationsToString( @@ -571,22 +564,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -602,7 +595,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -612,14 +605,14 @@ Status VerifyHloStructure(HloModule* module) { } } } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -635,37 +628,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -674,13 +667,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -695,41 +688,40 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { @@ -778,7 +770,7 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { "init: %s, body: %s", init->ToString().c_str(), body_root->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { @@ -796,7 +788,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { ShapeUtil::HumanString(operand_shape).c_str()); } } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr HloVerifier::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 6208887547a14d..1392a78097aa02 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -82,9 +82,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } + Status FinishVisit(HloInstruction*) override { return Status::OK(); } protected: // Check the instruction's shape against the shape given by ShapeInference diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 13e4557317f74b..d7458c338e9f1d 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -27,6 +27,7 @@ using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; @@ -35,20 +36,26 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto append_op = [&](const OpInfo& op) { + auto print_op = [&](const OpInfo& op) { + // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that + // were expected to be free and are actually free -- things like (on most + // backends) kParameter or kConstant HLOs. There's no need to clutter the + // profile with these. + if (op.optimal_seconds == 0 && op.cycles == 0) { + return; + } + string bytes_per_sec; string bytes_per_cycle; - if (op.cycles <= 0 || op.bytes_accessed < 0) { - bytes_per_sec = ""; - bytes_per_cycle = ""; - } else { - bytes_per_sec = - HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); + if (op.cycles > 0 && op.bytes_accessed >= 0) { + bytes_per_sec = StrCat( + HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)), + "/s"); + double bpc = static_cast(op.bytes_accessed) / op.cycles; if (op.bytes_accessed > op.cycles) { - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = - Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + bytes_per_cycle = Printf("%.3fB/cycle", bpc); } } @@ -59,14 +66,16 @@ string HumanReadableProfileBuilder::ToString() const { double nsecs = op.cycles / clock_rate_ghz_; Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec (%12.1f optimal) :: %18s " - ":: %18s :: %12s/s :: %12s/cycle :: %s\n", + "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " + ":: %18s :: %14s :: %16s :: %s\n", op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds * 1e6, + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), op.flop_count <= 0 - ? "" + ? "" : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" + op.transcendental_count <= 0 ? "" : HumanReadableNumTranscendentalOps( op.transcendental_count, nsecs) .c_str(), @@ -78,24 +87,26 @@ string HumanReadableProfileBuilder::ToString() const { int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { - optimal_seconds_sum += op.optimal_seconds; - total_flops += op.flop_count; - total_transcendentals += op.transcendental_count; - total_bytes += op.bytes_accessed; + if (op.optimal_seconds > 0) { + optimal_seconds_sum += op.optimal_seconds; + } + total_flops += std::max(op.flop_count, int64{0}); + total_transcendentals += std::max(op.transcendental_count, int64{0}); + total_bytes += std::max(op.bytes_accessed, int64{0}); } VLOG(1) << "Total floating point ops: " << total_flops; - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); - // Sort ops in decreasing order of cycles. + // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); std::sort( sorted_ops.begin(), sorted_ops.end(), [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); for (const auto& op : sorted_ops) { - append_op(op); + print_op(op); } if (total_cycles_ <= 0) { @@ -109,8 +120,20 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds above estimated optimum"); table.SetEntryName("ops"); table.SetShowCategoryTable(); + table.SetShowAllEntries(); float total_discrepancy_in_microseconds = 0.0f; - for (const auto& op : sorted_ops) { + for (const auto& op : op_infos_) { + // Skip ops with < 0 optimal seconds. These are ops for which we don't + // know the optimal time. + if (op.optimal_seconds < 0) { + continue; + } + // Also skip ops with 0 actual cycles. These ops were free; there's no + // need to clutter the "above estimated optimum" table with them, + // because they can't be optimized further. + if (op.cycles == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -128,7 +151,14 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds"); table.SetEntryName("ops"); table.SetShowCategoryTable(); - for (const auto& op : sorted_ops) { + table.SetShowAllEntries(); + for (const auto& op : op_infos_) { + // Skip ops with 0 optimal seconds and 0 actual cycles. As in + // print_op(), these are uninteresting because they're expected to be + // free, and they were actually free. + if (op.cycles == 0 && op.optimal_seconds == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -139,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const { StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); } } + + if (total_bytes > 0) { + MetricTableReport table; + table.SetMetricName("MiB read+written"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : op_infos_) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = static_cast(op.bytes_accessed) / (1 << 20); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, + table.MakeReport(static_cast(total_bytes) / (1 << 20))); + } return s; } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index fc24acd2713f4c..6f56c3aa82e9d1 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -32,7 +32,7 @@ class HumanReadableProfileBuilder { explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(computation_name.ToString()), + : computation_name_(std::string(computation_name)), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -41,15 +41,17 @@ class HumanReadableProfileBuilder { int64 total_cycles() const { return total_cycles_; } // Adds an operation to the profile. If you don't know the number of - // floating-point ops or bytes touched by the op, pass -1 for that param. + // floating-point ops or bytes touched by the op, or if you don't know how + // fast it would run optimally, pass -1 for that param. void AddOp(tensorflow::StringPiece op_name, tensorflow::StringPiece short_name, tensorflow::StringPiece category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back( - {op_name.ToString(), short_name.ToString(), category.ToString(), cycles, - flop_count, transcendental_count, bytes_accessed, optimal_seconds}); + op_infos_.push_back({std::string(op_name), std::string(short_name), + std::string(category), cycles, flop_count, + transcendental_count, bytes_accessed, + optimal_seconds}); } // Gets the human-readable profile. @@ -61,10 +63,10 @@ class HumanReadableProfileBuilder { string short_name; string category; int64 cycles; - int64 flop_count; + int64 flop_count; // -1 if unknown int64 transcendental_count; - int64 bytes_accessed; - float optimal_seconds; + int64 bytes_accessed; // -1 if unknown + float optimal_seconds; // -1 if unknown }; double CyclesToSeconds(int64 cycles) const { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc new file mode 100644 index 00000000000000..8b3fa6c1572cf0 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -0,0 +1,733 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gtl = ::tensorflow::gtl; + +namespace { +using Analysis = IndexedArrayAnalysis; +using UnknownArray = Analysis::UnknownArray; +using ConstantArray = Analysis::ConstantArray; +using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using tensorflow::gtl::ArraySlice; +using tensorflow::str_util::Join; +} // namespace + +string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { + switch (root->kind()) { + case Array::kUnknown: { + auto* unknown_tensor = root->as(); + return tensorflow::strings::StrCat("%", + unknown_tensor->instruction().name()); + } + + case Array::kConstant: { + if (print_constants) { + string contents = root->as()->literal()->ToString(); + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, + ")"); + } + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + } + + case Array::kScalarIndexedConstant: + case Array::kScalarIndexed: { + auto* indexed_array = root->as(); + string name = root->kind() == Array::kScalarIndexedConstant + ? "scalar-indexed-const" + : "scalar-indexed"; + return tensorflow::strings::StrCat( + "(", name, " ", ToString(indexed_array->source(), print_constants), + " ", ToString(indexed_array->indices(), print_constants), " ", + indexed_array->source_dim(), "->[", + Join(indexed_array->output_dims(), ","), "])"); + } + } +} + +StatusOr IndexedArrayAnalysis::GetArrayFor( + const HloInstruction* instr) { + auto it = cache_.find(instr); + if (it != cache_.end()) { + return it->second; + } + + TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); + return FindOrDie(cache_, instr); +} + +Status IndexedArrayAnalysis::TraverseAndPopulateCache( + const HloInstruction* root) { + // Depth first search over the DAG, invoking ComputeArrayFor in post order. + // The HLO instructions already in the cache are considered leaves. + + gtl::InlinedVector stack; + + enum DfsState { kDiscovered, kVisited }; + gtl::FlatMap dfs_state_map; + + stack.push_back(root); + InsertOrDie(&dfs_state_map, root, kDiscovered); + + do { + const HloInstruction* instr = stack.back(); + if (cache_.count(instr)) { + stack.pop_back(); + continue; + } + + switch (FindOrDie(dfs_state_map, instr)) { + case kDiscovered: { + for (const HloInstruction* operand : instr->operands()) { + if (!cache_.count(operand)) { + stack.push_back(operand); + CHECK(!dfs_state_map.count(operand) || + dfs_state_map[operand] == kDiscovered); + dfs_state_map[operand] = kDiscovered; + } + } + dfs_state_map[instr] = kVisited; + break; + } + + case kVisited: + stack.pop_back(); + TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); + InsertOrDie(&cache_, instr, array); + break; + } + } while (!stack.empty()); + + return Status::OK(); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayFor( + const HloInstruction* instr) { + Array* computed_array; + if (instr->IsElementwise() && instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseUnaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->IsElementwise() && instr->operand_count() == 2) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseBinaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForConstant(instr->literal())); + } else if (instr->opcode() == HloOpcode::kGather) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForReshape(instr->shape(), + FindOrDie(cache_, instr->operand(0)))); + } else { + computed_array = nullptr; + } + + if (!computed_array) { + computed_array = Construct(instr); + } + + return computed_array; +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( + const Literal& literal) { + return Construct(&literal); +} + +StatusOr IndexedArrayAnalysis::FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape) { + // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). + // `source` is the inner Gather(A, X). + + Array* a = source->source(); + Array* x = source->indices(); + Array* y = indices; + + // This bit is slightly tricky, so we do a naive "simulation" of the two + // consecutive gather operations to infer what the composed gather should look + // like. + + enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; + + std::vector simulated_index(a->shape().dimensions_size(), + IndexComponent::Ungathered); + + // Simulate the first gather. + EraseAt(&simulated_index, source->source_dim()); + for (int64 gather_dim : source->output_dims()) { + simulated_index.insert(simulated_index.begin() + gather_dim, + IndexComponent::GatheredFirst); + } + + // Simulate the second gather. + EraseAt(&simulated_index, source_dim); + for (int64 output_dim : output_dims) { + simulated_index.insert(simulated_index.begin() + output_dim, + IndexComponent::GatheredSecond); + } + + int64 source_dim_for_index_array = + FindIndex(source->output_dims(), source_dim); + CHECK_NE(source_dim_for_index_array, source->output_dims().size()); + + std::vector output_dims_for_index_array; + int64 gathered_index_components_seen = 0; + for (IndexComponent simulation_dim : simulated_index) { + if (simulation_dim == IndexComponent::GatheredSecond) { + output_dims_for_index_array.push_back(gathered_index_components_seen); + } + if (simulation_dim != IndexComponent::Ungathered) { + gathered_index_components_seen++; + } + } + + std::vector dim_sizes_for_composed_index; + std::vector output_dims_for_new_gather; + for (int64 i = 0, e = simulated_index.size(); i < e; i++) { + if (simulated_index[i] != IndexComponent::Ungathered) { + dim_sizes_for_composed_index.push_back(shape.dimensions(i)); + output_dims_for_new_gather.push_back(i); + } + } + + Array* inner_indices = ConstructScalarIndexedArray( + x, y, source_dim_for_index_array, output_dims_for_index_array, + ShapeUtil::MakeShape(x->shape().element_type(), + dim_sizes_for_composed_index)); + return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), + output_dims_for_new_gather, + std::move(shape)); +} + +StatusOr IndexedArrayAnalysis::ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + return nullptr; + } + + CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + if (!c_binary_search(dim_numbers.elided_window_dims(), + dim_numbers.gather_dims_to_operand_dims(0))) { + return nullptr; + } + + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + std::vector output_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + output_dims.push_back(i); + } + } + + if (auto* indexed = dynamic_cast(source)) { + auto it = c_find(indexed->output_dims(), source_dim); + if (it != indexed->output_dims().end()) { + return FoldGatherOfGather(indexed, indices, source_dim, output_dims, + shape); + } + } else if (auto* constant = dynamic_cast(source)) { + return Construct(constant, indices, source_dim, + output_dims, shape); + } + + return Construct(source, indices, source_dim, output_dims, + shape); +} + +namespace { +// Returns an index into `values` such that the product of the range +// [values.begin()+index, values.end()) is equal to `product`. If there is no +// such index, return -1. All integers in `values` must be positive. +int64 FindSuffixWithProduct(ArraySlice values, int64 product) { + DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + + int64 current_product = 1; + int64 i; + for (i = values.size() - 1; i >= 0 && product > current_product; --i) { + current_product *= values[i]; + } + + if (product == current_product) { + return i + 1; + } + + return -1; +} + +struct ReshapePassthroughDimPair { + int64 result_dim; + int64 operand_dim; +}; + +// Returns a set of dimension pairs such for all (result_dim, operand_dim) in +// the set: +// +// output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim] +// +// The returned vector of pairs is sorted in both the result_dim and the +// operand_dim components. +std::vector ComputeReshapePassthroughDimPairs( + ArraySlice operand_shape, ArraySlice result_shape) { + // A reshape can be seen as an index mapping from output index to input index: + // + // (i_0, ..., i_n) = f(o_0, ..., o_m) + // + // This function returns the pairs (j, k) for which the following invariant + // holds for all indices in the shape: + // + // o_j == i_k + // + // And this occurs when: + // + // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m + // + // (where O_x are the sizes of the output shape and I_x are the sizes of the + // input shape) and the size of the dimension j of the result is the same as + // the size of dimension k in the operand. + // + // These conditions are sufficient because the Reshape HLO is spec'ed such + // that the rightmost dimensions are always minor in the flattening and refine + // operation. + + std::vector result; + int64 result_subarray_size = 1; + for (int64 result_dim = result_shape.size() - 1; result_dim >= 0; + --result_dim) { + int64 candidate_operand_dim = + FindSuffixWithProduct(operand_shape, result_subarray_size); + + // result_subarray_size does not include the elements in the current + // `result_dim` dimension (we multiply in result_shape[result_dim] at the + // end of loop body) so candidate_operand_dim can never be zero. + CHECK_NE(candidate_operand_dim, 0); + + if (candidate_operand_dim != -1 && + result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { + result.push_back({/*result_dim=*/result_dim, + /*operand_dim=*/candidate_operand_dim - 1}); + } + result_subarray_size *= result_shape[result_dim]; + } + + c_reverse(result); + + if (VLOG_IS_ON(3)) { + std::vector result_strings; + c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" + << Join(result_shape, ",") << "] passthrough indices are [" + << Join(result_strings, ",") << "]"; + } + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.result_dim < rhs.result_dim; + })); + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.operand_dim < rhs.operand_dim; + })); + + return result; +} + +// Return true if `dim` is stated as an passthrough operand dim in +// `passthrough_dims`. +bool IsReshapePassthroughOperandDim( + ArraySlice passthrough_dims, int64 dim) { + return c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); +} + +// Maps `operand_dim` which must be an passthrough operand dimension to its +// corresponding passthrough result dimension based on `passthrough_dims`. +int64 MapPassthroughOperandDimToResultDim( + ArraySlice passthrough_dims, int64 operand_dim) { + auto it = c_find_if(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); + CHECK(it != passthrough_dims.end()); + return it->result_dim; +} + +int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, + ArraySlice result_shape, + int64 source_passthrough_dim) { + int64 indexed_source_subarray_size = + std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, + operand_shape.end(), 1, std::multiplies()); + + return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); +} + +}; // namespace + +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + auto* scalar_indexed = dynamic_cast(operand); + if (!scalar_indexed) { + return nullptr; + } + + // Try to fold Reshape(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(Const', Indices) + // + // We can view the reshape and the scalar-indexed operations as functions that + // map an output index (i.e. an index into the result) to an input index + // (i.e. an index into the operand). The key idea used here is that the + // output-to-input mapping for some reshape operations may "pass through" some + // output dimensions into the input space unchanged -- i.e. there may exist + // output dimension "O" and input dimension "I" such that OutputIndex[O] is + // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through + // dimensions in the input space of the reshape happen to be include all the + // output dimensions for the scalar-indexed node then, roughly, the following + // holds: + // + // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx)) + // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs)) + // + // Where Ps are the set of the pass-through components of Idx that are + // also the output dims of the scalar-indexed node, and Qs are the rest. + // For brevity, we're playing fast and loose with the notation here -- we + // don't literally require Idx to be a concatenation of Ps and Qs, as + // suggested by the "++". + // + // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs)) + // + // Again, we're playing fast and loose with the notation around "++". + // Generally this ++ will be a different function that the ++ in the + // previous step. + // + // If the scalar-indexed node has a constant as the source then the + // SourceIndexOfReshape function can be "folded into" the constant itself by + // reshaping it, leaving us with: + // + // == SourceIndexOfScalarIndexed(Ps ++ Qs) + // == SourceIndexOfScalarIndexed(Idx) + // + // which is just a scalar-indexed node (with parameters different from the + // scalar-indexed node we started with) with a reshaped constant as the + // source. + // + // We can't fold SourceIndexOfReshape into the constant without introducing + // another precondition: since the new scalar-indexed node will have a + // reshaped (constant) array as its source it will, in general, have a + // different source dimension than the original scalar-indexed node. This + // source dimension will have to be a passthrough dimension of the + // SourceIndexOfReshape indexing function that is folded into the source. And + // such a dimension need not exist so this is a non-trivial precondition. + + std::vector reshape_passthrough_dims = + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*result_shape=*/AsInt64Slice(shape.dimensions())); + + auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { + return IsReshapePassthroughOperandDim(reshape_passthrough_dims, + operand_dim); + }; + + if (!c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { + return nullptr; + } + + // To compute the shape of the source for the new scalar-indexed node we're + // going to create, we first "undo" the scalar-indexed operation. + std::vector new_scalar_indexed_source_shape(shape.dimensions().begin(), + shape.dimensions().end()); + for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) { + int64 output_dim = scalar_indexed->output_dims()[i]; + int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim( + reshape_passthrough_dims, output_dim); + EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape); + } + + // After this, we need to add in the dimension that will be the source + // dimension for the new scalar-indexed node. A scalar-indexed node "removes" + // the source dimensions and "adds" the output dimensions, so to get back to + // the shape for the *source* of the scalar-indexed node we need to remove the + // output dims (which we did above) and then add back the source dim (which we + // are about to do below): + + const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape(); + + int64 source_dim_for_new_scalar_indexed_node = + FindSourcePositionForPassthroughResultDim( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape, + scalar_indexed->source_dim()); + + // We may not be able to find a source dim for the new scalar-indexed node. + // For instance consider: + // + // operand = s32[3,5,2] constant({...}) + // indices = s32[7] parameter(0) + // gather = s32[3,2,7] gather(operand, indices), + // output_window_dims={0,1}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3,1,2} + // reshape = s32[6,7] reshape(gather) + // + // In this case the gather maps to: + // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2]) + // + // and the reshape passes through dimension 2 from its input into dimension 1 + // in its output. However, we can't rewrite the reshape as a scalar-indexed + // node because then we'd have to reshape the [3,5,2] `operand` array to + // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently + // (a.k.a. isn't pass-through) than the [3,5,2] array. + + if (source_dim_for_new_scalar_indexed_node == -1) { + return nullptr; + } + + InsertAt( + &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, + scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + + CHECK(IsReshapePassthroughOperandDim( + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape), + scalar_indexed->source_dim())); + + auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) { + return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims, + result_dim); + }; + + std::vector output_dims_for_new_scalar_indexed_node; + c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); + + TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, + TakeOwnership(scalar_indexed->literal().Reshape( + new_scalar_indexed_source_shape))); + TF_ASSIGN_OR_RETURN( + Array * new_scalar_indexed_source, + ComputeArrayForConstant(*new_scalar_indexed_source_literal)); + + return ConstructScalarIndexedArray( + new_scalar_indexed_source, scalar_indexed->indices(), + source_dim_for_new_scalar_indexed_node, + output_dims_for_new_scalar_indexed_node, shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs) { + // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) + // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) + // + // We can do this if every output dimension from the scalar-indexed node is a + // broadcasted dimension for the broadcast node. Informally, the precondition + // means Broadcast(Const0)[IDX] is solely a function of the components of IDX + // that are not output-dims for the scalar-indexed node. In other words, for + // every assignment to the non-output dims in IDX we have a "constant" LHS to + // the BinaryOp. This transform propagates this "constant" to the source for + // the scalar-indexed node. + + ScalarIndexedConstantArray* lhs_scalar_indexed_const = + dynamic_cast(lhs); + ScalarIndexedConstantArray* rhs_scalar_indexed_const = + dynamic_cast(rhs); + + bool lhs_is_indexed; + + // One of the operands must be scalar-indexed and the other must be a + // broadcast of a constant. + if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) { + lhs_is_indexed = true; + } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) { + lhs_is_indexed = false; + } else { + return nullptr; + } + + ScalarIndexedConstantArray* scalar_indexed_const = + lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const; + UnknownArray* candidate_broadcast_array = + dynamic_cast(lhs_is_indexed ? rhs : lhs); + if (!candidate_broadcast_array || + candidate_broadcast_array->instruction().opcode() != + HloOpcode::kBroadcast) { + return nullptr; + } + + const HloInstruction* broadcast_instr = + &candidate_broadcast_array->instruction(); + const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0); + if (broadcast_const_operand->opcode() != HloOpcode::kConstant) { + return nullptr; + } + + ArraySlice broadcast_dims = broadcast_instr->dimensions(); + auto is_broadcasted_dim = [&](int64 output_dim) { + return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + }; + + // All of the output dims must be "broadcasted" dims for the other operand. + if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + return nullptr; + } + + // To figure out the broadcast dimensions for the (constant) source for the + // scalar-indexed node, we "simulate" the index transformation done by the + // existing broadcsat: + enum class IndexComponent { Broadcasted, NotBroadcasted }; + std::vector simulated_index( + broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); + for (int64 broadcast_dim : broadcast_dims) { + simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; + } + + // The scalar-indexed node "removes" the source dim and "inserts" the output + // dims. We do the opposite here to undo the scalar-indexed operation. + ArraySlice output_dims = scalar_indexed_const->output_dims(); + for (int64 i = output_dims.size() - 1; i >= 0; --i) { + CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); + EraseAt(&simulated_index, output_dims[i]); + } + + InsertAt(&simulated_index, scalar_indexed_const->source_dim(), + IndexComponent::Broadcasted); + + // new_inner_broadcast_dims holds the broadcast dimensions for the inner + // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to + // new_inner_broadcast_dims. + std::vector new_inner_broadcast_dims; + for (int64 i = 0; i < simulated_index.size(); i++) { + if (simulated_index[i] == IndexComponent::NotBroadcasted) { + new_inner_broadcast_dims.push_back(i); + } + } + + // inner_broadcast_result is the Broadcast'(Const0) bit in + // BinaryOp(Broadcast'(Const0), Const1) + TF_ASSIGN_OR_RETURN( + std::unique_ptr inner_broadcast_result, + broadcast_const_operand->literal().Broadcast( + scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); + + // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) + const Literal* literal_for_new_source; + if (lhs_is_indexed) { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + } else { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + } + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand) { + auto* scalar_indexed_const = + dynamic_cast(operand); + if (scalar_indexed_const == nullptr) { + return nullptr; + } + + // Fold UnaryOp(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(UnaryOp(Const), Indices) + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( + opcode, scalar_indexed_const->literal()))); + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { + return "indexed-array-analysis-printer-pass"; +} + +StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { + if (!VLOG_IS_ON(2)) { + return false; + } + + IndexedArrayAnalysis analysis; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instr : computation->instructions()) { + TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); + if (!dynamic_cast(t) && !dynamic_cast(t)) { + VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h new file mode 100644 index 00000000000000..ce92fd2919c90f --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -0,0 +1,326 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the + // input index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + + // `source_dim` is the dimension in the source array that is being indexed + // over using indices from the `indices` array. See the class documentation + // and the overview for more details. + int64 source_dim() const { return source_dim_; } + + // `output_dims` are the dimensions in the output array that are being used + // to compute an index into the `indices` array. See the class + // documentation and the overview for more details. + tensorflow::gtl::ArraySlice output_dims() const { + return output_dims_; + } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64 source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64 source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + StatusOr GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + string ToString(Array* root, bool print_constants = false); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + Status TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + StatusOr ComputeArrayFor(const HloInstruction* instr); + + StatusOr ComputeArrayForConstant(const Literal& literal); + + StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + StatusOr FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape); + + StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); + + StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, Array* rhs); + StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + Literal* TakeOwnership(std::unique_ptr literal) { + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + StatusOr TakeOwnership( + StatusOr> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + std::vector> owned_tensors_; + std::vector> owned_literals_; + tensorflow::gtl::FlatMap cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloPassInterface { + public: + tensorflow::StringPiece name() const override; + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc new file mode 100644 index 00000000000000..373556ebeba883 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -0,0 +1,504 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace { +class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + protected: + void AssertArrayForRootExpressionIs(const string& hlo_text, + const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/false); + } + + void AssertArrayWithConstantsForRootExpressionIs( + const string& hlo_text, const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/true); + } + + private: + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, + const string& root_expression, + bool print_constants) { + IndexedArrayAnalysis indexed_tensor_analysis; + ParseAndVerifyModule(hlo_text); + + TF_ASSERT_OK_AND_ASSIGN( + IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( + module().entry_computation()->root_instruction())); + string string_result = + indexed_tensor_analysis.ToString(array_result, print_constants); + LOG(INFO) << string_result; + ASSERT_EQ(string_result, root_expression); + } +}; + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5] parameter(0) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices_a = s32[5] parameter(0) + indices_b = s32[2] parameter(1) + gather_a = s32[5,3] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT gather_b = s32[2,3] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a " + "%indices_b 0->[0]) 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[2] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 1->[1]) 1->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,6] parameter(0) + indices_a = s32[2] parameter(1) + indices_b = s32[5,7] parameter(2) + gather_a = s32[2,6] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 0->[0,1]) 0->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[4,8] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=2, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " + "1->[0,2]) 1->[0,1,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5] parameter(0) + gather = s32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT reshape = s32[5,2,2] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,7] parameter(0) + gather = s32[5,4,7] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,2,6] constant(s32[3,2,6]{ + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[5,7] parameter(0) + gather = s32[5,2,6,7] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,2,6} + ROOT reshape = s32[5,3,4,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,2,3] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,5,2] constant(s32[3,5,2]{ + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}}) + indices = s32[7] parameter(0) + gather = s32[3,2,7] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1,2} + ROOT reshape = s32[6,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { + string hlo_text = R"( +HloModule UnaryOpOfGather + +ENTRY main { + operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + indices = s32[5] parameter(0) + gather = f32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT tanh = f32[5,4] tanh(gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant f32[3,4] f32[3,4] { + { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, + { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, + { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 6, 7, 8, 9 }, + { 6, 8, 7, 9 }, + { 9, 8, 7, 6 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsLhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { -4, -3, -2, -1 }, + { -4, -2, -3, -1 }, + { -1, -2, -3, -4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsRhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 4, 3, 2, 1 }, + { 4, 2, 3, 1 }, + { 1, 2, 3, 4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[4] constant({10,11,12,13}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 11, 13, 15, 17 }, + { 11, 14, 14, 17 }, + { 14, 14, 14, 14 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[5] constant({10,11,12,13,14}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input = f32[100] parameter(0) + ROOT tanh = f32[100] tanh(input) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%tanh"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input0 = f32[100] parameter(0) + input1 = f32[100] parameter(1) + ROOT add = f32[100] add(input0, input1) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7aa1c7c8358318..d2af261008f40e 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({4, 3, 3, 4}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } // Test that `constant` function is changed to `broadcast`. @@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({3, 1, -1, -3}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index b9ccfeddb565b7..429c8503432b79 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,6 +28,25 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// These nodes can always be duplicated into consumers, even if +// InstructionFusion::may_duplicate_ is false. +// +// In general these should be nodes that get *cheaper* the more they're +// duplicated (and fused into consumers). +// +// TODO(jlebar): Duplicating instructions when we have a variable called "may +// duplicate" that's equal to false is not pretty. +bool IsAlwaysDuplicable(const HloInstruction& instruction) { + // We are always willing to duplicate a widening type-conversion instruction + // if it means we can fuse the convert into a consumer. This allows the + // consumer to read less memory, which is almost always a performance win. + return instruction.opcode() == HloOpcode::kConvert && + ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction.shape()); +} +} // namespace + /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -99,13 +118,16 @@ namespace xla { case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kDivide: + case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: case HloOpcode::kHostCompute: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: @@ -128,11 +150,11 @@ namespace xla { return false; } -// An "effectively unary" operation is one that has one "large" +// An "effectively at most unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. -bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { +bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { int64 output_rank = 0; ShapeUtil::ForEachSubshape( hlo->shape(), @@ -156,66 +178,89 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { } bool InstructionFusion::CanFuseOnAllPaths( - const HloReachabilityMap& reachability_map, HloInstruction* producer, - HloInstruction* consumer, DoNotFuseSet* do_not_fuse) { - auto could_fuse_on_all_paths = [&] { - // First check to see if we have already marked this producer as infeasible - // to fuse into consumer. - if (do_not_fuse->count(producer) > 0) { + HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_duplicate) { + if (consumer == producer) { + return true; + } + if (!consumer->IsFusable()) { + return false; + } + for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { + auto* consumer_operand = consumer->mutable_operand(i); + // If the operand is not on a path to the producer, it doesn't matter + // whether it's fusable. + if (!reachability_->IsReachable(producer, consumer_operand)) { + continue; + } + if (do_not_duplicate.count(consumer_operand) > 0 || + !ShouldFuse(consumer, i)) { return false; } - // Make sure it is possible for producer and consumer to exist in a fusion - // node. - if (!producer->IsFusable() || !consumer->IsFusable()) { + // The producer is reachable from consumer_operand which means we need + // to be able to fuse consumer_operand into consumer in order for + // producer to be fusable into consumer on all paths. + // Perform the recursive step: make sure producer can be fused into + // consumer_operand on all paths. + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { return false; } - // We do an upward walk of the graph from consumer towards all paths which - // lead to producer to find any unfusable paths. - for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { - auto* consumer_operand = consumer->mutable_operand(i); - if (consumer_operand == producer) { - // This is the base case: our upward crawl ends but we need to make sure - // that fusion from consumer can happen. - if (!ShouldFuse(consumer, i)) { - return false; - } - } else if (reachability_map.IsReachable(producer, consumer_operand)) { - // The reachability map told us that consumer_operand is a node on the - // path to producer. We need to further investigate from - // consumer_operand. - - // First check if we have already ruled out fusing producer into - // consumer_operand. - if (do_not_fuse->count(consumer_operand) > 0) { - return false; - } - // Make sure it is possible for consumer_operand to exist in a fusion - // node. - if (!consumer_operand->IsFusable()) { - return false; - } - // The producer is reachable from consumer_operand which means we need - // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. - if (!ShouldFuse(consumer, i)) { - return false; - } - // Perform the recursive step: make sure producer can be fused into - // consumer_operand on all paths. - if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand, - do_not_fuse)) { - return false; - } + } + return true; +} + +InstructionFusion::HloInstructionSet +InstructionFusion::ComputeGloballyUnfusable( + tensorflow::gtl::ArraySlice post_order) { + // Forbid fusion of producers that: + // a) Need to be duplicated, unless they can be fused into all consumers + // via all paths. + // b) Are more than unary, that is, fusing them would likely lead to an + // increase in memory bandwidth use. + // + // Note that if we allow fusion by these global rules, we may still forbid + // fusing operations that require duplication later depending on + // is_expensive_(). + HloInstructionSet do_not_duplicate; + for (HloInstruction* consumer : post_order) { + for (HloInstruction* producer : consumer->operands()) { + if (do_not_duplicate.count(producer) > 0) { + continue; } + + // If the producer is effectively not more than unary, duplicating it + // will not increase the number of relevant inputs read, as the fusion + // node will only need to read at most 1 relevant input (the input of + // the producer). In that case, we do not forbid fusion of the operation + // here. + if (EffectivelyAtMostUnary(producer)) { + continue; + } + // Otherwise we will forbid fusing the op unless we can fuse it into + // all of its consumers on all paths. + // + // That means, that for: + // A --> B (fusable) + // \-> C (non-fusable) + // A will be not allowed to be fused into B, as it cannot be fused into C. + // + // Similarly, for: + // A -------------> B + // \-> C -> D -/ + // If: + // - A is fusable into B and C, and D is fusable into B + // - C is *not* fusable into D + // A will be not allowed to be fused into B, as it cannot be fused via + // all paths. + if (producer->IsFusable() && + CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { + continue; + } + do_not_duplicate.insert(producer); } - return true; - }; - if (could_fuse_on_all_paths()) { - return true; } - // We couldn't fuse on all paths, record this result. - do_not_fuse->insert(producer); - return false; + + return do_not_duplicate; } StatusOr InstructionFusion::Run(HloModule* module) { @@ -227,6 +272,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; + reachability_ = computation_->ComputeReachability(); // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make @@ -244,36 +290,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - DoNotFuseSet do_not_fuse; - auto reachability = computation->ComputeReachability(); - - auto cheap_to_duplicate = [this](HloInstruction* producer) { - if (producer->opcode() == HloOpcode::kBroadcast) { - return true; - } - if (producer->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(producer->shape())) { - return true; - } - if (EffectivelyUnary(producer)) { - return true; - } - return false; - }; - - for (HloInstruction* consumer : post_order) { - for (HloInstruction* producer : consumer->operands()) { - if (cheap_to_duplicate(producer)) { - continue; - } - if (CanFuseOnAllPaths(*reachability, producer, consumer, - &do_not_fuse)) { - CHECK_EQ(do_not_fuse.count(producer), 0); - } else { - CHECK_GT(do_not_fuse.count(producer), 0); - } - } - } + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -341,9 +358,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { // ensures that B will be considered before A. // // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers(instruction->operands().size()); - std::iota(std::begin(sorted_operand_numbers), - std::end(sorted_operand_numbers), 0); + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the post-order index. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (post_order_index.find(instruction->mutable_operand(i)) == + post_order_index.end()) { + continue; + } + sorted_operand_numbers.push_back(i); + } std::sort( sorted_operand_numbers.begin(), sorted_operand_numbers.end(), [&](int64 i, int64 j) { @@ -360,13 +388,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { if (!operand->IsFusable()) { continue; } - if (!ShouldFuse(instruction, i)) { - continue; - } - if (do_not_fuse.count(operand) > 0) { + + HloInstruction* fusion_instruction; + // Try "regular" fusion if the operand may be duplicated. Otherwise, + // perform multi-output fusion, unless this creates a cycle. + // TODO(tjoerg): Consider making multi-output fusion the default. + if (ShouldFuse(instruction, i) && + do_not_duplicate.count(operand) == 0) { + fusion_instruction = Fuse(operand, instruction); + } else if (ShouldFuseIntoMultiOutput(instruction, i) && + !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_instruction = FuseIntoMultiOutput(operand, instruction); + } else { continue; } - HloInstruction* fusion_instruction = Fuse(operand, instruction); // Fusing an instruction into a fusion instruction can change the // operand set of the fusion instruction. For simplicity just push the @@ -397,12 +432,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { return changed; } -HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { +HloInstruction* InstructionFusion::AddFusionInstruction( + HloInstruction* producer, HloInstruction* consumer) { HloInstruction* fusion_instruction; - - VLOG(2) << "Fusing " << producer->ToString() << " into " - << consumer->ToString(); auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -414,17 +446,48 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } + return fusion_instruction; +} +HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, + HloInstruction* consumer) { + VLOG(2) << "Fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); fusion_instruction->FuseInstruction(producer); return fusion_instruction; } +HloInstruction* InstructionFusion::FuseIntoMultiOutput( + HloInstruction* producer, HloInstruction* consumer) { + VLOG(2) << "Multi-output fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + return fusion_instruction; +} + +bool InstructionFusion::MultiOutputFusionCreatesCycle( + HloInstruction* producer, HloInstruction* consumer) { + return c_any_of( + consumer->operands(), [&](const HloInstruction* consumer_operand) { + // The fusion algorithm traverses the HLO graph in reverse post order. + // Thus `cosumers` is visited before its operands (including + // `producer`). Therefore, consumer operands cannot have been fused yet. + // It is thus safe to use the pre-computed reachability map. + return consumer_operand != producer && + reachability_->IsReachable(producer, consumer_operand); + }); +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && - (is_expensive_(*producer) || !may_duplicate_)) { + (!may_duplicate_ || is_expensive_(*producer)) && + !IsAlwaysDuplicable(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 152d0886ee9eda..f73ca9adf768ed 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -61,6 +61,14 @@ class InstructionFusion : public HloPassInterface { // Subtypes can override this with target-specific heuristics. virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index); + // Returns whether multi-output fusion can be applied to fuse `producer` into + // `consumer`. In contrast to "regular" fusion, the `producer` is not + // duplicated by multi-output fusion. + virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + return false; + } + // Chooses a fusion kind for `producer` and `consumer`. // Default method chooses `kLoop`. virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, @@ -70,11 +78,18 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); - // An "effectively unary" operation is one that has one "large" + // Creates a new fusion instruction containing `producer` and `consumer`. A + // tuple is added as the fusion instruction's root, which consumes from both, + // `producer` and `consumer`. This style of fusion is referred to as + // multi-output fusion. + virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer); + + // An "effectively unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. - bool EffectivelyUnary(HloInstruction* hlo); + bool EffectivelyAtMostUnary(HloInstruction* hlo); // Returns true if fusing producer into consumer would cause producer to be // duplicated. This is the case if producer has uses other than consumer. @@ -90,21 +105,34 @@ class InstructionFusion : public HloPassInterface { // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; + // Reachability information for the current computation. + std::unique_ptr reachability_; private: // The set of producers whose consumers we cannot fuse into. - using DoNotFuseSet = std::unordered_set; + using HloInstructionSet = std::unordered_set; - // Whether or not we can fuse consumer into original_producer on all paths + HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer); + + // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map, - HloInstruction* producer, HloInstruction* consumer, - DoNotFuseSet* do_not_fuse); + bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_fuse); + + // Computes the set of nodes that we do not want to fuse into any of their + // consumers based on a global analysis of the HLO graph. + HloInstructionSet ComputeGloballyUnfusable( + tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 0fa2c95fb458f8..21db2338995960 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -16,12 +16,100 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { +namespace op = xla::testing::opcode_matchers; + using InstructionFusionTest = HloTestBase; +// Subclass of InstructionFusion exposing the protected methods Fuse and +// FuseIntoMultiOutput for testing. +class InstructionFusionForTesting : public InstructionFusion { + public: + explicit InstructionFusionForTesting(HloModule* module) + : InstructionFusion(InstructionFusion::IsExpensive) { + module_ = module; + computation_ = module->entry_computation(); + } + + HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::Fuse(producer, consumer); + } + + HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::FuseIntoMultiOutput(producer, consumer); + } +}; + +TEST_F(InstructionFusionTest, FuseInstructions) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT sub = f32[4,3]{1,0} subtract(add, p0) + })") + .ValueOrDie(); + HloInstruction* sub = module->entry_computation()->root_instruction(); + HloInstruction* add = sub->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(add, sub); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), + op::Subtract(op::Add(), op::Parameter())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + fused_computation { + p1 = f32[4,3] parameter(0) + add = f32[4,3] add(p1, p1) + } + ENTRY entry_computation { + p0 = f32[4,3] parameter(0) + abs = f32[4,3] abs(p0) + ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(abs, root); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + abs = f32[4,3]{1,0} abs(p0) + tanh = f32[4,3]{1,0} tanh(abs) + ROOT add = f32[4,3]{1,0} add(abs, tanh) + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* tanh = root->mutable_operand(1); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs())) + << module->ToString(); +} + TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( @@ -89,7 +177,172 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); +} + +// Counts the number of HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT root = f32[4,3]{1,0} subtract(add, add) + })") + .ValueOrDie(); + // Expect the add and subtraction to be fused. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); +} + +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { + // Make sure we do not duplicate the add, as we cannot fuse through the rng. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> rng -> abs2 -/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + rng = f32[4,3]{1,0} rng(abs1), distribution=rng_uniform + abs2 = f32[4,3]{1,0} abs(rng) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + // We expect abs2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); + + // Use a log node with a second consumer to break the fusion. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> log -> abs2 -/ + // \-> send + module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + log = f32[4,3]{1,0} log(abs1) + send = f32[4,3]{1,0} send(log), channel_id=0 + abs2 = f32[4,3]{1,0} abs(log) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + + // We expect abs2 to be fused into root and abs1 to be fused into log. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 2) << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); + + // Make sure we still fuse ops where one operand in the chain to the producer + // can't be fused. + // + // p0 ---> add1 -----------> sub + // \ \-> add2 -/ + // \-> log -/ + // \-> send + module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add1 = f32[4,3]{1,0} add(p0, p0) + log = f32[4,3]{1,0} log(p0) + send = f32[4,3]{1,0} send(log), channel_id=0 + add2 = f32[4,3]{1,0} add(log, add1) + ROOT root = f32[4,3]{1,0} subtract(add1, add2) + })") + .ValueOrDie(); + + // Expect the add1 and add2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + + // Make sure we didn't duplicate any adds. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); + + // A variant of the above that allows the algorithm to put add2 into the set + // of unfusable ops to short-circuit the decision whether add1 should be fused + // into sub2. + // + // /---------------\ + // p0 ---> add1 ---> add2 ------> sub2 + // \------> sub1 + // log -/ + // \-> send + module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add1 = f32[4,3]{1,0} add(p0, p0) + add2 = f32[4,3]{1,0} add(add1, add1) + log = f32[4,3]{1,0} log(add2) + send = f32[4,3]{1,0} send(log), channel_id=0 + sub1 = f32[4,3]{1,0} subtract(log, add2) + sub2 = f32[4,3]{1,0} subtract(add2, add1) + ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) + })") + .ValueOrDie(); + + // Expect sub1 and sub2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); + + // Make sure we didn't duplicate any adds. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); } TEST_F(InstructionFusionTest, AllowUnaryDuplication) { @@ -135,4 +388,29 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, + WideningConvertsAreAlwaysDuplicableIntoConsumers) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY Test { + p0 = f16[100] parameter(0) + c = f32[100] convert(p0) + add = f32[100] add(c, c) + ROOT mul = f32[100] multiply(c, c) + })") + .ValueOrDie(); + + // The convert should be fused into the add and mul, even though may_duplicate + // is false, because it's always beneficial to fuse/duplicate widening + // converts into consumers. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion(op::Parameter())); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 45505484951abf..524d3234eb4eff 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -18,7 +18,6 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -117,6 +116,5 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/README.md b/tensorflow/compiler/xla/service/interpreter/README.md index 4c19a1b916d421..0b21b251c3f663 100644 --- a/tensorflow/compiler/xla/service/interpreter/README.md +++ b/tensorflow/compiler/xla/service/interpreter/README.md @@ -5,7 +5,7 @@ evaluating the result of the HLO graph directly with HloEvaluator, without lowering it further (to LLVM IR for example) before execution as other backends (CPU and GPU for example) do. -Its key componenets are: +Its key components are: * [`InterpreterCompiler`] despite the inherited naming of "compiler", all `InterpreterCompiler` really does is the following: diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 76b3ecad26fe92..c1666530687f2f 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" @@ -45,8 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - + hlo_module->mutable_device_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } @@ -71,7 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module)); + xla::MakeUnique(std::move(hlo_module), + xla::MakeUnique()); return std::move(executable); } @@ -101,17 +100,14 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() return InterpreterExecutable::ShapeSizeBytes; } -static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); -} - static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { return xla::MakeUnique(); }); xla::ComputationPlacer::RegisterComputationPlacer( - se::interpreter::kXlaInterpreterPlatformId, &CreateComputationPlacer); + se::interpreter::kXlaInterpreterPlatformId, + []() { return xla::MakeUnique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 61f199bc9e8f4f..029e71058a7373 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -32,16 +31,17 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace interpreter { InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module, + std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + /*hlo_profile_index_map=*/nullptr), + evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} @@ -82,10 +82,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( } // Execute the graph using the HloEvaluator. - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, - evaluator.Evaluate>(*computation, arg_literals)); + std::unique_ptr result_literal; + { + tensorflow::mutex_lock lock(evaluator_lock_); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate>( + *computation, arg_literals)); + } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index b0b797ca7d6f44..91d8148d26dc8e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -40,13 +42,15 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module, + std::unique_ptr evaluator); ~InterpreterExecutable() override; StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override + LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, @@ -54,6 +58,11 @@ class InterpreterExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); + protected: + // The interpreter interprets executables with an HloEvaluator. + std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + mutable tensorflow::mutex evaluator_lock_; + private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 92e069a8c67c1d..42c2c28997d5f3 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -31,13 +30,13 @@ limitations under the License. namespace stream_executor { namespace interpreter { -XlaInterpreterPlatform::XlaInterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, + const Platform::Id& id) + : name_(name), id_(id) {} XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id XlaInterpreterPlatform::id() const { - return kXlaInterpreterPlatformId; -} +Platform::Id XlaInterpreterPlatform::id() const { return id_; } int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } @@ -106,8 +105,6 @@ REGISTER_MODULE_INITIALIZER( interpreter_platform, stream_executor::interpreter::InitializeXlaInterpreterPlatform()); -DECLARE_MODULE_INITIALIZER(multi_platform_manager); - // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. REGISTER_MODULE_INITIALIZER_SEQUENCE(interpreter_platform, diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index d68c5aa20dda7a..0187f6d473b19f 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -28,7 +29,8 @@ namespace interpreter { class XlaInterpreterPlatform : public Platform { public: - XlaInterpreterPlatform(); + XlaInterpreterPlatform(const string& name = "Interpreter", + const Platform::Id& id = kXlaInterpreterPlatformId); ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -55,6 +57,8 @@ class XlaInterpreterPlatform : public Platform { private: // This platform's name. string name_; + // This platform's id. + Platform::Id id_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 2494569db53f26..7067b6f86a0fb2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -400,9 +402,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints) { + const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, HloComputation* computation, + LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -424,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { - // Parameter layouts must match the respective layout in - // ComputationLayout. - shape_with_layout = - &computation_layout.parameter_layout(instruction->parameter_number()) - .shape(); + if (computation_layout != nullptr) { + const ShapeLayout& parameter_layout = + computation_layout->parameter_layout( + instruction->parameter_number()); + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + shape_with_layout = ¶meter_layout.shape(); + } + } } if (shape_with_layout != nullptr) { TF_RETURN_IF_ERROR( @@ -493,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints( HloComputation* body = instruction->while_body(); HloComputation* condition = instruction->while_condition(); const HloInstruction* init = instruction->operand(0); - const ComputationLayout& body_layout = - FindOrDie(computation_layouts_, body); - const ComputationLayout& condition_layout = + ComputationLayout& body_layout = FindOrDie(computation_layouts_, body); + ComputationLayout& condition_layout = FindOrDie(computation_layouts_, condition); // Check a few invariants irrespective of layout. @@ -508,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints( condition_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); - // Return error if earlier layout assignment of the embedded computations - // has produced conflicting layouts. - if (!ShapeUtil::Equal(body_layout.result_shape(), - body_layout.parameter_shape(0))) { - return InternalError( - "Parameter and result of body computation %s of while instruction " - "%s have different layouts: %s vs %s", - body->name().c_str(), instruction->name().c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str(), - ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + if (body_layout.result_layout() != body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while body parameter layout: body=" << body->name() + << " while=" << instruction->name() + << " shape=" << body_layout.result_layout().ToString(); + *body_layout.mutable_parameter_layout(0) = body_layout.result_layout(); } - if (!ShapeUtil::Equal(body->root_instruction()->shape(), - condition->parameter_instruction(0)->shape())) { - return InternalError( - "Parameter of condition computation %s of while instruction " - "%s does not match body computation %s result: %s vs %s", - condition->name().c_str(), instruction->name().c_str(), - body->name().c_str(), - ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + if (condition_layout.parameter_layout(0) != + body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while condition parameter layout: cond=" + << condition->name() << " while=" << instruction->name() + << " shape=" << body_layout.parameter_layout(0).ToString(); + *condition_layout.mutable_parameter_layout(0) = + body_layout.parameter_layout(0); } // Constrain the output and the operand of the while instruction to match @@ -557,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints( true_computation_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible( false_operand->shape(), false_computation_layout.parameter_shape(0))); - + if (true_computation_layout.result_layout() != + false_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the true and false computations + // might have negotiated a different layout. But for the conditional + // instruction POV the layout must match, so we run again on the false + // computation, this time with proper computation layout. + VLOG(2) << "Reset %conditional false computation result layout: " + "false_computation=" + << false_computation->name() + << " conditional=" << instruction->name() << " shape=" + << true_computation_layout.result_layout().ToString(); + *false_computation_layout.mutable_result_layout() = + true_computation_layout.result_layout(); + } TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( true_computation_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( @@ -593,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints( } } } - - // Finally set the result layout to match ComputationLayout. - return constraints->SetResultLayout( - computation_layout.result_layout().shape()); + // Finally set the result layout to match ComputationLayout, if there is one. + if (computation_layout != nullptr) { + const ShapeLayout& result_layout = computation_layout->result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); + } + } + return Status::OK(); } namespace { @@ -760,6 +776,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); + RegisterAddedCopy(copy); SetupCopiedInstruction(*instruction, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( @@ -783,13 +800,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + VLOG(5) << "Operand " << operand->ToString() << " layout matches in " + << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. return Status::OK(); } + VLOG(4) << "Operand " << operand->ToString() << " layout does not match " + << operand_layout.ToString() << " in " << instruction->ToString(); TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, CreateCopyWithNewLayout(operand_layout.shape(), operand)); + VLOG(4) << "New copy of " << operand->ToString() << " is " + << operand_copy->ToString(); return instruction->ReplaceOperandWith(operand_no, operand_copy); } @@ -896,15 +919,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } } } - - // Finally verify the result layout matches the layout of the entry + // Finally verify the result layout, if set, matches the layout of the entry // computation root. - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), + const ShapeLayout& result_layout = FindOrDie(computation_layouts_, module->entry_computation()) - .result_layout() - .shape())); - + .result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + result_layout.shape())); + } return Status::OK(); } @@ -913,18 +937,13 @@ LayoutAssignment::LayoutAssignment( ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { - VLOG(1) << "entry computation layout given to layout assignment: " + VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. for (const ShapeLayout& parameter_layout : entry_computation_layout_->parameter_layouts()) { CHECK(parameter_layout.LayoutIsSet()); } - // If the result layout is not set, then choose the default. - // TODO(b/29118294): Choose a better layout in this case. - if (!entry_computation_layout_->result_layout().LayoutIsSet()) { - entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout(); - } } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1484,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); } +Status LayoutAssignment::CalculateComputationLayout( + HloComputation* computation) { + ComputationLayout computation_layout(computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << " Calculated ComputationLayout = " + << computation_layout.ToString(); + return Status::OK(); +} + +Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + return Status::OK(); +} + Status LayoutAssignment::RunOnComputation( - const ComputationLayout& computation_layout, + ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { - DCHECK(computation_layout.LayoutIsSet()); - InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; - VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); + if (computation_layout != nullptr) { + auto it = computation_layouts_.find(computation); + if (it == computation_layouts_.end()) { + VLOG(2) << " New ComputationLayout = " << computation_layout->ToString(); + computation_layouts_.emplace(computation, *computation_layout); + } else { + TF_RET_CHECK(computation_layout == &it->second || + computation_layout == entry_computation_layout_); + VLOG(2) << " Existing ComputationLayout = " + << computation_layout->ToString(); + } + } else { + VLOG(2) << " No ComputationLayout specified (will be calculated)"; + } // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1536,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation( CHECK_LT(constraints.unconstrained_buffer_ids().size(), unconstrained_count); } - // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + // If the computation layout wasn't specified, now it is the time to compute + // it according to the parameters and root instruction layouts. + // This allows the first pass through this API to record the best flowing + // layout to parameters and root instruction. + if (computation_layout == nullptr) { + TF_RETURN_IF_ERROR(CalculateComputationLayout(computation)); + } + // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. for (HloInstruction* instruction : computation->instructions()) { @@ -1556,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation( return Status::OK(); } +Status LayoutAssignment::PropagateComputationLayouts( + HloComputation* computation, ComputationLayout* computation_layout) { + ComputationLayout computed_computation_layout( + computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) { + ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); + if (!param_layout->LayoutIsSet()) { + VLOG(4) << "Assigning layout to parameter " << i << " of computation " + << computation->name() << ": " + << computed_computation_layout.parameter_layout(i).ToString(); + *param_layout = computed_computation_layout.parameter_layout(i); + } else { + TF_RET_CHECK(computed_computation_layout.parameter_layout(i) == + *param_layout); + } + } + ShapeLayout* result_layout = computation_layout->mutable_result_layout(); + if (!result_layout->LayoutIsSet()) { + VLOG(4) << "Assigning result layout of computation " << computation->name() + << ": " << computed_computation_layout.result_layout().ToString(); + *result_layout = computed_computation_layout.result_layout(); + } else { + TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + } + return Status::OK(); +} + StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -1564,52 +1662,45 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "before layout assignment", module->config().debug_options()); } - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Assign layouts to computations in an order such that a callee computation - // is handled before its caller computation. This ensures that the layout of - // all callers of a computation will agree. - std::list computation_post_order = - module->MakeComputationPostOrder(); - for (auto* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + TF_RETURN_IF_ERROR(Init()); + + // We do two passes. The first one we pass a nullptr ComputationLayout to + // the RunOnComputation() calls (for non entry computations), and we register + // the ComputationLayout which are naturally flowing in DFS fashion to the + // parameters and root instruction. + // Walking in DFS mode though, means that we can end up with incorrect layouts + // when seen from an outer instruction, which has across-computation + // constraints to impose. + // For example, the kWhile instruction needs to enforce the same layouts for + // the parameters and root of the bosy, as well as the condition parameters. + // Similarly, the kConditional instruction needs to enforce the same layouts + // for the root of the true and false computations. + // So in the first pass, while allowing the layouts to flow to parameters and + // root, we also fix up the eventually inconsistent ComputationLayout, which + // will be then made mandatory by the second pass. + for (int64 i = 0; i < 2; ++i) { + TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; } - if (instruction->opcode() != HloOpcode::kInfeed) { - LayoutUtil::ClearLayout(instruction->mutable_shape()); + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); + } else { + ComputationLayout* computation_layout = + (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, + *points_to_analysis, computation, + channel_layout_constraints_)); } } - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - *entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); - } else { - ComputationLayout computation_layout(computation->ComputeProgramShape()); - // Setting all embedded computations to the default layout is potentially - // suboptimal. - computation_layout.SetToDefaultLayout(); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, - channel_layout_constraints_)); - } } - + TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; @@ -1619,9 +1710,54 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "after layout assignment", module->config().debug_options()); } - // All layouts are reset then reassigned by this pass. return true; } +Status LayoutAssignment::Init() { + computation_layouts_.clear(); + return Status::OK(); +} + +Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + // Clear all the copies which have been added, and all the related + // instructions (like GTE and tuples). + int64 removed_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCopy && + added_copies_.count(instruction) > 0) { + VLOG(5) << "Removing added copy: " << instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + ++removed_copies; + } + } + } + added_copies_.clear(); + if (removed_copies > 0) { + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, + int64 operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + SetupCopiedInstruction(*operand, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + } + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index ae4986d6ad9bc3..c287cca0c54ba1 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -281,8 +282,8 @@ class LayoutAssignment : public HloPassInterface { // the case that no particular layout is requested. // // channel_constraints is both an input and output. Any sends or recvs that - // are present in channel_constraints will be layed out as constrained. Any - // unconstrained sends or recvs will be layed out as locally optimal and their + // are present in channel_constraints will be laid out as constrained. Any + // unconstrained sends or recvs will be laid out as locally optimal and their // layout will be added as a constraint to channel_constraints. // // If channel_constraints is nullptr, no kSend or kRecvs must be contained @@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface { int64 operand_no); private: + // Initializes the layout assignment object for a new Run() call. + Status Init(); + // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints); + Status AddMandatoryConstraints(const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, + LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface { } // Construct contraints and assign layouts to all instructions in the - // computation satisfying the given ComputationLayout. Layouts constraints are - // added, then propagated until all LogicalBuffers in the computation are - // constrained. - Status RunOnComputation(const ComputationLayout& computation_layout, + // computation satisfying the given ComputationLayout, if not nullptr. + // Otherwise the ComputationLayout will be calculated by propagating the + // computation instruction contraints. + // Layouts constraints are added, then propagated until all LogicalBuffers in + // the computation are constrained. + Status RunOnComputation(ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -402,6 +408,25 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); + // Computes the ComputationLayout of the given computation based of the + // layouts assigned to parameters and root instruction, and inserts it to the + // computation_layouts_ map. + Status CalculateComputationLayout(HloComputation* computation); + + // Clears all the layouts which can be cleared within a computation. + Status ClearComputationLayouts(HloComputation* computation); + + // Clears the side effects of a previous pass, like added copy instructions. + Status ClearPreviousPassSideEffects(HloModule* module); + + // Propagates the layouts computed by the layout assignment pass on the given + // computation, to the computation layout passed in to this API. + // This API propagates missing layout, and also checks that the caller + // specified have been respected, by comparing those with the parameters and + // root computation instruction. + Status PropagateComputationLayouts(HloComputation* computation, + ComputationLayout* computation_layout); + ComputationLayout* entry_computation_layout_; protected: @@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface { // Creates and returns a copy of the given instruction with a different // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple // instruction producing the copy is returned. - static StatusOr CreateCopyWithNewLayout( + StatusOr CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction); // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + + // Registers a copy instruction added by the layout assignment pass. + void RegisterAddedCopy(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + added_copies_.insert(copy); + } + + // Adds a copy for the operand of an instruction, unless such operand is + // already a copy, and has a single user (which is forcibly the instruction + // itself). + Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + + // Every copy added to the module by the layout assignment pass is registered + // here. + tensorflow::gtl::FlatSet added_copies_; + ChannelLayoutConstraints* channel_layout_constraints_; }; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 4b1c9bad41de80..bf0448a67674f2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -651,7 +651,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); module = backend() @@ -660,13 +660,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - EXPECT_EQ( - ::tensorflow::Status::OK(), - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr) - .status()); + EXPECT_EQ(Status::OK(), backend() + .compiler() + ->RunBackend(std::move(module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); } // A GTE inside of a fusion node inherits the layout of its operand (which @@ -692,7 +691,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc deleted file mode 100644 index 68c99256a246ed..00000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ /dev/null @@ -1,379 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return true; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return false if any uses are detected at 'index', returns true otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return false; - } - } - // Return true: found no uses of 'operand' at 'index' in 'user'. - return true; - } - return false; -} - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : - dataflow.GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; - } - } - } - } - - return true; -} - -namespace { - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'instruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector> uses; - const PointsToSet::BufferList& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if there is exactly one use of 'operand' at 'operand_index' -// in 'fusion.fused_instructions', where the singleton use is the fused -// root at operand index 'use_operand_index'. Returns false otherwise. -// -// REQUIRES: 'fusion' opcode is a kFusion instruction. -bool HasUniqueFusedUseOfOperandAt( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* fusion, const int64 use_operand_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Check that 'operand' is unique in the operand list of 'fusion'. - if (fusion->OperandIndices(operand).size() > 1) { - return false; - } - // Find fusion parameter associated with 'operand'. - const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { - return fusion->operand(fused_param->parameter_number()) == operand; - }); - if (fused_param_it == fused_params.end()) { - return false; - } - auto* fused_param = *fused_param_it; - // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root (at index in 'use_operand_indices'). - return fused_param_uses.size() == 1 && - fused_param_uses[0].first == fusion->fused_expression_root() && - fused_param_uses[0].second == use_operand_index; -} - -} // namespace - -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following qualifications: -// -// (1) Is element-wise. Or... -// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index -// 0. -// -// (2) and (3) can only be determined if points-to analysis is available. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, - other_add_operand_index, - points_to_analysis); - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // TODO(b/62548313): Remove when buffer assignment is module scoped and - // does not assign buffers to calls. - // Find called computation parameter associated with 'operand'. - const std::vector operand_indices = user->OperandIndices(operand); - if (operand_indices.size() > 1) { - return false; - } - CHECK_EQ(1, operand_indices.size()); - auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); - // Get all uses of 'operand' at 'index' in called computation. - auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, - points_to_analysis); - - // Return true iff: - // *) There exists exactly one use of 'operand' in called computation. - // *) The unique use is by the root instruction of called computation. - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - auto* callee_root = user->to_apply()->root_instruction(); - return param_uses.size() == 1 && param_uses[0].first == callee_root && - callee_root->IsElementwiseOnOperand(param_uses[0].second); - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& value = - dataflow.GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - return false; - } - const HloUse& use = value.uses()[0]; - - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // Get all uses of value defined by 'operand' at 'operand_index'. - const auto& uses = - dataflow.GetValueDefinedAt(operand, operand_index).uses(); - // Return true iff: - // *) There exists two uses of 'operand'. - // *) One use is by 'user' (caller). - // *) One use is by root instruction of called computation (callee root). - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { - return use.instruction == user; - }) != uses.end(); - auto* callee_root = user->to_apply()->root_instruction(); - const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); - return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h deleted file mode 100644 index 28ef991880039d..00000000000000 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// A collection of utilities on the HLO graph. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ - -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// Returns true if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis); -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow); - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc deleted file mode 100644 index f8b309488eeb53..00000000000000 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ /dev/null @@ -1,463 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -class PointsToAnalysisTestBase : public HloTestBase { - protected: - void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); - computation_ = module_->AddEntryComputation(std::move(computation)); - } - - void RunAnalysis() { - CHECK_NOTNULL(module_.get()); - points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); - } - - void BuildModuleAndRunAnalysis(std::unique_ptr computation) { - BuildModule(std::move(computation)); - RunAnalysis(); - } - - std::unique_ptr module_; - HloComputation* computation_ = nullptr; - std::unique_ptr points_to_analysis_; - std::unique_ptr dataflow_analysis_; -}; - -class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; - -TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { - auto builder = HloComputation::Builder(TestName()); - - Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); -} - -TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction never uses tuple element 0, but does use element 1. - EXPECT_TRUE( - DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); -} - -class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto log = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape in_shape = ShapeUtil::MakeShape(F32, {8}); - Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, in_shape, "param0")); - auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *dataflow_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction can share with tuple element 1. - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *dataflow_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto update = builder.AddInstruction( - HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); - auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // The DynamicUpdateSlice instruction can share with the data operand, but not - // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, dot}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused dot add should be able to share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - auto b_t = builder.AddInstruction( - HloInstruction::CreateTranspose(data_shape, b, {1, 0})); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - - auto nested_fusion = computation_->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); - - auto fusion = computation_->CreateFusionInstruction( - {add, nested_fusion}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused transpose-dot-add should be share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto reverse = builder.AddInstruction( - HloInstruction::CreateReverse(data_shape, operand, {0, 1})); - - auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, two, reverse}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused operand->reverse->add cannot alias operand buffer 'operand'. - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - - auto make_cond = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Cond"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); - return builder.Build(); - }; - - auto make_body = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Body"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); - return builder.Build(); - }; - - module_ = CreateNewModule(); - HloComputation* cond_computation = - module_->AddEmbeddedComputation(make_cond()); - HloComputation* body_computation = - module_->AddEmbeddedComputation(make_body()); - - auto builder = HloComputation::Builder(TestName()); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto whil = builder.AddInstruction(HloInstruction::CreateWhile( - data_shape, cond_computation, body_computation, data)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); -} - -// Tests that Call can alias operand buffer if the only use of the operand -// in the called computation is an elementwise instruction. -TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { - Shape shape = ShapeUtil::MakeShape(F32, {8}); - // Build sub-computation with fusion root. - auto sub_builder = HloComputation::Builder(TestName() + "_sub"); - auto sub_param = sub_builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "sub_param")); - auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto ones = sub_builder.AddInstruction( - HloInstruction::CreateBroadcast(shape, one, {1})); - auto add = sub_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - - module_ = CreateNewModule(); - auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); - sub_computation->CreateFusionInstruction({add, ones}, - HloInstruction::FusionKind::kLoop); - - // Build entry-computation with kCall which calls 'sub_computation'. - auto builder = HloComputation::Builder(TestName()); - - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto reverse = - builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); - auto call = builder.AddInstruction( - HloInstruction::CreateCall(shape, {reverse}, sub_computation)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *dataflow_analysis_)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index bc683a1880b010..d909845a3a21fc 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; @@ -151,7 +153,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; - return tensorflow::Status::OK(); + return Status::OK(); } FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 3312a888443233..7323abeb207715 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -333,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress( } CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); - std::vector actual_index; - bool is_implicit_broadcast = false; - // We perform broadcasting when the operand shape has dimension(s) of size - // 1. In this case we fix the index value for that dimension to zero. This - // effectively broadcasts along this dimension. - for (int64 i = 0; i < index.size(); ++i) { - auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); - is_implicit_broadcast |= dim == 1; - } - - if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( @@ -354,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress( {index.linear()}, llvm_ir::AsStringRef(name)); } + std::vector actual_index; + for (int64 i = 0; i < index.size(); ++i) { + // When dimension i is of size 1, LLVM optimization is able to replace + // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to + // produce better code in some cases. + auto dim = shape_->dimensions(i); + actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + } + // "base_ptr_" has the type of "*" // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element // should be computed by diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 06cfb2a36c56c5..4c3195c29c859c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -97,6 +97,10 @@ class IrArray { llvm::Value*& operator[](size_t i) { return multidim()[i]; } void push_back(llvm::Value* value) { multidim().push_back(value); } + void InsertAt(int64 index, llvm::Value* value) { + CHECK_LE(index, size()); + multidim().insert(multidim().begin() + index, value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 1c00b2aabd182d..64b935bbf1fb90 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -100,6 +100,15 @@ class KernelSupportLibrary { [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); } + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + } + void For( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 7b227ce294176c..497b48ff227d7d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -36,8 +36,8 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, bool prevent_unrolling, bool prevent_vectorization) - : prefix_(prefix.ToString()), - suffix_(suffix.ToString()), + : prefix_(std::string(prefix)), + suffix_(std::string(suffix)), start_index_(start_index), end_index_(end_index), step_(step), diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 20069ce5a28184..d915f95db13491 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -174,7 +174,7 @@ class ForLoopNest { : ForLoopNest(/*name=*/"", ir_builder) {} ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) - : name_(name.ToString()), + : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f9112..ff64da87e9c9ac 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -87,18 +87,10 @@ llvm::Value* EmitCallToIntrinsic( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, llvm::IRBuilder<>* ir_builder) { - std::vector types; - for (auto type : overloaded_types) { - types.push_back(type); - } llvm::Module* module = ModuleFromIRBuilder(ir_builder); - llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); - std::vector operands_vec; - for (auto operand : operands) { - operands_vec.push_back(operand); - } - return ir_builder->CreateCall(intrinsic, operands_vec); + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, intrinsic_id, AsArrayRef(overloaded_types)); + return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -368,15 +360,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray(literal, module); + case U32: + return GetConstantDataArray(literal, module); + case U8: + return GetConstantDataArray(literal, module); + case S64: + return GetConstantDataArray(literal, module); + case S32: + return GetConstantDataArray(literal, module); + case F64: + return GetConstantDataArray(literal, module); + case F32: + return GetConstantDataArray(literal, module); + case BF16: + case F16: + return GetConstantDataArray(literal, module); + case PRED: + return GetConstantDataArray(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 3978acc132f34b..dc2934a34c23f8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); target_array.EmitWriteArrayElement(array_index, target_element, ir_builder); - return tensorflow::Status::OK(); + return Status::OK(); }), shape_(target_array.GetShape()), ir_builder_(ir_builder) {} @@ -84,7 +83,9 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { - CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())) + << ": '" << shape_.ShortDebugString() << "' does not match '" + << array.GetShape().ShortDebugString() << "'"; } } @@ -124,7 +125,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { for (const IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); @@ -135,7 +136,7 @@ tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { if (exit_bb_ != nullptr) { ir_builder_->SetInsertPoint(exit_bb_); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 9ff497aecd0bc9..b70d28ecd3033e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -38,8 +38,7 @@ using ElementGenerator = // Emits a loop for every element in the given shape. class LoopEmitter { public: - using BodyEmitter = - std::function; + using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* ir_builder); @@ -72,7 +71,7 @@ class LoopEmitter { tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. - tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = ""); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 34899b7400464e..dacc54742c0897 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), update_shape.dimensions(i)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + llvm::Value* max_bound = + ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, + start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // output_index[dim] = start_index[dim] + update_index[dim] // IrArray::Index output_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - update_index[i]->getType(), output_shape.dimensions(i)); - llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateURem( - ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 3a21eda35757aa..5fc08aab916e37 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -24,15 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -47,30 +46,27 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; llvm::Value* on_true_element_address = ir_builder->CreateInBoundsGEP(on_true, element_index); llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = ir_builder->CreateInBoundsGEP(on_false, element_index); llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), output_element_address); } } -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index dbf9a140068b60..352d34ebf839c6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -59,13 +59,13 @@ namespace llvm_ir { // of the address from the corresponding element in either // tuple_on_true or tuple_on_false: // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module); +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0fa4061738612d..1d9c9e0678765a 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -24,14 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -110,6 +108,11 @@ ExecutionOptions CreateExecutionOptions( ->set_xla_dump_optimized_hlo_proto_to( build_options.dump_optimized_hlo_proto_to().value()); } + if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_unoptimized_hlo_proto_to( + build_options.dump_unoptimized_hlo_proto_to().value()); + } if (build_options.dump_per_pass_hlo_proto_to().has_value()) { execution_options.mutable_debug_options() ->set_xla_dump_per_pass_hlo_proto_to( @@ -124,75 +127,17 @@ ExecutionOptions CreateExecutionOptions( LayoutUtil::SetToDefaultLayout( execution_options.mutable_shape_with_output_layout()); } - return execution_options; -} - -} // namespace - -StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { - return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); + for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { + execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( + disabled_pass); } - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { - tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); - auto metadata_string = [&metadata]() -> string { - if (!metadata.has_value()) { - return ""; - } - CHECK(metadata.value() != nullptr); - const OpMetadata& m = *metadata.value(); - if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); - } - return ""; - }; - return InvalidArgument( - "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); - } - - ExecutionOptions execution_options = - CreateExecutionOptions(build_options, program_shape.get()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, - &execution_options, user_computation)); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(build_options.device_ordinal())); - - return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); + return execution_options; } +} // namespace + StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -260,4 +205,15 @@ StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { /*computation_count=*/1); } +StatusOr LocalService::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + if (replica_number >= buffers.size()) { + return InvalidArgument( + "replica_number %d out of range; must be less than num_replicas = %zu.", + replica_number, buffers.size()); + } + return buffers[replica_number]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 06567cabd6eb28..39d6734c3fc06d 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -41,23 +41,11 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -70,6 +58,11 @@ class LocalService : public Service { // the "easy" case where a single replica is a single device. StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index 68553bed121917..c742d35a7bcafa 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include -#include - #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -28,43 +25,20 @@ namespace xla { LogicalBuffer::LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) - : instruction_(instruction), id_(id), color_(kInvalidColor), index_(index) { - const auto& s = shape(); - is_array_ = ShapeUtil::IsArray(s); - is_tuple_ = ShapeUtil::IsTuple(s); -} + : BufferValue(instruction, index, id), + instruction_(instruction), + index_(index) {} + +LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { + string color_string; + if (has_color()) { + color_string = tensorflow::strings::StrCat(" @", color().value()); + } return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), - "](#", id_, " @", color_.value(), ")"); -} - -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { - out << buffer.ToString(); - return out; -} - -/*static*/ LogicalBufferProto::Location LogicalBuffer::ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index) { - LogicalBufferProto::Location proto; - proto.set_computation_name(instruction.parent()->name()); - proto.set_instruction_name(instruction.name()); - for (const int64 index_entry : index) { - proto.add_shape_index(index_entry); - } - return proto; -} - -LogicalBufferProto LogicalBuffer::ToProto(const SizeFunction& size_fn) const { - LogicalBufferProto proto; - proto.set_id(id_); - proto.set_size(size_fn(*this)); - LogicalBufferProto::Location proto_location = - ToLocationProto(*instruction_, index_); - proto.mutable_defined_at()->Swap(&proto_location); - proto.set_color(color_.value()); - return proto; + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 67b205e289e626..f9ba5a554740c9 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -16,11 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ -#include -#include #include -#include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,133 +31,30 @@ limitations under the License. namespace xla { -// Class describing a contiguous sequence of elements (ie, C array) which form -// the components of Shaped values in XLA. XLA arrays are trivially a -// single LogicalBuffer. Tuple values are made up of more than one -// LogicalBuffer: a LogicalBuffer for the pointers to elements, and a -// LogicalBuffer for each child element. -// -// Every buffer is defined by a particular instruction and most instructions -// define only a single buffer. Instructions which define a single buffer -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single buffer -// which is a vector of pointers to the buffers containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple buffers only the top-level buffer (the vector of pointers) is -// defined by the Tuple instruction. The buffers containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one buffer. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. The tuple-shaped instructions do -// not assemble a tuple from existing buffers like the Tuple instruction does, -// but rather define the entire tuple. -// -// Some instructions, such as Bitcast, define no buffers. These instructions -// simply forward buffers from their operands. -// -// The LogicalBuffer object describes which HLO instruction defines a buffer and -// where within that instruction's output shape the buffer is defined. The -// location within the output shape is indicated by LogicalBuffer::index() which -// is defined identically to the index used in -// ShapeUtil::GetSubshape(). Examples: -// -// %add = Add(%foo, %bar) -// %tuple_constant = Constant({1, {42, 43}}) -// -// %add defines a single array-shaped buffer LogicalBuffer(%add, {}) which holds -// the array result of the add operation. The nested-tuple-shaped -// %tuple_constant defines 5 buffers described by the following LogicalBuffer -// objects: -// -// LogicalBuffer(%tuple_constant, {}) // "Top-level" buffer: vector of -// // pointers to LogicalBuffers at -// // indices {0} and {1} -// LogicalBuffer(%tuple_constant, {0}) // Holds value "1" -// LogicalBuffer(%tuple_constant, {1}) // Holds nested tuple: vector of -// // pointers to LogicalBuffers at -// // indices {1, 0} and {1, 1} -// LogicalBuffer(%tuple_constant, {1, 0}) // Holds value "42" -// LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43" -class LogicalBuffer { +// TuplePointsToAnalysis uses this subclass of BufferValue. +class LogicalBuffer : public BufferValue { public: - TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); - - // Id is a unique identifier for the LogicalBuffer to facilitate efficient - // collections of LogicalBuffers with stable iteration order. - // LogicalBuffers are typically created and accessed through - // TuplePointsToAnalysis, and points-to analysis assigns each LogicalBuffer a - // unique value. - using Id = int64; - - // Functions which return the size and alignment of a logical buffer in bytes. - using SizeFunction = std::function; - using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id); - - Id id() const { return id_; } + ~LogicalBuffer() override; // Return the instruction that defines the buffer. - HloInstruction* instruction() const { return instruction_; } + HloInstruction* instruction() const override { return instruction_; } // Return the index within the output of the instruction where the buffer is // defined. Index used defined as in ShapeUtil::GetSubshape() - const ShapeIndex& index() const { return index_; } - - // Return the color of the logical buffer. Differently colored buffers can - // not be parts of the same allocation. - Color color() const { - CHECK_NE(color_, kInvalidColor) - << "Should not query the color of a buffer that was never colored"; - return color_; - } - - void set_color(Color color) { - CHECK_NE(color, kInvalidColor) - << "Should not set the color of a buffer to the invalid color"; - color_ = color; - } - - bool has_color() const { return color_ != kInvalidColor; } + const ShapeIndex& index() const override { return index_; } // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will // contain the layout of instruction, if any. - const Shape& shape() const { + const Shape& shape() const override { return ShapeUtil::GetSubshape(instruction_->shape(), index_); } - // Returns true if this buffer is the top-level output buffer of the defining - // HLO instruction. This is equivalent to index == {}. - bool IsTopLevel() const { return index_.empty(); } - - // Whether this buffer contains a tuple. - bool IsTuple() const { return is_tuple_; } - - // Whether this buffer contains an array. - bool IsArray() const { return is_array_; } - - // operator< is required for std::set. - bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; } - - string ToString() const; - LogicalBufferProto ToProto(const SizeFunction& size_fn) const; - - // Returns the LogicalBufferProto::Location that serializes the given - // instruction and index. - static LogicalBufferProto::Location ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index); - - const Color kInvalidColor = Color(-1); + string ToString() const override; private: HloInstruction* instruction_; - Id id_ : 62; - bool is_array_ : 1; - bool is_tuple_ : 1; - Color color_; ShapeIndex index_; // Similar to HLO constructs (HloInstruction, etc), pointers are used for @@ -167,8 +62,6 @@ class LogicalBuffer { TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer); }; -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 6aca6ba38572c5..f410921b4b5337 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -125,6 +125,12 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { // RecvDone doesn't create a new buffer but rather aliases its input (Recv) // tuple element at {0} to its output. diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index f4c63dd86b4d8a..b5ef3967875a58 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -59,6 +59,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f74bcb0b79355c..3a6a7c25f4b727 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -53,7 +53,7 @@ NameUniquer::NameUniquer(const string& separator) { } string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : prefix.ToString()); + string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc new file mode 100644 index 00000000000000..c115bc097f3b1d --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" + +namespace xla { + +void OwningDeviceMemory::Free() { + CHECK(allocator_ != nullptr) + << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + auto status = allocator_->Deallocate(device_ordinal_, mem_); + if (!status.ok()) { + LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed."; + } + + allocator_ = nullptr; + mem_ = se::DeviceMemoryBase(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h new file mode 100644 index 00000000000000..9cf071f0d9d09d --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.h @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Break circular dependency between this file and device_memory_allocator.h. +class DeviceMemoryAllocator; + +// Owning pointer for memory on a device. +// +// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can +// point to memory that resides on a "device" (e.g. a GPU). When an +// OwningDeviceMemory goes out of scope, it frees the memory it owns. +// +// We say that an instance of OwningDeviceMemory is "active" if it currently +// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing, +// Free()'ing, and other actions can deactive an active object. +// +// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of +// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a +// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator. +class OwningDeviceMemory { + public: + OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {} + + explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal, + DeviceMemoryAllocator* allocator) + : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) { + CHECK(allocator != nullptr) << "allocator cannot be null."; + } + + OwningDeviceMemory(OwningDeviceMemory&& other) + : mem_(other.mem_), + device_ordinal_(other.device_ordinal_), + allocator_(other.allocator_) { + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + } + + OwningDeviceMemory& operator=(OwningDeviceMemory&& other) { + if (allocator_ != nullptr) { + Free(); + } + mem_ = other.mem_; + device_ordinal_ = other.device_ordinal_; + allocator_ = other.allocator_; + + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + return *this; + } + + // Deactivates this instance if it's active. Nop if it's not active. + OwningDeviceMemory& operator=(std::nullptr_t) { + if (allocator_ != nullptr) { + Free(); + } + return *this; + } + + ~OwningDeviceMemory() { + if (allocator_ != nullptr) { + Free(); + } + } + + // The returned allocator is nonnull iff this object is active. + DeviceMemoryAllocator* allocator() const { return allocator_; } + + int device_ordinal() const { return device_ordinal_; } + + // Gets the device memory pointer. + const void* opaque() const { return mem_.opaque(); } + void* opaque() { return mem_.opaque(); } + + uint64 size() const { return mem_.size(); } + + // Determines whether this wraps a null pointer. + // + // !is_null() is sufficient but not necessary to imply `this` is active. + bool is_null() const { return mem_.is_null(); } + + se::DeviceMemoryBase AsDeviceMemoryBase() { + return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); + } + + // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates + // this object. Precondition: `this` is active. + TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() { + CHECK(allocator_ != nullptr) + << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + allocator_ = nullptr; + se::DeviceMemoryBase mem(mem_); + mem_ = se::DeviceMemoryBase(); + return mem; + } + + // Frees the wrapped DeviceMemoryBase and deactivates this object. + // Precondition: `this` is active. + void Free(); + + private: + se::DeviceMemoryBase mem_; + int device_ordinal_; + DeviceMemoryAllocator* allocator_; // Null if this object is inactive. +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 586f6ef7a9c4f1..2515222cf2db3d 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -204,7 +204,7 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr LayoutPattern> EqualTo( - const Layout* layout) const { + const ::xla::Layout* layout) const { return LayoutPattern>( LayoutPatternEqualImpl(impl_, layout), matched_layout_); } @@ -702,6 +702,30 @@ class HloInstructionPatternOperandImpl { HloInstructionPattern operand_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a fusion node with a particular kind. +template +class HloInstructionPatternFusionKindImpl { + public: + explicit constexpr HloInstructionPatternFusionKindImpl( + const Previous& previous, ::xla::HloInstruction::FusionKind kind) + : previous_(previous), kind_(kind) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + private: + Previous previous_; + ::xla::HloInstruction::FusionKind kind_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -807,6 +831,16 @@ class HloInstructionPattern { matched_inst_); } + // Modifies the pattern to match only if the instruction is a fusion node with + // the given kind. + constexpr HloInstructionPattern> + WithFusionKind(HloInstruction::FusionKind kind) const { + return HloInstructionPattern>( + HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index c88157c312524f..fef3c132b0f346 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) { ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); const HloInstruction* matched_inst; HloInstruction* matched_operand; @@ -170,5 +170,28 @@ TEST(PatternMatcherTest, TupleShape) { Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); } +TEST(PatternMatcherTest, FusionKind) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + fused_computation { + ROOT fp0 = f32[] parameter(0) + } + + ENTRY while.v11 { + p0 = f32[] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop))); + EXPECT_FALSE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput))); + EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind( + HloInstruction::FusionKind::kLoop))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index e2c07e38271df8..688cceff0cd10d 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -75,7 +75,7 @@ StatusOr ReducePrecisionInsertion::insert_after( return false; } - // Check that we haven't already inserted an equivalant reduce-precision + // Check that we haven't already inserted an equivalent reduce-precision // operation after this instruction. (The zero-user case occurs when this is // the root instruction.) if (instruction->user_count() > 0) { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 086bd61dd04aa1..d01c35b9923131 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -62,12 +61,11 @@ namespace xla { namespace { -// Records the arguments used to invoke a computation in a SessionModule -// proto. -tensorflow::Status RecordArguments( +// Records the arguments used to invoke a computation in an HloSnapshot proto. +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, - SessionModule* module) { + HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( @@ -75,20 +73,18 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } -// Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - SessionModule* module) { +// Records the result of a computation in a HloSnapshot proto. +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace @@ -171,35 +167,20 @@ Service::Service(const ServiceOptions& options, } } -tensorflow::Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { +Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { +Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { return allocation_tracker_.Unregister(arg->data()); } // Deconstructs a previously-allocated global handle. -tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { +Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { TF_ASSIGN_OR_RETURN( std::vector elements, allocation_tracker_.DeconstructTuple(arg->tuple_handle())); @@ -207,11 +188,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, for (auto& element : elements) { *result->add_element_handles() = element; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const { +Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const { if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -265,11 +246,12 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - auto* computation_layout = config->mutable_entry_computation_layout(); - + ComputationLayout* host_computation_layout = + config->mutable_host_entry_computation_layout(); + ComputationLayout* device_computation_layout = + config->mutable_device_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -280,23 +262,16 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - if (user_computation == nullptr) { - return InvalidArgument( - "Argument does not match shape of computation parameter %d: want " - "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); - } - return InvalidParameterArgument( - *user_computation->ParameterMetadata(i).value(), - "Argument does not match shape of computation parameter %d: want %s, " - "got %s", + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - *argument_shapes[i])); + TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { @@ -305,10 +280,20 @@ StatusOr> Service::CreateModuleConfig( TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( + host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + shape_with_output_layout)); + TF_RETURN_IF_ERROR( + device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { - computation_layout->mutable_result_layout()->Clear(); + // If the result layout is not set, then choose the default. + // TODO(b/29118294): Allow the compiler to choose a better layout in this + // case. + // TODO(b/78356948): We are forcing the default layout here. We should fix + // clients which expect a default layout, to be explicit about it, by + // passing the proper ExecutionOptions with shape_with_output_layout set. + host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -330,85 +315,43 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation) { + const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options, - user_computation); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( - std::vector versioned_handles, + const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. - std::vector> session_modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { + std::vector> hlo_snapshots; + for (int64 i = 0; i < module_protos.size(); ++i) { const string& directory_path = module_configs[i]->debug_options().xla_dump_computations_to(); - const string& other_directory_path = + const string& execution_directory_path = module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && other_directory_path.empty()) { + if (directory_path.empty() && execution_directory_path.empty()) { continue; } - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); + auto hlo_snapshot = MakeUnique(); + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); - } - } - - VLOG(1) << "Computation handles:"; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle; - } - - CHECK_EQ(versioned_handles.size(), module_configs.size()); - std::vector> modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const VersionedComputationHandle& versioned_handle = versioned_handles[i]; - const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - computation_tracker_.BuildHloModule( - versioned_handle, config, - /*include_unreachable_instructions=*/true)); - modules.push_back(std::move(module)); - } - - TF_ASSIGN_OR_RETURN( - std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); - - for (size_t i = 0; i < versioned_handles.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { - executables[i]->set_session_module(std::move(session_modules[i])); + string filename = + Printf("computation_%lld__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name().c_str()); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); + hlo_snapshots.push_back(std::move(hlo_snapshot)); } } - return std::move(executables); -} - -StatusOr>> Service::BuildExecutables( - const std::vector& module_protos, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); - VLOG(1) << "Computations:"; for (const HloModuleProto* proto : module_protos) { VLOG(1) << proto->name(); @@ -429,97 +372,29 @@ StatusOr>> Service::BuildExecutables( backend->compiler()->Compile(std::move(modules), std::move(executors), device_allocator)); - return std::move(executables); -} - -StatusOr> Service::BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, - versioned_handle.ToString().c_str()); - - // Dump computation proto state if flag is set. - std::unique_ptr session_module; - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !other_directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); + for (size_t i = 0; i < module_protos.size(); ++i) { + if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { + executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i])); } } - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - true)); - - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); - - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); - - if (!other_directory_path.empty()) { - executable->set_session_module(std::move(session_module)); - } - - return std::move(executable); + return std::move(executables); } -StatusOr> Service::BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator) { - std::shared_ptr executable = - compilation_cache_.LookUp(versioned_handle, *module_config); - - if (executable != nullptr) { - // Executable found in the computation cache. - if (profile != nullptr) { - profile->set_compilation_cache_hit(true); - } - return executable; - } - - uint64 start_micros = - // Avoid reading the clock if we don't want timing info - (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; - - // Take a copy of the module config, as compilation introduces layouts where - // layouts were optional before. - HloModuleConfig original_module_config = *module_config; - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), backend, - executor, device_allocator)); - - if (profile != nullptr) { - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - uint64 milliseconds = (end_micros - start_micros) / 1000; - profile->set_compilation_cache_hit(false); - profile->set_compile_time_ms(milliseconds); - } - - // Insert executable into the cache. - return compilation_cache_.Insert(std::move(executable_unique_ptr), - original_module_config); +Status Service::ValidateEntryComputationLayout(HloModule* module) { + const ComputationLayout& on_device = + module->device_entry_computation_layout(); + for (int64 i = 0; i < on_device.parameter_count(); ++i) { + TF_RET_CHECK(ShapeUtil::Equal( + on_device.parameter_shape(i), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().parameter_shape(i)))); + } + TF_RET_CHECK(ShapeUtil::Equal( + module->device_entry_computation_layout().result_shape(), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().result_shape()))); + return Status::OK(); } StatusOr> @@ -542,9 +417,16 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. @@ -574,7 +456,6 @@ Service::ExecuteParallelAndRegisterResult( ExecutableRunOptions options; options.set_stream(streams.back().get()); options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); @@ -688,12 +569,12 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back(options, backend->StreamBorrower(), - backend->inter_op_thread_pool()); + run_options.emplace_back( + options, backend->StreamBorrower(), + /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); } if (options_.number_of_replicas() == 1) { @@ -718,13 +599,6 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); -} - StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const { @@ -766,119 +640,8 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - - std::vector>> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; - std::vector> module_configs; - std::vector computation_names; - std::vector device_handles; - - int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { - return a + r.execution_options().device_handles_size(); - }); - if (num_requested_devices * options_.number_of_replicas() > - execute_backend_->device_count()) { - return FailedPrecondition( - "there are not enough stream executors to execute %d computations", - num_requested_devices); - } - - for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor for the i'th computation. This stream executor - // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - - // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); - - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); - - // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. Here, we care only about the - // shapes of the arguments, so, it is sufficient to use the arguments of - // replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); - - // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(replicated_arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - versioned_handles.push_back(versioned_handle); - module_configs.push_back(std::move(module_config)); - computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); - all_executors.push_back(executors); - device_handles.insert(device_handles.end(), - execution_options.device_handles().begin(), - execution_options.device_handles().end()); - } - - // Build the user computations into HloModules and compile to generate the - // executables. - // - // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); - std::vector executable_ptrs; - executable_ptrs.reserve(executables.size()); - for (const auto& executable : executables) { - executable_ptrs.push_back(executable.get()); - } - - // Execute the generated executables in parallel and return the device - // handles for each computation's output. - ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; - } - - VLOG(1) << "successfully completed 'execute-parallel' request"; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; @@ -926,11 +689,10 @@ tensorflow::Status Service::ExecuteGraphParallel( std::unique_ptr module_config, CreateModuleConfig(request.computation().program_shape(), replicated_arguments.front(), - request.execution_options(), - /*user_computation=*/nullptr)); + request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -975,11 +737,11 @@ tensorflow::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); const int64 replica_count = options_.number_of_replicas(); if (replica_count <= 0) { @@ -999,20 +761,11 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, *result->add_device_handles() = device_handle; } - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); + return Status::OK(); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1020,7 +773,7 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::PickParallelResponse( +Status Service::PickParallelResponse( const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean @@ -1043,81 +796,6 @@ tensorflow::Status Service::PickParallelResponse( return Status::OK(); } -tensorflow::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - // Since we care only about the shapes of the arguments, it is sufficient to - // use the arguments of replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } - - TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + user_computation->name(), result->mutable_profile())); - - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } - - VLOG(1) << "successfully completed 'execute' request"; - return tensorflow::Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -1126,6 +804,22 @@ StatusOr> Service::BuildExecutable( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name().c_str()); + // Dump computation proto state if flag is set. + auto hlo_snapshot = MakeUnique(); + const string& directory_path = + module_config->debug_options().xla_dump_computations_to(); + const string& execution_directory_path = + module_config->debug_options().xla_dump_executions_to(); + if (!directory_path.empty() || !execution_directory_path.empty()) { + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; + if (!directory_path.empty()) { + string filename = Printf("computation_%lld__%s", module_proto.id(), + module_proto.entry_computation_name().c_str()); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); + } + } + TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(module_proto, *module_config)); @@ -1134,6 +828,8 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); + // Check that on-host and on-device shapes are consistent. + TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( @@ -1142,8 +838,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { VLOG(1) << "running execute-graph request"; if (!arg->has_computation()) { @@ -1176,99 +872,37 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + if (executable->dumping_snapshot()) { + executable->hlo_snapshot()->set_execution_platform( + execute_backend_->platform()->Name()); + TF_RETURN_IF_ERROR(RecordArguments( + replicated_arguments.front(), + execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->hlo_snapshot())); + } + TF_ASSIGN_OR_RETURN( *result->mutable_output(), ExecuteAndRegisterResult( executable.get(), replicated_arguments, execute_backend_.get(), "result of " + arg->computation().name(), result->mutable_profile())); - VLOG(1) << "successfully completed 'execute-graph' request"; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); - - // Set up streams. - std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - std::vector result_buffers; - for (size_t i = 0; i < streams.size(); ++i) { - const auto& stream = streams[i]; - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_inter_op_thread_pool(execute_backend_->inter_op_thread_pool()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); - - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, - executable->ExecuteAsyncOnStream( - &service_options, replicated_arguments[i])); - - result_buffers.emplace_back(std::move(this_result_buffer)); + if (executable->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN( + const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(result->output(), 0)); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(result_buffers), "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return tensorflow::Status::OK(); + VLOG(1) << "successfully completed 'execute-graph' request"; + return Status::OK(); } -tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) { +Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, execution_tracker_.Resolve(arg->execution())); @@ -1279,11 +913,11 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); VLOG(1) << "successfully completed 'wait-for-execution' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); @@ -1313,7 +947,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, *result->mutable_literal() = result_literal->Relayout(*return_shape)->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } namespace { @@ -1331,8 +965,8 @@ std::unique_ptr CloneShapedBufferOnDevice( } // namespace -tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { +Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(arg->literal())); const Shape& shape = literal->shape(); @@ -1365,11 +999,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1398,9 +1032,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, *literal); } -tensorflow::Status Service::TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1426,127 +1059,16 @@ tensorflow::Status Service::TransferFromOutfeed( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, arg->shape_with_layout(), &literal)); *result->mutable_literal() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { +Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); } -tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } - return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); - } - - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - - TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); - if (arg->has_output_layout()) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - user_computation)); - - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); - - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. - // - // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); - } - *result->mutable_literal() = result_literal->ToProto(); - - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { +Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) { if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -1584,73 +1106,17 @@ tensorflow::Status Service::ComputeConstantGraph( } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - HloModuleConfig config; - config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); - - // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis( - execute_backend_->compiler()->ShapeSizeBytesFunction()); - - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); - - ComputationStats stats; - stats.set_flop_count(analysis.flop_count()); - stats.set_transcendental_count(analysis.transcendental_count()); - *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationGraphStats( +Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); @@ -1677,264 +1143,7 @@ tensorflow::Status Service::GetComputationGraphStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); -} - -template -tensorflow::Status Service::AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - // Send does not return a value, but we need a handle to be able to - // set OpMetadata and OpSharding (device assignment). - handle_status = computation->AddSendInstruction(arg->send_request()); - break; - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return tensorflow::Status::OK(); -} - -tensorflow::Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceHandle Service::SingleComputationDeviceHandle() const { diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 476bd0597de735..d64b2b4d0afa15 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -27,15 +27,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -83,57 +80,29 @@ class Service : public ServiceInterface { static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - - // Executes one or more computations in parallel with the provided global data - // passed as immutable arguments. Returns global data output for each - // computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; // Requests one or more device handles from the target. // @@ -143,49 +112,33 @@ class Service : public ServiceInterface { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; - - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the // first call. - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; // Requests that global data be transferred to the client in literal form. - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; // Transfers data from a literal provided by the client, into device memory. - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; // Transfers data from the Outfeed othe device to the literal provided by the // client. - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -196,77 +149,25 @@ class Service : public ServiceInterface { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; - - // Tests if an expression is a compile-time constant. - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - // Computes the value of a constant expression. - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; - - // Returns the program shape of the computation associated with the given - // handle. - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; // Retrieves the statistics of a computation. - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Retrieves the statistics of a computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; // Creates a unique channel handle that can be used for Send/Recv // instructions. - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; - - // Returns the ComputationTracker of the current service instance. - // Only used in unit tests to access user computations from client. - const ComputationTracker& computation_tracker() { - return computation_tracker_; - } + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } @@ -278,8 +179,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, @@ -295,6 +195,9 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); + // Assert that host- and device-shapes are in a consistent state. + Status ValidateEntryComputationLayout(HloModule* module); + protected: friend class LocalExecutable; @@ -317,23 +220,13 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. // // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator = nullptr); - - // Builds an Executable for the given HLO module proto. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -342,26 +235,12 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator); StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); - // Similar to BuildExecutable, but look in the compilation cache for the - // executable first. If the executable is not in the cache, it is built and - // inserted into the cache. - StatusOr> BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator = nullptr); - // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an @@ -384,26 +263,16 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile); - // Convenience function for adding a function to a user computation. - template - tensorflow::Status AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder); - // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result); - tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - tensorflow::Status ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const; + Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that @@ -419,9 +288,6 @@ class Service : public ServiceInterface { ServiceOptions options_; - // Tracks computations built via the API. - ComputationTracker computation_tracker_; - // Tracks channels created via the API. ChannelTracker channel_tracker_; diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto deleted file mode 100644 index bb8d1cd2a106ea..00000000000000 --- a/tensorflow/compiler/xla/service/session.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This proto file defines messages which store the state of XLA -// computations within the XLA service. A computation is stored as a record -// of the operation requests used to build it. -syntax = "proto3"; - -import "tensorflow/compiler/xla/xla_data.proto"; - -package xla; - -// Describes a single operation request. -message OperationRequest { - ComputationDataHandle output_handle = 1; - Shape output_shape = 2; - - // For operations which call embedded computations such as "Map", these are - // the version(s) that the embedded computation should be called at. A version - // value of a computation is the ComputationDataHandle of the root of the - // computation at the point in time. - // - // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single - // embedded computation so this field will have a single value for those - // operations. - // - // "While" operation takes two; index 0 is the "condition" version and index 1 - // is the "body" version. - repeated int64 embedded_computation_versions = 3; - - // The actual request, which in itself is a tagged union of all possible - // operation request types. - OpRequest request = 4; -} - -// Describes a sequence of operation requests which define an XLA -// computation. -message SessionComputation { - string name = 1; - - // The ComputationHandle used to refer to this computation in the XLA - // service. - ComputationHandle computation_handle = 2; - - // Map from ComputationDataHandle value to operation request. The highest - // ComputationDataHandle value corresponds to the root of the computation. - map requests = 3; -} - -// Describes a group of SessionComputations with an "entry point" computation -// that may refer to the other non-entry (AKA embedded) computations. -// -// This message is used to serialize a computation that has been built via the -// XLA service API, along with its dependencies, for purposes such as -// analysis/replay/file-storage. -message SessionModule { - // The entry computation, which was requested for serialization. This may have - // referred to embedded computations, which are reflected below. - SessionComputation entry = 1; - - // Embedded computations that are transitively referred to by the entry - // computation. - repeated SessionComputation embedded_computations = 2; - - // The arguments passed to the computation. - repeated LiteralProto arguments = 3; - - // The result of the computation. - LiteralProto result = 4; - - // The name of the platform used to run the computation. - string execution_platform = 5; -} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 48b2922e77b787..d624f548b1ba65 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -58,6 +58,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_COS; case HloOpcode::kExp: return UNOP_EXP; + case HloOpcode::kExpm1: + return UNOP_EXPM1; case HloOpcode::kFloor: return UNOP_FLOOR; case HloOpcode::kImag: @@ -66,6 +68,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_IS_FINITE; case HloOpcode::kLog: return UNOP_LOG; + case HloOpcode::kLog1p: + return UNOP_LOG1P; case HloOpcode::kNot: return UNOP_NOT; case HloOpcode::kNegate: @@ -168,24 +172,24 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { +Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { return InvalidArgument("Expected non-opaque argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { - return tensorflow::Status::OK(); + return Status::OK(); } } -tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { if (reducer_shape.parameters_size() != 2) { return InvalidArgument( "Reduction function must take 2 parameters, but " @@ -245,7 +249,7 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, ShapeUtil::HumanString(accumulator_shape).c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr InferWindowOutputShape(const Shape& base_shape, @@ -312,7 +316,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. - if (opcode == HloOpcode::kCopy) { + // A domain shape is the same as the input one. + if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) { return shape; } @@ -337,7 +342,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_COS: case UNOP_SIN: case UNOP_EXP: + case UNOP_EXPM1: case UNOP_LOG: + case UNOP_LOG1P: case UNOP_TANH: if (!ShapeUtil::ElementIsFloating(arg) && !ShapeUtil::ElementIsComplex(arg)) { @@ -1212,11 +1219,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( @@ -1318,15 +1325,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index fb3b5f06dad67b..7d7dcac10b6593 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include #include #include @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -123,6 +123,8 @@ ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) } ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { + Deallocate(); + *static_cast(this) = std::move(static_cast(s)); allocator_ = s.allocator_; // Null out s.allocator_ so it doesn't try to free anything in its destructor. @@ -130,7 +132,15 @@ ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { return *this; } -ScopedShapedBuffer::~ScopedShapedBuffer() { +ScopedShapedBuffer::~ScopedShapedBuffer() { Deallocate(); } + +ShapedBuffer ScopedShapedBuffer::release() { + ShapedBuffer shaped_buffer(static_cast(*this)); + buffers_ = ShapeTree(); + return shaped_buffer; +} + +void ScopedShapedBuffer::Deallocate() { // allocator_ will be null if we were moved-from. if (allocator_ == nullptr) { return; @@ -138,22 +148,14 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - std::set deallocated_opaques; + tensorflow::gtl::FlatSet deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && - deallocated_opaques.count(memory_base.opaque()) == 0) { - deallocated_opaques.insert(memory_base.opaque()); - TF_CHECK_OK( - this->allocator_->Deallocate(this->device_ordinal(), &memory_base)); + deallocated_ptrs.insert(memory_base.opaque()).second) { + TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base)); } } } -ShapedBuffer ScopedShapedBuffer::release() { - ShapedBuffer shaped_buffer(static_cast(*this)); - buffers_ = ShapeTree(); - return shaped_buffer; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e10fca9e9466c0..905a7e82e621f2 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -148,13 +148,29 @@ class ScopedShapedBuffer : public ShapedBuffer { // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Releases all device memory owned by this ScopedShapedBuffer and returns the - // device memory pointers in the form of a ShapedBuffer. The returned - // ShapedBuffer takes over the memory from the ScopedShapedBuffer. The - // resulting ScopedShapedBuffer can only be destroyed. - ShapedBuffer release(); + // Sets the device memory buffer at the given index. + // + // If the given buffer's device memory is non-null, its device_ordinal and + // allocator must match those in `this`. + void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { + if (!buffer.is_null()) { + CHECK_EQ(buffer.device_ordinal(), device_ordinal()); + CHECK_EQ(buffer.allocator(), allocator_); + *buffers_.mutable_element(index) = buffer.Forget(); + } else { + *buffers_.mutable_element(index) = se::DeviceMemoryBase(); + } + } + + // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from + // this ScopedShapedBuffer, without freeing any of the associated memory. + // + // It's the caller's job to ensure that the memory contained therein is freed. + TF_MUST_USE_RESULT ShapedBuffer release(); protected: + void Deallocate(); + DeviceMemoryAllocator* allocator_; }; diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc new file mode 100644 index 00000000000000..0fc24366791165 --- /dev/null +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/shaped_buffer.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { +namespace { + +TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { + TF_ASSERT_OK_AND_ASSIGN(auto platforms, + xla::PlatformUtil::GetSupportedPlatforms()); + ASSERT_FALSE(platforms.empty()); + auto* platform = platforms[0]; + TF_ASSERT_OK_AND_ASSIGN(auto executors, + xla::PlatformUtil::GetStreamExecutors(platform)); + xla::StreamExecutorMemoryAllocator allocator(platform, executors); + const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + const int kDeviceOrdinal = 0; + auto scoped_buffer = tensorflow::MakeUnique( + shape, shape, &allocator, kDeviceOrdinal); + std::unique_ptr buffer = std::move(scoped_buffer); + buffer = nullptr; +} + +class TestAllocator : public DeviceMemoryAllocator { + public: + TestAllocator() + : DeviceMemoryAllocator(PlatformUtil::GetDefaultPlatform().ValueOrDie()) { + } + + ~TestAllocator() override { + if (!allocations_.empty()) { + ADD_FAILURE() << "Some allocations not freed!"; + } + } + + // Pull in two-arg overload of Allocate. + using DeviceMemoryAllocator::Allocate; + + StatusOr Allocate(int device_ordinal, uint64 size, + bool /*retry_on_failure*/) override { + // By contract, we must return null if size == 0. + if (size == 0) { + return OwningDeviceMemory(); + } + void* buf = malloc(size); + allocations_.insert({device_ordinal, buf}); + return OwningDeviceMemory(se::DeviceMemoryBase(buf, size), device_ordinal, + this); + } + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override { + if (mem.is_null()) { + return Status::OK(); + } + + auto it = allocations_.find({device_ordinal, mem.opaque()}); + if (it == allocations_.end()) { + ADD_FAILURE() << "Allocation not found (double free?)"; + } else { + free(mem.opaque()); + allocations_.erase(it); + } + return Status::OK(); + } + + bool AllowsAsynchronousDeallocation() const override { return false; } + + private: + std::set> allocations_; +}; + +TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) { + Shape s = ShapeUtil::MakeShape(F32, {1}); + TestAllocator allocator; + ScopedShapedBuffer sb1(s, s, &allocator, /*device_ordinal=*/0); + sb1.set_buffer( + allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(), + /*index=*/{}); + + ScopedShapedBuffer sb2(s, s, &allocator, /*device_ordinal=*/1); + sb2.set_buffer( + allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(), + /*index=*/{}); + + sb1 = std::move(sb2); + + // TestAllocator's destructor checks that all memory was freed. +} + +} // anonymous namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index a776d745f4e56c..18e2651abb1600 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { namespace source_map_util { -// Creates an INVALID_ARUGMENT status with the given format string. +// Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable // and append it to the status message for source mapping to user code. diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 8b71a415091f02..c4d01562c4e322 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -37,7 +37,7 @@ TransferManager::GetPlatformTransferManagers() { } Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -196,9 +196,11 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const ShapeIndex& index = pair.first; se::DeviceMemoryBase& memory_base = pair.second; const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); - TF_ASSIGN_OR_RETURN(memory_base, + TF_ASSIGN_OR_RETURN(auto memory, allocator->Allocate(shaped_buffer.device_ordinal(), GetByteSizeRequirement(subshape))); + // Move the allocated buffer into the ScopedShapedBuffer, which owns it. + memory_base = memory.Forget(); } return std::move(shaped_buffer); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index d82b4f0f81b5da..43a8092b06fba0 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -65,14 +65,14 @@ class TransferManager { // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, // but need not have the same layout virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. Status TransferArrayToDevice(se::StreamExecutor* executor, - const Literal& literal, + const LiteralSlice& literal, const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( se::StreamExecutor* executor, const Shape& shape, @@ -81,7 +81,7 @@ class TransferManager { // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) = 0; + const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 3efd38ce0daa3e..ba16dc640e2d29 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,7 +35,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode()) { + if (HloOpcode::kDot != dot.opcode() || + dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { return {}; } @@ -44,6 +45,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); + } else if (ShapeUtil::Rank(operand.shape()) != 2) { + return {}; } } @@ -74,23 +77,39 @@ using InstructionOperandsPair = // Folds the operands of `dot` that are foldable transposes. `computation` is // the parent HLO computation of `dot`. -// -// Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair) { - auto* dot = pair.first; - std::vector instructions_to_fuse(1, dot); - for (const int64 operand_index : pair.second) { - instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); - } - - // Early-exit if no operands are foldable. - if (instructions_to_fuse.size() == 1) { - return false; +Status FoldTransposeIntoDot(InstructionOperandsPair pair) { + HloInstruction* dot = pair.first; + + DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* new_lhs = dot->mutable_operand(0); + HloInstruction* new_rhs = dot->mutable_operand(1); + + CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); + + for (int64 operand_index : pair.second) { + // We've checked that there aren't any batch dimensions and that the inputs + // are rank 2, and shape inference guarantees that there is exactly one + // contracting dimension. + if (operand_index == 0) { + CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_lhs_contracting_dimensions( + 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + new_lhs = new_lhs->mutable_operand(0); + } else { + CHECK_EQ(operand_index, 1); + CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + new_rhs = new_rhs->mutable_operand(0); + } } - dot->parent()->CreateFusionInstruction( - instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); - return true; + std::unique_ptr new_dot = HloInstruction::CreateDot( + dot->shape(), new_lhs, new_rhs, new_dim_numbers); + return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } // Folds the operands of `convolution` that are foldable transposes. @@ -196,7 +215,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { std::make_pair(instruction, operand_indices)); } } - return tensorflow::Status::OK(); + return Status::OK(); }; for (auto* comp : module->MakeNonfusionComputations()) { @@ -205,7 +224,8 @@ StatusOr TransposeFolding::Run(HloModule* module) { bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair); + TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); + changed = true; } for (InstructionOperandsPair& pair : foldable_convolutions) { changed |= FoldTransposeIntoConvolution(pair); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index caa1a111ad880b..3139801ea31303 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,13 +19,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -34,6 +36,8 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase { }; TEST_F(TransposeFoldingTest, FoldDotTranspose) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); - - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3]{1,0} parameter(0) + y = f32[2,3]{1,0} parameter(1) + transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0} + ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.size()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = *instruction_set.begin(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + FoldTranspose(module.get()); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); +TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { - auto builder = HloComputation::Builder("entry_computation"); - // 2x1 - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); - // 3x2 - HloInstruction* const1 = - builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); - HloInstruction* transpose0 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1, 3}), - /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); - - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); - - for (auto* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK_EQ(2, instruction->operand_count()); - EXPECT_EQ(const0, instruction->operand(0)); - EXPECT_EQ(const1, instruction->operand(1)); - } - } + string hlo_string = R"( +HloModule FoldDotTransposeConstant + +ENTRY entry_computation { + constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} + constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} + ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + FoldTranspose(module.get()); - // The created fusion instruction should contain two parameters, two - // transposes (one for each parameter) and one dot. - EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instruction_count()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Constant(), op::Constant(), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1)); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -149,10 +172,10 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - HloModule module("fuse_with_constant_operands"); + auto module = CreateNewModule("fuse_with_constant_operands"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(mul)); - HloInstruction* call = module.OutlineExpressionFromComputation( + module->AddEntryComputation(builder.Build(mul)); + HloInstruction* call = module->OutlineExpressionFromComputation( {add, sub, mul}, "", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); @@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instruction_count()); } -TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); - - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - - HloInstruction* call = module.OutlineExpressionFromComputation( - {transpose_y, dot}, "outlined", entry_computation); +TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) { + string hlo_string = R"( +HloModule FoldDotTransposeInCall - FoldTranspose(&module); - - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(call)) - << "call is not in entry_computation."; - CHECK(instruction_set.empty()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = - call->called_computations().front()->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); +callee { + name.0 = f32[2,3]{1,0} parameter(0) + name.1 = f32[2,3]{1,0} parameter(1) + transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0} + ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); +ENTRY entry_computation { + y = f32[2,3]{1,0} parameter(1) + x = f32[2,3]{1,0} parameter(0) + ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + FoldTranspose(module.get()); + + const HloComputation* callee = module->GetComputationWithName("callee"); + ASSERT_NE(callee, nullptr); + EXPECT_THAT(callee->root_instruction(), + op::Dot(op::Parameter(1), op::Parameter(0), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } // Test that a two dimension swap of the kernel gets folded into convolution. @@ -222,7 +227,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -240,10 +245,10 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -275,7 +280,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -293,10 +298,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -334,7 +339,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -351,10 +356,10 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -398,7 +403,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -415,10 +420,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 657a8fe09ae9df..bb634e6573ffce 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(domain, domain->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { // A kSlice instruction aliases its operand if the backend lowers it to an // in-place implementation. @@ -588,4 +596,201 @@ void TuplePointsToAnalysis::InstructionToString( }); } +bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'instruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> +TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const { + std::vector> uses; + const PointsToSet::BufferList& points_to = + GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index) const { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = + GetAllUsesOfInstructionAtIndex(fused_param, operand_index); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following qualifications: +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd output fusion instruction where the only use of +// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused +// root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. +bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c3743b150168eb..c0d82414806d9a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -248,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; @@ -256,6 +257,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + private: explicit TuplePointsToAnalysis( const HloModule* module, @@ -310,6 +328,13 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { return &per_instruction_[id]; } + std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const; + bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* fusion, + const int64 use_operand_index) const; + // The module this analysis is performed on. const HloModule* module_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dec446d4dac650..f558316b05b168 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -805,5 +805,348 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { Run(/*add_additional_gte0_user=*/true); } +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE( + points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + add_operand, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {}, + call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index 113c2e2bd9f73a..d668855084a884 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -69,6 +69,7 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // Tuple // HloInstruction* top_tuple = nullptr; + HloInstruction* first_gte = nullptr; bool can_simplify = true; for (int64 operand_number = 0; operand_number < instruction->operand_count(); ++operand_number) { @@ -78,11 +79,17 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - + if (first_gte == nullptr) { + first_gte = operand; + } else if (!first_gte->has_compatible_sharding(operand)) { + can_simplify = false; + break; + } if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), - instruction->shape())) { + instruction->shape()) || + !instruction->has_compatible_sharding(top_tuple)) { can_simplify = false; break; } @@ -108,15 +115,17 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // | // GTE if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { - changed = true; HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); - for (HloInstruction* user : element_source->users()) { - if (user->opcode() == HloOpcode::kTuple || - user->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(user); + if (instruction->has_compatible_sharding(element_source)) { + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); + } } } } diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 754fd8ef169231..d33d5bb8f30c85 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc deleted file mode 100644 index 0f16a592b68e20..00000000000000 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ /dev/null @@ -1,3553 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { -namespace { - -HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { - switch (unop) { - case UNOP_ABS: - return HloOpcode::kAbs; - case UNOP_CEIL: - return HloOpcode::kCeil; - case UNOP_CLZ: - return HloOpcode::kClz; - case UNOP_COS: - return HloOpcode::kCos; - case UNOP_EXP: - return HloOpcode::kExp; - case UNOP_FLOOR: - return HloOpcode::kFloor; - case UNOP_IMAG: - return HloOpcode::kImag; - case UNOP_IS_FINITE: - return HloOpcode::kIsFinite; - case UNOP_LOG: - return HloOpcode::kLog; - case UNOP_NOT: - return HloOpcode::kNot; - case UNOP_NEGATE: - return HloOpcode::kNegate; - case UNOP_REAL: - return HloOpcode::kReal; - case UNOP_ROUND_NEAREST_AFZ: - return HloOpcode::kRoundNearestAfz; - case UNOP_SIGN: - return HloOpcode::kSign; - case UNOP_SIN: - return HloOpcode::kSin; - case UNOP_SORT: - return HloOpcode::kSort; - case UNOP_TANH: - return HloOpcode::kTanh; - default: - LOG(FATAL) << "unhandled operation " << unop; - } -} - -HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { - switch (binop) { - case BINOP_ATAN2: - return HloOpcode::kAtan2; - case BINOP_COMPLEX: - return HloOpcode::kComplex; - case BINOP_MUL: - return HloOpcode::kMultiply; - case BINOP_ADD: - return HloOpcode::kAdd; - case BINOP_SUB: - return HloOpcode::kSubtract; - case BINOP_DIV: - return HloOpcode::kDivide; - case BINOP_EQ: - return HloOpcode::kEq; - case BINOP_GE: - return HloOpcode::kGe; - case BINOP_GT: - return HloOpcode::kGt; - case BINOP_LE: - return HloOpcode::kLe; - case BINOP_LT: - return HloOpcode::kLt; - case BINOP_NE: - return HloOpcode::kNe; - case BINOP_MAX: - return HloOpcode::kMaximum; - case BINOP_MIN: - return HloOpcode::kMinimum; - case BINOP_POW: - return HloOpcode::kPower; - case BINOP_REM: - return HloOpcode::kRemainder; - case BINOP_OR: - return HloOpcode::kOr; - case BINOP_AND: - return HloOpcode::kAnd; - case BINOP_SHIFT_LEFT: - return HloOpcode::kShiftLeft; - case BINOP_SHIFT_RIGHT_ARITHMETIC: - return HloOpcode::kShiftRightArithmetic; - case BINOP_SHIFT_RIGHT_LOGICAL: - return HloOpcode::kShiftRightLogical; - default: - LOG(FATAL) << "unhandled operation " << binop; - } -} - -HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { - switch (triop) { - case TRIOP_CLAMP: - return HloOpcode::kClamp; - case TRIOP_SELECT: - return HloOpcode::kSelect; - default: - LOG(FATAL) << "unhandled operation " << triop; - } -} - -HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { - switch (varop) { - case VAROP_TUPLE: - return HloOpcode::kTuple; - default: - LOG(FATAL) << "unhandled operation " << varop; - } -} - -} // namespace - -/* static */ StatusOr> -UserComputation::MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new) { - auto user_computation = - MakeUnique(session_computation.name(), handle); - { - tensorflow::mutex_lock lock(user_computation->mutex_); - user_computation->session_computation_ = session_computation; - user_computation->next_handle_value_ = - std::max_element(session_computation.requests().begin(), - session_computation.requests().end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }) - ->first + - 1; - TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); - } - - return std::move(user_computation); -} - -UserComputation::UserComputation(const string& name, - const ComputationHandle& handle) - : name_(name), next_handle_value_(1) { - *session_computation_.mutable_computation_handle() = handle; - session_computation_.set_name(name); - - VLOG(1) << "New UserComputation \"" << name - << "\", handle: " << handle.handle(); -} - -ComputationDataHandle UserComputation::CreateComputationDataHandle() { - ComputationDataHandle handle; - handle.set_handle(next_handle_value_); - // Handles are used as Version values and *must* be assigned consecutively for - // computation versioning to work. - next_handle_value_++; - return handle; -} - -StatusOr UserComputation::AddParameterInstruction( - const ParameterRequest& parameter_request) { - tensorflow::mutex_lock lock(mutex_); - - int64 parameter_number = parameter_request.parameter(); - if (parameters_.count(parameter_number) != 0) { - return InvalidArgument("parameter %lld already registered", - parameter_number); - } - ComputationDataHandle handle = CreateComputationDataHandle(); - - const Shape& validated_shape = parameter_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_parameter_request() = parameter_request; - - parameters_[parameter_number] = &request; - - VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << parameter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSendInstruction( - const SendRequest& send_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); - - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_send_request() = send_request; - - VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << send_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddRecvInstruction( - const RecvRequest& recv_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = recv_request.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_recv_request() = recv_request; - - VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << recv_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddPadInstruction( - const PadRequest& pad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(pad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookUpRequest(pad_request.padding_value())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( - operand->output_shape(), - padding_value->output_shape(), - pad_request.padding_config())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_pad_request() = pad_request; - - VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << pad_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConstantInstruction( - const ConstantRequest& constant_request) { - const Shape& validated_shape = constant_request.literal().shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - tensorflow::mutex_lock lock(mutex_); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_constant_request() = constant_request; - - VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle(); - return handle; -} - -StatusOr UserComputation::AddGatherInstruction( - const GatherRequest& gather_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, - LookUpRequest(gather_request.input())); - TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, - LookUpRequest(gather_request.gather_indices())); - - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferGatherShape( - input_request->output_shape(), gather_indices_request->output_shape(), - gather_request.dimension_numbers(), - AsInt64Slice(gather_request.window_bounds()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_gather_request() = gather_request; - - VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << gather_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(get_tuple_element_request.operand())); - if (!ShapeUtil::IsTuple(operand->output_shape())) { - return InvalidArgument( - "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(operand->output_shape()).c_str()); - } - Shape element_shape = ShapeUtil::GetTupleElementShape( - operand->output_shape(), get_tuple_element_request.index()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = element_shape; - *request.mutable_request()->mutable_get_tuple_element_request() = - get_tuple_element_request; - - VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << get_tuple_element_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_trace_request() = trace_request; - - VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << trace_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRngInstruction( - const RngRequest& rng_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check the number of parameters per RNG distribution. - switch (rng_request.distribution()) { - case RandomDistribution::RNG_NORMAL: - case RandomDistribution::RNG_UNIFORM: - if (rng_request.parameter_size() != 2) { - return InvalidArgument( - "RNG distribution (%s) expects 2 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; - default: - LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); - } - - // Verify that the parameter indices are valid; - for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookUpRequest(param).status()); - } - const Shape& validated_shape = rng_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_rng_request() = rng_request; - - VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << rng_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, - AsInt64Slice(map_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_map_request() = map_request; - - VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << map_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceShape( - operand->output_shape(), init_value->output_shape(), - AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_request() = reduce_request; - - VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_training_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_training_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_training_request.offset())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormTrainingShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), batch_norm_training_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_training_request() = - batch_norm_training_request; - - VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_training_request.ShortDebugString(); - - return handle; -} - -StatusOr -UserComputation::AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_inference_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_inference_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_inference_request.offset())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_inference_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_inference_request.variance())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBatchNormInferenceShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), mean->output_shape(), - variance->output_shape(), - batch_norm_inference_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_inference_request() = - batch_norm_inference_request; - - VLOG(1) << "AddBatchNormInferenceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << batch_norm_inference_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_grad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_grad_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_grad_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_grad_request.variance())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, - LookUpRequest(batch_norm_grad_request.grad_output())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormGradShape( - operand->output_shape(), scale->output_shape(), mean->output_shape(), - variance->output_shape(), grad_output->output_shape(), - batch_norm_grad_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_grad_request() = - batch_norm_grad_request; - - VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_grad_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_window_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_window_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->output_shape(), init_value->output_shape(), - reduce_window_request.window(), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_window_request() = - reduce_window_request; - - VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_window_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(select_and_scatter_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookUpRequest(select_and_scatter_request.source())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(select_and_scatter_request.init_value())); - - VersionedComputationHandle::Version select_version = - select_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, - select_computation.ComputeProgramShape(select_version)); - VersionedComputationHandle::Version scatter_version = - scatter_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, - scatter_computation.ComputeProgramShape(scatter_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferSelectAndScatterShape( - operand->output_shape(), *select_program_shape, - select_and_scatter_request.window(), source->output_shape(), - init_value->output_shape(), *scatter_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(select_version); - request.add_embedded_computation_versions(scatter_version); - *request.mutable_request()->mutable_select_and_scatter_request() = - select_and_scatter_request; - - VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << select_and_scatter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReverseInstruction( - const ReverseRequest& reverse_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reverse_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReverseShape( - operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reverse_request() = reverse_request; - VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reverse_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookUpRequest(while_request.init())); - - VersionedComputationHandle::Version condition_version = - condition_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr condition_program_shape, - condition_computation.ComputeProgramShape(condition_version)); - - VersionedComputationHandle::Version body_version = body_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, - body_computation.ComputeProgramShape(body_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferWhileShape( - *condition_program_shape, *body_program_shape, init->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(condition_version); - request.add_embedded_computation_versions(body_version); - *request.mutable_request()->mutable_while_request() = while_request; - - VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << while_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* pred, - LookUpRequest(conditional_request.predicate())); - TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, - LookUpRequest(conditional_request.true_operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, - LookUpRequest(conditional_request.false_operand())); - - VersionedComputationHandle::Version true_computation_version = - true_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr true_computation_shape, - true_computation.ComputeProgramShape(true_computation_version)); - - VersionedComputationHandle::Version false_computation_version = - false_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr false_computation_shape, - false_computation.ComputeProgramShape(false_computation_version)); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferConditionalShape( - pred->output_shape(), true_operand->output_shape(), - false_operand->output_shape(), - *true_computation_shape, *false_computation_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(true_computation_version); - request.add_embedded_computation_versions(false_computation_version); - *request.mutable_request()->mutable_conditional_request() = - conditional_request; - - VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << conditional_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBroadcastInstruction( - const BroadcastRequest& broadcast_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(broadcast_request.operand())); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBroadcastShape( - operand->output_shape(), - AsInt64Slice(broadcast_request.broadcast_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_broadcast_request() = broadcast_request; - - VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << broadcast_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReshapeInstruction( - const ReshapeRequest& reshape_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reshape_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReshapeShape( - operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), - AsInt64Slice(reshape_request.new_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reshape_request() = reshape_request; - - VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reshape_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTransposeInstruction( - const TransposeRequest& transpose_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(transpose_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferTransposeShape( - operand->output_shape(), - AsInt64Slice(transpose_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_transpose_request() = transpose_request; - - VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << transpose_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSliceInstruction( - const SliceRequest& slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(slice_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferSliceShape( - operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_slice_request() = slice_request; - - VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookUpRequest(dynamic_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferDynamicSliceShape( - operand->output_shape(), start_indices->output_shape(), - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_slice_request() = - dynamic_slice_request; - - VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dynamic_slice_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_update_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookUpRequest(dynamic_update_slice_request.update())); - - TF_ASSIGN_OR_RETURN( - const OperationRequest* start_indices, - LookUpRequest(dynamic_update_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->output_shape(), update->output_shape(), - start_indices->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_update_slice_request() = - dynamic_update_slice_request; - - VLOG(1) << "AddDynamicUpdateSliceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << dynamic_update_slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferConcatOpShape( - operand_shapes, concatenate_request.dimension())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_concatenate_request() = - concatenate_request; - - VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << concatenate_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_convert_request() = convert_request; - - VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBitcastConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_bitcast_convert_request() = - convert_request; - - VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_precision_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferReducePrecisionShape( - operand->output_shape(), reduce_precision_request.exponent_bits(), - reduce_precision_request.mantissa_bits())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_reduce_precision_request() = - reduce_precision_request; - - VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_precision_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvolveInstruction( - const ConvolveRequest& convolve_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(convolve_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(convolve_request.rhs())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( - lhs->output_shape(), rhs->output_shape(), - convolve_request.window(), - convolve_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_convolve_request() = convolve_request; - - VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convolve_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddFftInstruction( - const FftRequest& fft_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(fft_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferFftShape( - operand->output_shape(), fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_fft_request() = fft_request; - - VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << fft_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(cross_replica_sum_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand->output_shape()})); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_cross_replica_sum_request() = - cross_replica_sum_request; - - VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << cross_replica_sum_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddInfeedInstruction( - const InfeedRequest& infeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = infeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Infeed must have a layout"); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_infeed_request() = infeed_request; - - VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << infeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddOutfeedInstruction( - const OutfeedRequest& outfeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = outfeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Outfeed must have a layout"); - } - - // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_outfeed_request() = outfeed_request; - - VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << outfeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_call_request() = call_request; - - VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCustomCallInstruction( - const CustomCallRequest& custom_call_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), - "$")) { - return InvalidArgument( - "Invalid custom_call_target \"%s\": Call targets that start with '$' " - "are reserved for internal use.", - custom_call_request.call_target_name().c_str()); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = custom_call_request.shape(); - *request.mutable_request()->mutable_custom_call_request() = - custom_call_request; - - VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << custom_call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddHostComputeInstruction( - const HostComputeRequest& host_compute_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : host_compute_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = host_compute_request.shape(); - *request.mutable_request()->mutable_host_compute_request() = - host_compute_request; - - VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << host_compute_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDotInstruction( - const DotRequest& dot_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(dot_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(dot_request.rhs())); - - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( - lhs->output_shape(), rhs->output_shape(), - dot_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_dot_request() = dot_request; - - VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dot_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddUnaryInstruction( - const UnaryOpRequest& unary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(unary_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), - operand->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_unary_op_request() = unary_request; - - VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << unary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBinaryInstruction( - const BinaryOpRequest& binary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(binary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(binary_request.rhs())); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferBinaryOpShape( - binary_request.binop(), lhs->output_shape(), rhs->output_shape(), - AsInt64Slice(binary_request.broadcast_dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_binary_op_request() = binary_request; - - VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << binary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTernaryInstruction( - const TernaryOpRequest& ternary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(ternary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(ternary_request.rhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookUpRequest(ternary_request.ehs())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferTernaryOpShape( - ternary_request.triop(), lhs->output_shape(), - rhs->output_shape(), ehs->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_ternary_op_request() = ternary_request; - - VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << ternary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddVariadicInstruction( - const VariadicOpRequest& variadic_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferVariadicOpShape( - variadic_request.varop(), operand_shapes)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_variadic_op_request() = variadic_request; - - VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << variadic_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - return operand->output_shape(); -} - -Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_metadata() = metadata; - return Status::OK(); -} - -Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpSharding (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_sharding() = sharding; - return Status::OK(); -} - -Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { - return InvalidArgument("Invalid handle in SetReturnValue"); - } - - handle_to_return_ = handle; - - VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " - << GetVersionedHandleInternal(); - - return Status::OK(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandle() const { - tensorflow::mutex_lock lock(mutex_); - return GetVersionedHandleInternal(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - - if (handle_to_return_.handle() > 0) { - // A specific handle has been requested for the result of the computation. - versioned_handle.version = handle_to_return_.handle(); - } else { - // A version value is simply the most recently assigned - // ComputationDataHandle value, ie the handle value of the root of the - // computation. - versioned_handle.version = next_handle_value_ - 1; - } - - return versioned_handle; -} - -VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const { - tensorflow::mutex_lock lock(mutex_); - - // The version at which an operation was added is simply the handle value of - // the ComputationDataHandle. - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - versioned_handle.version = operation.handle(); - return versioned_handle; -} - -VersionedComputationHandle::Version UserComputation::version() const { - return GetVersionedHandle().version; -} - -namespace { - -// Returns true if the operation type corresponding to the given opcase can be -// the root of the computation. -bool CanBeRoot(const OpRequest::OpCase& op_case) { - switch (op_case) { - case OpRequest::kTraceRequest: - case OpRequest::kSendRequest: - case OpRequest::kOutfeedRequest: - return false; - default: - return true; - } -} - -// Returns a pointer to the operation with the given data handle value in the -// given SessionComputation. -StatusOr LookUpRequest( - int64 handle_value, const SessionComputation& session_computation) { - if (session_computation.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation.requests().at(handle_value); -} - -// Returns the OperationRequest corresponding to the root (result) of the -// session computation. -StatusOr GetRoot( - VersionedComputationHandle::Version version, - const SessionComputation& session_computation) { - TF_RET_CHECK(version > 0); - // Not all instructions can be roots. Walk backwards from the operation - // indicated by this version until a valid root is found. - const OperationRequest* root_request = nullptr; - while (version > 0) { - TF_ASSIGN_OR_RETURN(root_request, - LookUpRequest(version, session_computation)); - if (CanBeRoot(root_request->request().op_case())) { - break; - } - version--; - } - if (version == 0) { - return InternalError("Computation contains no root operation"); - } - return root_request; -} - -} // namespace - -StatusOr> -UserComputation::ComputeProgramShape( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - if (program_shape_ == nullptr || program_shape_version_ != version) { - // ProgramShape has not been computed yet, or is for different - // version. Compute it now. - TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); - - auto program_shape = MakeUnique(); - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - int64 param_no = parameter_request.parameter(); - // Parameters may be out of order so expand ProgramShape parameters - // until it is at least large enough to hold the current parameter - // number. - while (program_shape->parameters_size() <= param_no) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); - } - *program_shape->mutable_parameters(param_no) = request.output_shape(); - *program_shape->mutable_parameter_names(param_no) = - parameter_request.name(); - } - } - - // The root determines the output shape. - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version, session_computation_)); - *program_shape->mutable_result() = root_request->output_shape(); - if (ShapeUtil::IsOpaque(program_shape->result())) { - return Unimplemented("Computation results cannot be opaque"); - } - - program_shape_ = std::move(program_shape); - program_shape_version_ = version; - } - - return program_shape_; -} - -namespace { - -// A visitor which checks whether an operation is pure functional meaning that -// it doesn't depend on any parameter with an index higher then num_parameters. -// The visitor walks the computation starting at a given operation and sets -// is_functional to false iff a parameter or RNG operation is encountered. -void PureFunctionalVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - int64 num_parameters, std::set* visited, - bool* is_functional) { - if (visited->count(handle.handle()) != 0 || !*is_functional) { - return; - } - - const OperationRequest& request = - session_computation.requests().at(handle.handle()); - switch (request.request().op_case()) { - case OpRequest::kRngRequest: - *is_functional = false; - break; - - case OpRequest::kConstantRequest: - break; - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - PureFunctionalVisitor(session_computation, - get_tuple_element_request.operand(), num_parameters, - visited, is_functional); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - PureFunctionalVisitor(session_computation, slice_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.update(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - PureFunctionalVisitor(session_computation, convolve_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, convolve_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - PureFunctionalVisitor(session_computation, fft_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_functional = false; - break; - } - - case OpRequest::kInfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kOutfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kHostComputeRequest: { - *is_functional = false; - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself, - // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_functional=false in other similar - // cases since we're already relying on IsConstant to return true. - *is_functional = false; - break; - } - - case OpRequest::kCustomCallRequest: { - *is_functional = false; - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - PureFunctionalVisitor(session_computation, dot_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, dot_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kSendRequest: { - *is_functional = false; - break; - } - - case OpRequest::kRecvRequest: { - *is_functional = false; - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - PureFunctionalVisitor(session_computation, reduce_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, reduce_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - PureFunctionalVisitor(session_computation, - reduce_window_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - reduce_window_request.init_value(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.source(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the select and scatter - // computations themselves. - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - PureFunctionalVisitor(session_computation, broadcast_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - PureFunctionalVisitor(session_computation, reshape_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - PureFunctionalVisitor(session_computation, reverse_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - PureFunctionalVisitor(session_computation, pad_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, pad_request.padding_value(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - if (parameter_request.parameter() >= num_parameters) { - *is_functional = false; - } - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - PureFunctionalVisitor(session_computation, while_request.init(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the condition and body - // computations themselves. - *is_functional = false; - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - PureFunctionalVisitor(session_computation, - conditional_request.predicate(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.true_operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.false_operand(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the true and false computations - // themselves. - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - PureFunctionalVisitor(session_computation, transpose_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - PureFunctionalVisitor(session_computation, unary_op_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.offset(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.scale(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.offset(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.mean(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.variance(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.variance(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.grad_output(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - PureFunctionalVisitor(session_computation, binary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, binary_op_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kGatherRequest: { - PureFunctionalVisitor(session_computation, - request.request().gather_request().input(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - request.request().gather_request().gather_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - if (!*is_functional) { - VLOG(1) << "Non-functional: " << request.request().DebugString(); - } - visited->insert(handle.handle()); -} - -} // namespace - -StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, - int64 num_parameters) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the handle is valid. - auto operation_status = LookUpRequest(handle); - if (!operation_status.ok()) { - return operation_status.status(); - } - - bool is_constant = true; - std::set visited; - PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, - &is_constant); - - return is_constant; -} - -std::vector -UserComputation::GetEmbeddedComputations( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(1) - << "GetEmbeddedComputations(" << name() << " " - << VersionedComputationHandle{session_computation_.computation_handle(), - version} - << ")"; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - std::vector computations; - std::vector sorted_handles; - for (const auto& handle_request : session_computation_.requests()) { - sorted_handles.push_back(handle_request.first); - } - std::sort(sorted_handles.begin(), sorted_handles.end()); - for (int64 handle : sorted_handles) { - const auto& handle_request = session_computation_.requests().find(handle); - CHECK(handle_request != session_computation_.requests().end()); - int64 handle_value = handle_request->first; - if (handle_value <= version) { - const OperationRequest& request = handle_request->second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const CallRequest& call_request = request.request().call_request(); - const VersionedComputationHandle versioned_handle = { - call_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kMapRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const MapRequest& map_request = request.request().map_request(); - const VersionedComputationHandle versioned_handle = { - map_request.to_apply(), request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceRequest& reduce_request = - request.request().reduce_request(); - const VersionedComputationHandle versioned_handle = { - reduce_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceWindowRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - const VersionedComputationHandle versioned_handle = { - reduce_window_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - const VersionedComputationHandle select_versioned_handle = { - select_and_scatter_request.select(), - request.embedded_computation_versions(0)}; - computations.push_back(select_versioned_handle); - const VersionedComputationHandle scatter_versioned_handle = { - select_and_scatter_request.scatter(), - request.embedded_computation_versions(1)}; - computations.push_back(scatter_versioned_handle); - break; - } - - case OpRequest::kWhileRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const WhileRequest& while_request = request.request().while_request(); - const VersionedComputationHandle condition_versioned_handle = { - while_request.condition(), - request.embedded_computation_versions(0)}; - computations.push_back(condition_versioned_handle); - const VersionedComputationHandle body_versioned_handle = { - while_request.body(), request.embedded_computation_versions(1)}; - computations.push_back(body_versioned_handle); - break; - } - - case OpRequest::kConditionalRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - const VersionedComputationHandle true_computation_versioned_handle = { - conditional_request.true_computation(), - request.embedded_computation_versions(0)}; - computations.push_back(true_computation_versioned_handle); - const VersionedComputationHandle false_computation_versioned_handle = - {conditional_request.false_computation(), - request.embedded_computation_versions(1)}; - computations.push_back(false_computation_versioned_handle); - break; - } - - default: - // No embedded computation. - break; - } - } - } - VLOG(2) << "Embedded computations: " - << tensorflow::str_util::Join( - computations, ", ", - [](string* out, const VersionedComputationHandle& h) { - out->append(h.ToString()); - }); - return computations; -} - -StatusOr -UserComputation::LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const { - tensorflow::mutex_lock lock(mutex_); - return LookUpRequest(handle); -} - -tensorflow::gtl::optional UserComputation::ParameterMetadata( - int parameter_number) const { - tensorflow::mutex_lock lock(mutex_); - auto it = parameters_.find(parameter_number); - if (it == parameters_.end()) { - return tensorflow::gtl::nullopt; - } - OperationRequest* op = it->second; - return &op->request().metadata(); -} - -Status UserComputation::RemapEmbeddedComputations( - const std::map& old_to_new) { - auto update = [&old_to_new](ComputationHandle* to_update) -> Status { - int64 old = to_update->handle(); - auto it = old_to_new.find(old); - if (it == old_to_new.end()) { - string mapping = tensorflow::str_util::Join( - old_to_new, ", ", - [](string* out, std::pair element) { - tensorflow::strings::Appendf(out, "%lld:%lld", element.first, - element.second.handle()); - }); - return NotFound( - "could not find referenced (old) computation handle in mapping: " - "%lld; mapping: {%s}", - old, mapping.c_str()); - } - VLOG(2) << "remapping " << old << " to " << it->second.handle(); - *to_update = it->second; - return Status::OK(); - }; - TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); - for (auto& handle_request : *session_computation_.mutable_requests()) { - OperationRequest& request = handle_request.second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - CallRequest* call_request = - request.mutable_request()->mutable_call_request(); - TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); - break; - } - case OpRequest::kMapRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - MapRequest* map_request = - request.mutable_request()->mutable_map_request(); - TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceRequest* reduce_request = - request.mutable_request()->mutable_reduce_request(); - TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceWindowRequest* reduce_window_request = - request.mutable_request()->mutable_reduce_window_request(); - TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); - break; - } - case OpRequest::kSelectAndScatterRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - SelectAndScatterRequest* select_and_scatter_request = - request.mutable_request()->mutable_select_and_scatter_request(); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_select())); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_scatter())); - break; - } - case OpRequest::kWhileRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - WhileRequest* while_request = - request.mutable_request()->mutable_while_request(); - TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); - TF_RETURN_IF_ERROR(update(while_request->mutable_body())); - break; - } - case OpRequest::kConditionalRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - ConditionalRequest* conditional_request = - request.mutable_request()->mutable_conditional_request(); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_true_computation())); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_false_computation())); - break; - } - default: - // No embedded computation. - TF_RET_CHECK(0 == request.embedded_computation_versions_size()); - break; - } - } - return Status::OK(); -} - -SessionComputation UserComputation::CloneSessionComputation( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - SessionComputation result = session_computation_; - // Erase all the requests that exceed the version specified. - // There's no lower_bound method on tensorflow::protobuf::Map so we iterate - // all the elements. - auto it = result.mutable_requests()->begin(); - while (it != result.mutable_requests()->end()) { - if (it->first > version) { - it = result.mutable_requests()->erase(it); - } else { - ++it; - } - } - return result; -} - -StatusOr UserComputation::LookUpRequest( - const ComputationDataHandle& handle) const { - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation_.requests().at(handle_value); -} - -Status UserComputation::CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const { - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - // Determine number of parameter inputs at the given version. - std::map parameter_requests; - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - // Duplicate parameters should be checked when parameter requests are - // added. - TF_RET_CHECK(0 == - parameter_requests.count(parameter_request.parameter())); - parameter_requests[parameter_request.parameter()] = ¶meter_request; - } - } - - for (int64 i = 0; i < parameter_requests.size(); ++i) { - auto it = parameter_requests.find(i); - if (it == parameter_requests.end()) { - return FailedPrecondition( - "computation %s does not have all its parameters populated " - "sequentially, missing parameter %lld", - name_.c_str(), i); - } - } - - return Status::OK(); -} - -namespace { - -// Helper class which builds an HLO computation from a SessionComputation. To -// construct the HLO computation, the SessionComputation graph is walked in -// DFS order lowering each OperationRequest to an HLO instruction. -class ComputationLowerer { - public: - static StatusOr> Lower( - const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) { - ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver), debug_options, - include_unreachable_instructions); - return lowerer.Lower(); - } - - private: - ComputationLowerer(const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) - : hlo_builder_(computation_name), - session_computation_(session_computation), - version_(version), - hlo_resolver_(std::move(hlo_resolver)), - debug_options_(debug_options), - include_unreachable_instructions_(include_unreachable_instructions) {} - - // Build an HLO computation from the SessionComputation at the given - // version. - StatusOr> Lower(); - - private: - // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. - void TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit); - - // DFS visitor of the UserComputation operations which lowers the operations - // to HLO instructions. - void Visit(const ComputationDataHandle& handle, - std::unordered_map* instructions); - - // Resolves a ComputationHandle and Version to a previously lowered - // HloComputation using the hlo_resolver_ function. - HloComputation* ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version); - - // This function takes an input value which is being implicitly broadcast into - // an output shape and figures out the right kBroadcast instruction(s) - // necessary to replicate the implicit broadcast semantics explicitly. - HloInstruction* ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape); - - HloComputation::Builder hlo_builder_; - const SessionComputation& session_computation_; - const VersionedComputationHandle::Version version_; - const UserComputation::HloComputationResolver hlo_resolver_; - const DebugOptions& debug_options_; - const bool include_unreachable_instructions_; -}; - -// Calls 'apply' on each operand of 'request'. -static void ForEachOperand( - const OperationRequest& request, - const std::function& apply) { - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - for (const ComputationDataHandle& param : rng_request.parameter()) { - apply(param); - } - break; - } - - case OpRequest::kConstantRequest: - break; - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - apply(get_tuple_element_request.operand()); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - apply(slice_request.operand()); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - apply(dynamic_slice_request.operand()); - apply(dynamic_slice_request.start_indices()); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - apply(dynamic_update_slice_request.operand()); - apply(dynamic_update_slice_request.update()); - apply(dynamic_update_slice_request.start_indices()); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - apply(convolve_request.lhs()); - apply(convolve_request.rhs()); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - apply(fft_request.operand()); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - - apply(batch_norm_training_request.operand()); - apply(batch_norm_training_request.scale()); - apply(batch_norm_training_request.offset()); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - - apply(batch_norm_inference_request.operand()); - apply(batch_norm_inference_request.scale()); - apply(batch_norm_inference_request.offset()); - apply(batch_norm_inference_request.mean()); - apply(batch_norm_inference_request.variance()); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - apply(batch_norm_grad_request.operand()); - apply(batch_norm_grad_request.scale()); - apply(batch_norm_grad_request.mean()); - apply(batch_norm_grad_request.variance()); - apply(batch_norm_grad_request.grad_output()); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - apply(cross_replica_sum_request.operand()); - break; - } - - case OpRequest::kInfeedRequest: - break; - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - apply(outfeed_request.operand()); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - apply(reduce_request.operand()); - apply(reduce_request.init_value()); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - apply(reduce_window_request.operand()); - apply(reduce_window_request.init_value()); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - apply(select_and_scatter_request.operand()); - apply(select_and_scatter_request.source()); - apply(select_and_scatter_request.init_value()); - - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - apply(broadcast_request.operand()); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - apply(reshape_request.operand()); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - apply(transpose_request.operand()); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - apply(reverse_request.operand()); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - apply(pad_request.operand()); - apply(pad_request.padding_value()); - break; - } - - case OpRequest::kRecvRequest: - case OpRequest::kParameterRequest: - break; - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - apply(while_request.init()); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - apply(conditional_request.predicate()); - apply(conditional_request.true_operand()); - apply(conditional_request.false_operand()); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - apply(ternary_op_request.lhs()); - apply(ternary_op_request.rhs()); - apply(ternary_op_request.ehs()); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - for (const ComputationDataHandle& operand : cc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& hc_request = - request.request().host_compute_request(); - for (const ComputationDataHandle& operand : hc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - apply(dot_request.rhs()); - apply(dot_request.lhs()); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - apply(unary_op_request.operand()); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - apply(binary_op_request.rhs()); - apply(binary_op_request.lhs()); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - apply(reduce_precision_request.operand()); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - apply(trace_request.operand()); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - apply(send_request.operand()); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - apply(gather_request.input()); - apply(gather_request.gather_indices()); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } -} - -void ComputationLowerer::TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit) { - // Stack containing {handle, enter} pairs. The 'enter' value describes whether - // we are entering or leaving 'handle'. - std::stack> work; - work.push({root, true}); - while (!work.empty()) { - ComputationDataHandle handle; - bool enter; - std::tie(handle, enter) = work.top(); - work.pop(); - - if (enter) { - // We are entering 'handle'. The first time we enter 'handle', we add it - // to 'visited' with a nullptr value. If 'handle' is already in 'visited', - // we do not visit it again. This algorithm only uses the presence of - // a handle in 'visited', but we use a map so we can use the same data - // structure to store the HloInstruction outputs. - if (visited->emplace(handle.handle(), nullptr).second) { - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - // Push the corresponding 'leave' action onto the stack, followed by - // the operands. - work.push({handle, false}); - ForEachOperand(request, [&work](const ComputationDataHandle& child) { - work.push({child, true}); - }); - } - } else { - // We are leaving 'handle'. We have visited the operands of 'handle', and - // now can visit the 'handle' itself. - visit(handle); - } - } -} - -StatusOr> ComputationLowerer::Lower() { - // Map from ComputationDataHandle to HLO instruction. Serves as a record of - // which operations have been visited as well as a cache for looking up - // ComputationDataHandles as HloInstructions. - std::unordered_map instructions; - - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version_, session_computation_)); - - auto visit = [&](const ComputationDataHandle& handle) { - Visit(handle, &instructions); - }; - TraversePostorder(root_request->output_handle(), &instructions, visit); - HloInstruction* hlo_root = - instructions.at(root_request->output_handle().handle()); - - if (include_unreachable_instructions_) { - // Iterate through all computation data handles, and visit any unvisited - // operations. - for (int64 request_num = 1; request_num <= version_; ++request_num) { - TF_ASSIGN_OR_RETURN(const OperationRequest* request, - LookUpRequest(request_num, session_computation_)); - TraversePostorder(request->output_handle(), &instructions, visit); - } - } - - return hlo_builder_.Build(hlo_root); -} - -HloComputation* ComputationLowerer::ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version) { - const VersionedComputationHandle checked_handle = {handle, version}; - return hlo_resolver_(checked_handle); -} - -HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape) { - auto fadd = [this](std::unique_ptr x) { - return hlo_builder_.AddInstruction(std::move(x)); - }; - return fadd( - HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); -} - -void ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::unordered_map* instructions) { - CHECK_LE(handle.handle(), version_); - CHECK(instructions->at(handle.handle()) == nullptr); - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - auto add_instruction = [&](std::unique_ptr instruction) { - HloInstruction* hlo_instruction = - hlo_builder_.AddInstruction(std::move(instruction)); - hlo_instruction->set_metadata(request.request().metadata()); - if (request.request().has_sharding()) { - OpSharding op_sharding = request.request().sharding(); - hlo_instruction->set_sharding( - HloSharding::FromProto(op_sharding).ValueOrDie()); - } - return hlo_instruction; - }; - auto lookup_instruction = [&](const ComputationDataHandle& handle) { - return instructions->at(handle.handle()); - }; - HloInstruction* hlo_instruction; - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - std::vector parameters; - for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(lookup_instruction(param)); - } - hlo_instruction = add_instruction(HloInstruction::CreateRng( - request.output_shape(), rng_request.distribution(), parameters)); - break; - } - - case OpRequest::kConstantRequest: { - const ConstantRequest& constant_request = - request.request().constant_request(); - hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal::CreateFromProto(constant_request.literal()) - .ConsumeValueOrDie())); - break; - } - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - HloInstruction* operand = - lookup_instruction(get_tuple_element_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, get_tuple_element_request.index())); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = lookup_instruction(slice_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSlice( - request.output_shape(), operand, - AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_slice_request.operand()); - HloInstruction* start_indices = - lookup_instruction(dynamic_slice_request.start_indices()); - - hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_update_slice_request.operand()); - HloInstruction* update = - lookup_instruction(dynamic_update_slice_request.update()); - HloInstruction* start_indices = - lookup_instruction(dynamic_update_slice_request.start_indices()); - hlo_instruction = - add_instruction(HloInstruction::CreateDynamicUpdateSlice( - request.output_shape(), operand, update, start_indices)); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( - request.output_shape(), operands, concatenate_request.dimension())); - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); - HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - HloInstruction* operand = lookup_instruction(fft_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateFft( - request.output_shape(), operand, fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - HloInstruction* lhs = lookup_instruction(dot_request.lhs()); - HloInstruction* rhs = lookup_instruction(dot_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateDot( - request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - HloInstruction* operand = - lookup_instruction(cross_replica_sum_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), {operand})); - break; - } - - case OpRequest::kInfeedRequest: { - const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = add_instruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); - break; - } - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - HloInstruction* operand = lookup_instruction(outfeed_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( - outfeed_request.shape(), operand, outfeed_request.outfeed_config())); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - std::vector operands; - for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version map_version = - request.embedded_computation_versions(0); - HloComputation* map_computation = - ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = add_instruction(HloInstruction::CreateMap( - request.output_shape(), operands, map_computation)); - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = lookup_instruction(reduce_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_version = - request.embedded_computation_versions(0); - HloComputation* reduce_computation = - ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - HloInstruction* operand = - lookup_instruction(reduce_window_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_window_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_window_version = - request.embedded_computation_versions(0); - HloComputation* reduce_window_computation = ResolveComputation( - reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - HloInstruction* operand = - lookup_instruction(select_and_scatter_request.operand()); - HloInstruction* source = - lookup_instruction(select_and_scatter_request.source()); - HloInstruction* init_value = - lookup_instruction(select_and_scatter_request.init_value()); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version select_version = - request.embedded_computation_versions(0); - VersionedComputationHandle::Version scatter_version = - request.embedded_computation_versions(1); - HloComputation* select_computation = ResolveComputation( - select_and_scatter_request.select(), select_version); - HloComputation* scatter_computation = ResolveComputation( - select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_training_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_training_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_training_request.offset()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( - request.output_shape(), operand, scale, offset, - batch_norm_training_request.epsilon(), - batch_norm_training_request.feature_index())); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_inference_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_inference_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_inference_request.offset()); - HloInstruction* mean = - lookup_instruction(batch_norm_inference_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_inference_request.variance()); - - hlo_instruction = - add_instruction(HloInstruction::CreateBatchNormInference( - request.output_shape(), operand, scale, offset, mean, variance, - batch_norm_inference_request.epsilon(), - batch_norm_inference_request.feature_index())); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - HloInstruction* operand = - lookup_instruction(batch_norm_grad_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_grad_request.scale()); - HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_grad_request.variance()); - HloInstruction* grad_output = - lookup_instruction(batch_norm_grad_request.grad_output()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( - request.output_shape(), operand, scale, mean, variance, grad_output, - batch_norm_grad_request.epsilon(), - batch_norm_grad_request.feature_index())); - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - HloInstruction* operand = lookup_instruction(broadcast_request.operand()); - std::vector broadcast_dimensions; - // The client-level broadcast instruction just appends dimensions on the - // left (adds lowest numbered dimensions). The HLO broadcast op is more - // flexible and can add new dimensions anywhere. The broadcast_dimensions - // maps operand dimensions to dimensions in the broadcast output, so - // to append dimensions on the left the broadcast_dimensions should just - // be the n highest dimension numbers of the output shape where n is - // the number of input dimensions. - broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { - broadcast_dimensions.push_back(i + - ShapeUtil::Rank(request.output_shape()) - - ShapeUtil::Rank(operand->shape())); - } - hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - HloInstruction* operand = lookup_instruction(reshape_request.operand()); - HloInstruction* transposed; - if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { - transposed = operand; - } else { - transposed = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); - } - hlo_instruction = add_instruction( - HloInstruction::CreateReshape(request.output_shape(), transposed)); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - HloInstruction* operand = lookup_instruction(transpose_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - HloInstruction* operand = lookup_instruction(reverse_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = lookup_instruction(pad_request.operand()); - HloInstruction* padding_value = - lookup_instruction(pad_request.padding_value()); - hlo_instruction = add_instruction(HloInstruction::CreatePad( - request.output_shape(), operand, padding_value, - pad_request.padding_config())); - break; - } - - case OpRequest::kRecvRequest: { - const RecvRequest& recv_request = request.request().recv_request(); - HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( - request.output_shape(), recv_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - hlo_instruction = add_instruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateConvert(request.output_shape(), operand)); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( - request.output_shape(), operand)); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version condition_version = - request.embedded_computation_versions(0); - HloComputation* condition = - ResolveComputation(while_request.condition(), condition_version); - VersionedComputationHandle::Version body_version = - request.embedded_computation_versions(1); - HloComputation* body = - ResolveComputation(while_request.body(), body_version); - HloInstruction* init = lookup_instruction(while_request.init()); - hlo_instruction = add_instruction(HloInstruction::CreateWhile( - request.output_shape(), condition, body, init)); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version true_computation_version = - request.embedded_computation_versions(0); - HloComputation* true_computation = ResolveComputation( - conditional_request.true_computation(), true_computation_version); - VersionedComputationHandle::Version false_computation_version = - request.embedded_computation_versions(1); - HloComputation* false_computation = ResolveComputation( - conditional_request.false_computation(), false_computation_version); - HloInstruction* predicate = - lookup_instruction(conditional_request.predicate()); - HloInstruction* true_operand = - lookup_instruction(conditional_request.true_operand()); - HloInstruction* false_operand = - lookup_instruction(conditional_request.false_operand()); - hlo_instruction = add_instruction(HloInstruction::CreateConditional( - request.output_shape(), predicate, true_operand, true_computation, - false_operand, false_computation)); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); - HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); - auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - !ShapeUtil::IsTuple(request.output_shape())) { - if (!ShapeUtil::IsTuple(lhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(rhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(ehs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { - ehs = - ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); - } - } - - hlo_instruction = add_instruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - auto hlo_opcode = - VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = add_instruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - std::vector operands; - for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(lookup_instruction(handle)); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version call_version = - request.embedded_computation_versions(0); - HloComputation* call_computation = - ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = add_instruction(HloInstruction::CreateCall( - request.output_shape(), operands, call_computation)); - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - std::vector operands; - for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& host_compute_request = - request.request().host_compute_request(); - std::vector operands; - for (const ComputationDataHandle& operand : - host_compute_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - auto output_shape = host_compute_request.shape(); - auto channel_name = host_compute_request.channel_name(); - auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); - hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( - output_shape, operands, channel_name, cost_estimate_ns)); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - HloInstruction* operand = lookup_instruction(unary_op_request.operand()); - auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = add_instruction(HloInstruction::CreateUnary( - request.output_shape(), hlo_opcode, operand)); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); - auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0 && - ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { - // Emit a broadcast instruction to perform the "broadcast in dimension" - // operation. - HloInstruction* operand_to_broadcast = - ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs - : rhs; - CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), - binary_op_request.broadcast_dimensions().size()); - - // Construct the bounds of the shape of the kBroadcast instruction - // responsible for the in-dimension broadcast. - std::vector output_dimensions; - for (int64 size : request.output_shape().dimensions()) { - output_dimensions.push_back(size); - } - for (int64 operand_dim = 0; - operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); - ++operand_dim) { - int64 output_dim = - binary_op_request.broadcast_dimensions()[operand_dim]; - output_dimensions[output_dim] = - operand_to_broadcast->shape().dimensions(operand_dim); - } - - Shape broadcast_shape = ShapeUtil::MakeShape( - operand_to_broadcast->shape().element_type(), output_dimensions); - - // The broadcast semantics of a client-level binary op broadcast is - // identical to the HLO broadcast semantics so the broadcast_dimensions - // field can just be passed to the instruction builder. - HloInstruction* broadcasted_operand = - add_instruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand_to_broadcast, - AsInt64Slice(binary_op_request.broadcast_dimensions()))); - - lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; - rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; - } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - } - hlo_instruction = add_instruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - HloInstruction* operand = - lookup_instruction(reduce_precision_request.operand()); - auto exponent_bits = reduce_precision_request.exponent_bits(); - auto mantissa_bits = reduce_precision_request.mantissa_bits(); - hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( - request.output_shape(), operand, exponent_bits, mantissa_bits)); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - HloInstruction* operand = lookup_instruction(trace_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - HloInstruction* operand = lookup_instruction(send_request.operand()); - HloInstruction* send = add_instruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - HloInstruction* input_operand = - lookup_instruction(gather_request.input()); - HloInstruction* gather_indices_operand = - lookup_instruction(gather_request.gather_indices()); - std::vector window_bounds; - c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); - hlo_instruction = add_instruction(HloInstruction::CreateGather( - request.output_shape(), input_operand, gather_indices_operand, - gather_request.dimension_numbers(), window_bounds)); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - (*instructions)[handle.handle()] = hlo_instruction; -} // NOLINT(readability/fn_size) - -} // namespace - -StatusOr> UserComputation::BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), debug_options, - include_unreachable_instructions)); - - return std::move(hlo_computation); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h deleted file mode 100644 index 5544c868fe905c..00000000000000 --- a/tensorflow/compiler/xla/service/user_computation.h +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// A UserComputation is the built-up computation that users create via the -// XLA Service interface. -// -// The XLA service adds instructions to a user computation via this -// interface. The state of the computation is stored as a SessionComputation -// proto which holds a record of all operation-building requests received by the -// XLA service. -// -// UserComputations are lowered to HloComputations which are passed to the high -// level compiler interface. -class UserComputation { - public: - // Factory used when restoring a computation from serialized session - // computation (computation snapshot) data. Remaps any references to - // computation handle via the old_to_new mapping. - // - // An error will occur if the old_to_new mapping cannot resolve a reference to - // a computation that is present in session_computation. - static StatusOr> MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new); - - // Creates an empty computation with the given name and computation handle. - explicit UserComputation(const string& name, const ComputationHandle& handle); - - // Enqueues a parameter-retrieving instruction onto this user computation. - // Returns an error status if the parameter number is already registered with - // different values. - StatusOr AddParameterInstruction( - const ParameterRequest& parameter_request); - - // Enqueues a pad instruction onto this user computation. - StatusOr AddPadInstruction( - const PadRequest& pad_request); - - // Enqueues a tracing instruction onto this user computation. - // Returns an error status if the operand cannot be resolved. - Status AddTraceInstruction(const TraceRequest& trace_request); - - // Enqueues a random number generation instruction onto this user computation. - StatusOr AddRngInstruction( - const RngRequest& rng_request); - - // Enqueues a unary instruction onto this user computation. - // Returns an error status if the operand index is out of bounds. - StatusOr AddUnaryInstruction( - const UnaryOpRequest& unary_request); - - // Enqueues a batch norm training instruction onto this user computation. - StatusOr AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request); - - // Enqueues a batch norm inference instruction onto this user computation. - StatusOr AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request); - - // Enqueues a batch norm grad instruction onto this user computation. - StatusOr AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request); - - // Enqueues a binary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddBinaryInstruction( - const BinaryOpRequest& binary_request); - - // Enqueues a ternary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddTernaryInstruction( - const TernaryOpRequest& ternary_request); - - // Enqueues a variadic instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddVariadicInstruction( - const VariadicOpRequest& variadic_request); - - // Enqueues a constant instruction onto this user computation. - StatusOr AddConstantInstruction( - const ConstantRequest& constant_request); - - // Enqueues a get tuple element instruction onto this user computation. - StatusOr AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request); - - // Enqueues a map instruction onto this user computation. - StatusOr AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation); - - // Enqueues a reduce-precision instruction onto this user computation. - StatusOr AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request); - - // Enqueues a convolution instruction onto this user computation. - StatusOr AddConvolveInstruction( - const ConvolveRequest& convolve_request); - - // Enqueues an FFT instruction onto this user computation. - StatusOr AddFftInstruction( - const FftRequest& fft_request); - - // Enqueues a cross replica sum instruction onto this user computation. - StatusOr AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request); - - // Enqueues an infeed instruction onto this user computation. - StatusOr AddInfeedInstruction( - const InfeedRequest& infeed_request); - - // Enqueues an outfeed instruction onto this user computation. - StatusOr AddOutfeedInstruction( - const OutfeedRequest& outfeed_request); - - // Enqueues a host compute instruction onto this user computation. - StatusOr AddHostComputeInstruction( - const HostComputeRequest& host_compute_request); - - // Enqueues a call instruction onto this user computation. - StatusOr AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation); - - // Enqueues a custom call instruction onto this user computation. - StatusOr AddCustomCallInstruction( - const CustomCallRequest& custom_call_request); - - // Enqueues a dot instruction onto this user computation. - StatusOr AddDotInstruction( - const DotRequest& dot_request); - - // Enqueues a broadcast instruction onto this user computation. - StatusOr AddBroadcastInstruction( - const BroadcastRequest& broadcast_request); - - // Enqueues a reshape instruction onto this user computation. - StatusOr AddReshapeInstruction( - const ReshapeRequest& reshape_request); - - // Enqueues a transpose instruction onto this user computation. - StatusOr AddTransposeInstruction( - const TransposeRequest& transpose_request); - - // Enqueues a slice instruction onto this user computation. - StatusOr AddSliceInstruction( - const SliceRequest& slice_request); - - // Enqueues a dynamic slice instruction onto this user computation. - StatusOr AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request); - - // Enqueues a dynamic update slice instruction onto this user computation. - StatusOr AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request); - - // Enqueues a concatenate instruction onto this user computation. - StatusOr AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request); - - // Enqueues a convert instruction onto this user computation. - StatusOr AddConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a bitcast element instruction onto this user computation. - StatusOr AddBitcastConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a reduce instruction onto this user computation. - StatusOr AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation); - - // Enqueues a windowed reduce instruction onto this user computation. - StatusOr AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation); - - // Enqueues a select-and-scatter instruction onto this user - // computation. - StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation); - - // Enqueues a reverse instruction onto this user computation. - StatusOr AddReverseInstruction( - const ReverseRequest& reverse_request); - - // Enqueues a while instruction onto this user computation. - StatusOr AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation); - - // Enqueues a conditional instruction on this user computation. - StatusOr AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation); - - // Enqueues a Send instruction onto this user computation. - StatusOr AddSendInstruction( - const SendRequest& send_request); - - // Enqueues a Recv instruction onto this user computation. - StatusOr AddRecvInstruction( - const RecvRequest& recv_request); - - // Enqueues a Gather instruction onto this user computation. - StatusOr AddGatherInstruction( - const GatherRequest& gather_request); - - // Returns the user-provided name of this user computation, which is provided - // via the XLA computation-building API. - const string& name() const { return name_; } - - // Subsequent executions of this computation will compute the value - // represented by handle, rather than the last expression enqueued - // on the computation. - Status SetReturnValue(const ComputationDataHandle& handle); - - // Return a versioned handle for this computation. - VersionedComputationHandle GetVersionedHandle() const; - - // Return a versioned handle for this computation with a version equal to the - // point at which given operation was added to the computation. - VersionedComputationHandle GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const; - - // Return a version value representing the current state of the - // computation. - VersionedComputationHandle::Version version() const; - - // Computes and returns the program shape for the user computation -- gathers - // parameters and result type into a single proto. A shared_ptr is used - // because the returned pointer refers to an internally cached value which may - // be discarded by the UserComputation object. This avoid unnecessary copies. - // - // If the parameter space is not dense (i.e. there are holes in the parameter - // numbers provided) then an error status is returned. - StatusOr> ComputeProgramShape( - VersionedComputationHandle::Version version) const; - - // Returns true if the given data handle does not depend on any parameter with - // index higher then num_parameters. That is, the value can be computed at - // compile time if we know the first num_parameters arguments. - StatusOr IsConstant(const ComputationDataHandle& handle, - int64 num_parameters); - - // Returns the output shape of the operation indicated by the given handle. - StatusOr GetShape(const ComputationDataHandle& handle); - - // Sets metadata on the Hlo instruction referenced by the given handle. - Status SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata); - - // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding); - - // Builds a HLO computation from the UserComputation. The parameter "resolver" - // is a function which returns a pointer to the HloComputation corresponding - // to the given ComputationHandle at the given version. The resolver is used - // for operations, such as map, which call other computations and need a - // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unreachable_instructions is true, then - // instructions which are not reachable from the root are lowered into - // HloInstructions. - using HloComputationResolver = - std::function; - StatusOr> BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions = true) const; - - // Return a vector containing the embedded computations used by this - // UserComputation. Only embedded computations which are called directly by - // this UserComputation are included. That is, the transitive closure of - // embedded computations is not included. - std::vector GetEmbeddedComputations( - VersionedComputationHandle::Version version) const; - - // Returns the number of OperationRequest objects in this UserComputation. - // The 'version' of a computation is identical to the number of - // OperationRequests in the UserComputation. - int64 request_count(VersionedComputationHandle::Version version) const { - return version; - } - - // Returns a copy of the internal session state for this computation -- this - // is useful for serializing the guts of a user computation, though references - // to other handles (e.g. referred-to computations) must be handled with care - // in the serialization / de-serialization process. - SessionComputation CloneSessionComputation( - VersionedComputationHandle::Version version) const; - - // Warning: typically we don't want to look up computation data handles until - // the computation is finished being built, for consistency purposes. We - // expose this routine for error reporting purposes so that we can provide - // more meaningful error messages from the XLA service layer. - // - // Returns the operation request that the handle comes from. - StatusOr LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const; - - // Retrieves the parameter metadata for the given parameter number. - // - // If the parameter number is invalid for this computation, nullopt is - // returned. When the return value has_value(), nullptr will never be - // the held value. - tensorflow::gtl::optional ParameterMetadata( - int parameter_number) const; - - private: - // Warning: dangerous mutating operation that doesn't respect versioning. - // This is only used at initialization time when constructing from a - // SessionComputation a la MakeWithRemapping. - // - // Remaps references to old computations (with handle values in the keys of - // old_to_new) to the computation handle given in the values. This is useful - // when loading computations from snapshots, to finish initialization, before - // the user computation is released into the wild. - Status RemapEmbeddedComputations( - const std::map& old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle. - StatusOr LookUpRequest( - const ComputationDataHandle& handle) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Creates a new ComputationDataHandle with the next available handle value. - ComputationDataHandle CreateComputationDataHandle() - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Checks whether the parameter numbers of the parameter operations are - // contiguous starting from zero. Returns appropriate error status if not. - Status CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - VersionedComputationHandle GetVersionedHandleInternal() const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Name of the computation. - string name_; - - mutable tensorflow::mutex mutex_; - - // State of the computation as a record of all operation-building requests. - SessionComputation session_computation_ GUARDED_BY(mutex_); - - // Mapping from parameter number to operation request containing the - // respective ParameterRequest. - std::map parameters_ GUARDED_BY(mutex_); - - // The next ComputationDataHandle value to assign. Handle values are assigned - // sequentially. - int64 next_handle_value_ GUARDED_BY(mutex_); - - // If handle_to_return_.has_handle() then an Execution of this Computation - // will compute the value represented by handle_to_return_, otherwise it will - // compute the value of (next_handle_value_ - 1). - ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); - - // Memoized ProgramShape and its version. A shared_ptr is used because - // references to this object are returned by ComputeProgramShape. - mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; - mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc deleted file mode 100644 index 2fa163953f638c..00000000000000 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using UserComputationTest = ::testing::Test; - -TEST_F(UserComputationTest, SimpleComputation) { - const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); - const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); - - // Build a simple three operation computatation: - // - // %constant = Constant({123, 42}) - // %param = Param(0) - // %outfeed = Outfeed(%constant) - // - // Build the computation at two different versions and check invariants. - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest constant_request; - *constant_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); - - ParameterRequest param_request; - *param_request.mutable_shape() = kScalarShape; - param_request.set_parameter(0); - param_request.set_name("param0"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); - OpMetadata metadata; - metadata.set_op_name("meta"); - TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); - - OutfeedRequest outfeed_request; - *outfeed_request.mutable_operand() = constant_handle; - *outfeed_request.mutable_shape() = kVectorShape; - outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, - computation.AddOutfeedInstruction(outfeed_request)); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - { - // Test the computation at the latest version. In this case, the most - // recently added operation is an outfeed. However, the outfeed is not the - // root because outfeeds cannot be the root of a computation. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Program shape should have a single scalar parameter and scalar - // result. The outfeed instruction should not affect the program shape. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(latest_version.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - DebugOptions())); - // There should be one HloInstruction per UserComputation operation. - EXPECT_EQ(3, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - - { - // Test the computation at the version right after the parameter instruction - // is added. - VersionedComputationHandle version_at_param = - computation.GetVersionedHandleAtOperation(param_handle); - - // Program shape should have a single scalar parameter, and scalar result. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(version_at_param.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // There should be two instructions, one for the constant and one for the - // parameter. The outfeed instruction should not be included. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(version_at_param.version, hlo_resolver, - DebugOptions())); - EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - { - // Test the computation at the latest version, but lowered with - // include_unreachable_instructions set to false. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, DebugOptions(), - /*include_unreachable_instructions=*/false)); - // There is only one reachable instruction, the parameter. - EXPECT_EQ(1, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), - "meta"); - } -} - -TEST_F(UserComputationTest, EliminateScalarBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with scalar broadcast. - // - // %a = Constant({123, 42}) - // %b = Constant(1) - // %add = Add(%a, %b) - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest a_request; - *a_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); - - ConstantRequest b_request; - *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - // The binary operation has implicit scalar broadcast, should be converted - // to an explicit broadcast intruction and a binary instruction. - EXPECT_EQ(4, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - LOG(INFO) << hlo_computation->root_instruction()->ToString(); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with degenerate broadcast. - // - // %a = Param({1, 2, 3}); - // %b = Param({1, 2, 1}); - // %add = Add(%a, %b, {}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - const int64 kDevice = 7; - OpSharding sharding; - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(kDevice); - - TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // b a - // | | - // reshape | - // | | - // broadcast | - // \ / - // add - EXPECT_EQ(5, hlo_computation->instruction_count()); - ASSERT_THAT( - hlo_computation->root_instruction(), - op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); - - const HloInstruction* broadcast = - hlo_computation->root_instruction()->operand(1); - EXPECT_TRUE(broadcast->has_sharding()); - - const HloInstruction* reshape = broadcast->operand(0); - EXPECT_TRUE(reshape->has_sharding()); -} - -TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with in-dim broadcast and degenerate broadcast. - // - // %a = Param({2, 3}); - // %b = Param({2, 1, 4}); - // %add = Add(%a, %b, {0, 1}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - add.add_broadcast_dimensions(0); - add.add_broadcast_dimensions(1); - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // b a - // | | - // broadcast reshape - // | | - // | broadcast - // \ / - // add - EXPECT_EQ(6, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc new file mode 100644 index 00000000000000..10fc4958fae064 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Replaces all uses of old_instr with new_instr except the use at +// `while_body_root` (which must be a tuple instruction) at index `tuple_index`. +// This utility helps us replace an instruction in the while body with a +// constant while still keeping it trivially loop invariant. +static Status ReplaceUsesWhileKeepingLoopInvariance( + HloInstruction* old_instr, HloInstruction* new_instr, + HloInstruction* while_body_root, int64 tuple_index) { + CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple); + + std::vector users; + users.reserve(old_instr->user_count()); + c_copy(old_instr->users(), std::back_inserter(users)); + + for (auto* user : users) { + for (int64 i = 0, e = user->operand_count(); i < e; i++) { + if (user->operand(i) == old_instr && + !(user == while_body_root && i == tuple_index)) { + TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr)); + } + } + } + + return Status::OK(); +} + +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( + HloInstruction* while_instr) { + HloComputation* while_body = while_instr->while_body(); + + const HloInstruction& init_value = *while_instr->operand(0); + if (init_value.opcode() != HloOpcode::kTuple) { + return false; + } + + bool changed = false; + + for (HloInstruction* invariant_gte : + WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + int64 index = invariant_gte->tuple_index(); + const HloInstruction& invariant_value = *init_value.operand(index); + if (invariant_value.opcode() == HloOpcode::kConstant) { + auto* constant_instr = + while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( + invariant_gte, constant_instr, while_body->root_instruction(), + index)); + changed = true; + } + } + + return changed; +} + +StatusOr WhileLoopConstantSinking::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + // Right now we don't particulary care about optimizing while-of-while + // patterns. If/When we do, we'll want to visit the outer while (while_0) + // before we visit the inner while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(constant, ...), body=while_0_body, ...) + // } + // + // This will let us sink the constant into the outer while first and then + // into the inner while in a single run of this pass. + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // We only sink into while loop bodies, but this can be extended to + // transform conditions as well. + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileBody(while_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h new file mode 100644 index 00000000000000..21fb8568a84985 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Sinks while loop invariant values that happen to be constants into the while +// loop body. This is probably not a win in isolation but may unlock further +// optimizations like constant folding. +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(v) +// state = (..., v, ...) +// } +// +// => +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(const) +// state = (..., v, ...) +// } +// +// Note that it leaves the `v` in place to keep that component of the state +// tuple trivially loop invariant. WhileLoopSimplifier will later get rid of +// `v`. +// +// We only sink into while loop bodies, but this can be extended to transform +// conditions as well. +// +// TODO(b/79121449): We should also sink broadcasts of constants. +class WhileLoopConstantSinking : public HloPassInterface { + public: + ~WhileLoopConstantSinking() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc new file mode 100644 index 00000000000000..393e75803888d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +class WhileLoopConstantSinkingTest : public ::testing::Test {}; + +TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant()), _)); +} + +TEST_F(WhileLoopConstantSinkingTest, KeepConstantsLoopInvariant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=1 + p_body.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=2 + + add.0 = f32[2] add(p_body.1, p_body.2) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, TupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],(f32[2],f32[2])) parameter(0) + p_b.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=0 + p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1 + + p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0 + + ROOT root = (f32[2],f32[2],f32[2]) tuple(p_b.1.1, p_b.1) +} + +condition { + p_cond = (f32[2],(f32[2],f32[2])) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) + ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(op::Constant(), 0), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, DuplicateGTEs) { + // This test shows that the pass fails to optimize non-canonical IR. + // + // Even though the input IR has a constant value for p_b.2.dup, + // WhileLoopConstantSinking doesn't try to detect this. Instead, it relies on + // prior runs of HLO CSE to have commoned these identical GTE instructions. + + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],f32[2],f32[2]) parameter(0) + + p_b.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=1 + p_b.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + p_b.2.dup = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + + add.0 = f32[2] add(p_b.1, p_b.2.dup) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_b.1, p_b.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 3ef0cdff675125..09ddcffb22c218 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -115,26 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -// Populates `gte_set` with the GetTupleElement instructions in `while_body` -// that access elements in the parameter tuple that don't change across -// iterations. Assumes `while_body` is the body computation of the while loop -// in question. -static void GatherInvariantGTEs(HloComputation* while_body, - FlatSet* gte_set) { - const HloInstruction::InstructionVector root_operands = - while_body->root_instruction()->operands(); - for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && - instr->operand(0) == while_body->parameter_instruction(0) && - ShapeUtil::IsArray(instr->shape())) { - InsertOrDie(gte_set, instr); - } - } -} - -static StatusOr TryHoistingInvariantInstructionsFromWhileBody( +StatusOr +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -172,14 +157,24 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( // unhoisted_invariant_instructions -- they can be legally hoisted, but there // is no benefit to hoisting them unless something that uses it is also // hoisted. - GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + if (ShapeUtil::IsArray(instr->shape())) { + // TODO(b/79147885): We should try to generalize this to tuples for + // uniformity's sake, if nothing else. + InsertOrDie(&unhoisted_invariant_instructions, instr); + } + } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -256,6 +251,9 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { @@ -283,6 +281,14 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b0003c4..8e6cc8787576e4 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 799340fda905fb..8831c513eee66e 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { EXPECT_FALSE(simplified_loop); } +const char* const kConstantHoistingTestCase = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2]{0}) parameter(0) + p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 + const = f32[2]{0} constant({3, 4}) + add.0 = f32[2]{0} add(p_body.1, const) + ROOT root = (f32[2]{0}) tuple(add.0) +} + +condition { + p_cond = (f32[2]{0}) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2]{0} constant({1, 2}) + while_init = (f32[2]{0}) tuple(const_0) + ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = module().GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + + // We expect the while body to be the equivalent of: + // + // wide.body { + // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0) + // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0 + // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1) + // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0 + // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7) + // tuple.3 = (f32[2]{0}) tuple(add.1) + // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0 + // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8, + // get-tuple-element.9) + // } + + auto wide_param_1 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0); + auto tuple_1 = op::Tuple(get_tuple_element_1); + auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0); + auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1); + auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7); + auto tuple_3 = op::Tuple(add_1); + auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0); + auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1); + auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9); + + EXPECT_THAT(while_body->root_instruction(), tuple_4); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index bd0794184328b7..473eab2ea84eb8 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while = containing_computation->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); - TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( - while_instr, TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + + // We want to get rid of the old while instruction even if it has side + // effecting operations so we do a manual HloComputation::RemoveInstruction + // instead of relying on HloComputation::ReplaceInstruction. + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; @@ -244,4 +248,21 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { } return result; } + +/*static*/ std::vector WhileUtil::GetInvariantGTEsForWhileBody( + const HloComputation& while_body) { + std::vector result; + const HloInstruction::InstructionVector root_operands = + while_body.root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body.parameter_instruction(0)) { + result.push_back(instr); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 1688d4674269c3..e67636d80f4b68 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -38,17 +38,21 @@ class WhileUtil { }; // Replaces `while_instr` with a new while instruction that is equivalent to - // `while_instr`, except that it has all of the HLO instructions in + // `while_instr` except that it has all of the HLO instructions in // `instructions` as live-in, loop invariant values. These new live in values // are represented as new elements appended to the parameter of the while // loop, which must be of tuple shape. GetTupleElement instructions computing // each new live in value is returned in the `while_body_live_in_values` // vector. // - // Precondition: `while_instr` must have a tuple shaped state. + // Deletes `while_instr` after replacing it. // - // Every instruction in `instructions` must be contained in the computation - // that contains `while_instr`. + // Preconditions: + // + // `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); @@ -74,6 +78,12 @@ class WhileUtil { HloComputation* computation, int32 trip_count, const LoopStateTy& init_values, const LoopBodyGeneratorTy& loop_body_generator); + + // Returns the GetTupleElement instructions in `while_body` that access + // elements in the parameter tuple that don't change across iterations. + // Assumes `while_body` is the body computation of the while loop in question. + static std::vector GetInvariantGTEsForWhileBody( + const HloComputation& while_body); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index cf0d0db99bd92b..d79d3297213e83 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace { @@ -49,7 +50,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -126,5 +127,84 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) { op::GetTupleElement(op::Parameter(0), 3))); } +TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* while_body = module->GetComputationWithName("body"); + + ASSERT_NE(while_body, nullptr) + << "Expected exactly one while_body computation"; + + std::vector gte_list = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + ASSERT_EQ(gte_list.size(), 1); + EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); +} + +TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { + const char* const hlo_string = R"( +HloModule WhileWithSideEffects + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT condition = pred[] infeed() +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + to_make_live_in = f32[100] parameter(1) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* main = module->GetComputationWithName("main"); + HloInstruction* while_instr = main->root_instruction(); + HloInstruction* to_make_live_in = main->parameter_instruction(1); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{to_make_live_in})); + + auto is_while = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }; + EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index 4f8cdc1e0e73cd..f5331280ee9f25 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,9 +45,9 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - HloModule module("zero_sized_elimination_test_module"); - module.AddEntryComputation(builder_.Build()); - return ZeroSizedHloElimination{}.Run(&module); + auto module = CreateNewModule("zero_sized_elimination_test_module"); + module->AddEntryComputation(builder_.Build()); + return ZeroSizedHloElimination{}.Run(module.get()); } HloComputation::Builder builder_; diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 5b44c26b7c7b08..14c35e7b84f07b 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -31,99 +32,52 @@ class ServiceInterface { virtual ~ServiceInterface() = default; // TODO(b/31824348): Convert to use StatusOr. - virtual tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + virtual Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + virtual Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) = 0; - virtual tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) = 0; - virtual tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) = 0; + virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; - virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) = 0; + virtual Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; + virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; + virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) = 0; - virtual tensorflow::Status ExecuteParallel( - const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) = 0; - virtual tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) = 0; - - virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; - - virtual tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; - - virtual tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; - - virtual tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; - - virtual tensorflow::Status GetComputationGraphStats( + virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; - - virtual tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) = 0; - - virtual tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) = 0; - - virtual tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - - // Methods used by ComputationBuilder. - virtual tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual tensorflow::Status SetReturnValue( - const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; - - virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; + virtual Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; - virtual tensorflow::Status ComputeConstant( - const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; - virtual tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) = 0; + virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) = 0; - // Methods used by Computation. - virtual tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; + virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; // Methods used by GlobalData. - virtual tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) = 0; + virtual Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 789eba5780d37e..7ee366b27a82bd 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -22,24 +22,24 @@ limitations under the License. namespace xla { -tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { +Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(other_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } shape_ = other_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { +Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } *to_shape = shape_; - return tensorflow::Status::OK(); + return Status::OK(); } void ShapeLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index a1dce758cd3ab3..36806da599cc9b 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -40,7 +40,7 @@ class ShapeLayout { // Assigns the layouts in this ShapeLayout to the Layout fields of the given // shape. 'to_shape' and the shape of the ShapeLayout object must be // compatible. - tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; + Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible @@ -49,7 +49,7 @@ class ShapeLayout { // Copies the layout from the given shape into this ShapeLayout. 'other_shape' // must be compatible with the ShapeLayout's shape. - tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); + Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index ffaa40c2d673a2..5b14953ebb243d 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -42,36 +42,20 @@ namespace internal { template struct ShapeTreeNode { // Data corresponding to this node. - T data; + std::pair data; - // Children of this node. - std::vector> children; + // Children of this node, as indices into the container's nodes_ array. + std::vector children; - ShapeTreeNode() = default; - explicit ShapeTreeNode(const T& data) : data(data) {} - - ShapeTreeNode(const ShapeTreeNode& other) - : data(other.data), children(other.children.size()) { - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - - ShapeTreeNode& operator=(const ShapeTreeNode& other) { - if (this != &other) { - data = other.data; - children.resize(other.children.size()); - for (size_t i = 0; i < children.size(); ++i) { - children[i] = ::xla::MakeUnique(*other.children[i]); - } - } - return *this; - } + explicit ShapeTreeNode(ShapeIndex index) + : ShapeTreeNode(std::move(index), T()) {} + ShapeTreeNode(ShapeIndex index, T data) + : data(std::move(index), std::move(data)) {} }; } // namespace internal -template +template class ShapeTreeIterator; // A ShapeTree is a recursive data structure which mirrors the structure of a @@ -95,10 +79,9 @@ class ShapeTreeIterator; // before its ShapeTree goes away. template class ShapeTree { - friend class ShapeTreeIterator; - friend class ShapeTreeIterator; - public: + using Node = internal::ShapeTreeNode; + // Default constructor creates a tree with a nil shape (i.e. an empty tuple). ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {} @@ -110,30 +93,12 @@ class ShapeTree { // alive longer than this ShapeTree. explicit ShapeTree(Shape shape); explicit ShapeTree(const Shape* shape); + explicit ShapeTree(const std::shared_ptr& shape); // Create ShapeTree with the given shape, and init_value for all nodes. ShapeTree(Shape shape, const T& init_value); ShapeTree(const Shape* shape, const T& init_value); - - ShapeTree(const ShapeTree& other) { *this = other; } - ShapeTree(ShapeTree&&) = default; - - ShapeTree& operator=(const ShapeTree& other) { - root_ = other.root_; - - // Fix up internal pointer if necessary. - if (other.shape_storage_) { - CHECK_EQ(other.shape_, other.shape_storage_.get()); - shape_storage_.reset(new Shape(*other.shape_)); - shape_ = shape_storage_.get(); - } else { - shape_ = other.shape_; - } - - return *this; - } - - ShapeTree& operator=(ShapeTree&& other) = default; + ShapeTree(const std::shared_ptr& shape, const T& init_value); // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). @@ -161,63 +126,70 @@ class ShapeTree { return Lookup(index)->children.empty(); } - // iterator implements a forward_iterator with value_type = - // std::pair - using iterator = ShapeTreeIterator; - using const_iterator = ShapeTreeIterator; + ShapeTree(const ShapeTree&) = default; + ShapeTree& operator=(const ShapeTree&) = default; + ShapeTree(ShapeTree&&) = default; + ShapeTree& operator=(ShapeTree&& other) = default; + + // iterator implements a bidirectional_iterator with + // value_type = std::pair. + // + // The iteration order is guaranteed to be a pre-order walk of the ShapeTree. + using iterator = + ShapeTreeIterator, typename std::vector::iterator, + std::pair>; + using const_iterator = + ShapeTreeIterator, + typename std::vector::const_iterator, + const std::pair>; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; // begin/end for iterating over all nodes. iterator begin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } iterator end() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/false); } // rbegin/rend for iterating over all nodes in reverse. - iterator rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - iterator rend() { - return iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + reverse_iterator rbegin() { return reverse_iterator(end()); } + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); } - const_iterator rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false, - /*reverse=*/true); - } - const_iterator rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false, - /*reverse=*/true); + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); + return iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.begin(), + /*iterate_leaves_only=*/true); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/false); + return const_iterator(&nodes_, nodes_.end(), + /*iterate_leaves_only=*/true); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -227,22 +199,32 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } - iterator leaf_rbegin() { - return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); } + reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); } + const_reverse_iterator leaf_rbegin() const { + return const_reverse_iterator(leaf_end()); } - iterator leaf_rend() { - return iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_reverse_iterator leaf_rend() const { + return const_reverse_iterator(leaf_begin()); } - const_iterator leaf_rbegin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true, - /*reverse=*/true); + + // Returns an iterator pointing to the given ShapeIndex. + // REQUIRES: index must exist in the ShapeTree. + iterator find(const ShapeIndex& index) { + Node* element = Lookup(index); + return iterator(&nodes_, typename std::vector::iterator(element), + /*iterate_leaves_only=*/false); } - const_iterator leaf_rend() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true, - /*reverse=*/true); + const_iterator find(const ShapeIndex& index) const { + Node* element = Lookup(index); + return iterator(&nodes_, + typename std::vector::const_iterator(element), + /*iterate_leaves_only=*/false); } + // Returns the number of leaf nodes in the tree. + int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -282,8 +264,6 @@ class ShapeTree { bool operator!=(const ShapeTree& other) const { return !(*this == other); } private: - using Node = internal::ShapeTreeNode; - // Initialize node->children based on 'shape'. All children are assigned the // the given 'init_value'. void InitChildren(const Shape& shape, const T& init_value, Node* node); @@ -292,136 +272,57 @@ class ShapeTree { // default-constructed data values. void InitChildren(const Shape& shape, Node* node); + // Returns the number of subshapes, including interior nodes, in shape. + int64 CountSubshapes(const Shape& shape); + // Helpers for traversing the shape via ForEachElement. The helpers // recursively traverse the subtree rooted at "index" (defined as in // ShapeUtil::GetSubshape). template - static Status ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index); + static Status ForEachHelper(const Fn& func, const std::vector& nodes); template - static Status ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index); + static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. Node* Lookup(const ShapeIndex& index); const Node* Lookup(const ShapeIndex& index) const; - // The root node, which contains all other nodes. - Node root_; + // The nodes in this shape tree. + std::vector nodes_; // If we own our Shape, this field contains it, and shape_ is a pointer into // here. Otherwise if we don't own our shape, this is nullptr. - std::unique_ptr shape_storage_; + std::shared_ptr shape_storage_; // The XLA shape mirrored in this ShapeTree. This is either // shape_storage_.get() or the Shape pointer passed to our constructor. const Shape* shape_; }; -// Internal iterator that performs a pre-order walk. This is copyable, but -// contains a vector so isn't cheap to copy. This also means post-increment is -// expensive. The iterator value_type is equivalent to a std::pair, similar to std::map. The non-const iterator's T& type can be mutated -// in-place. -template -class ShapeTreeIterator : public std::iterator> { +// Internal iterator that performs a pre-order walk. This is cheap to copy. +// The iterator value_type is equivalent to a +// std::pair&, similar to std::map. +template +class ShapeTreeIterator + : public std::iterator { public: - using value_type = - typename std::conditional, - std::pair>::type; - using NodeType = - typename std::conditional::Node, - typename ShapeTree::Node>::type; - - // Construct an iterator pointing at node. Node must either be the tree root - // or nullptr (which is equivalent to end() and should not be dereferenced or - // incremented). If iterate_leaves_only is true, the iterator will not include - // interior tree nodes, only leaves. If reverse is true, the iterator will - // visit nodes in the reverse of pre-order traversal. - ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) - : node_(node), - iterate_leaves_only_(iterate_leaves_only), - reverse_(reverse) { - if (node_) { - if (reverse_) { - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - } else { - if (!node_->children.empty() && iterate_leaves_only) { - ++*this; - } - } + ShapeTreeIterator(ContainerType* nodes, IteratorType node, + bool iterate_leaves_only) + : nodes_(nodes), + node_(std::move(node)), + iterate_leaves_only_(iterate_leaves_only) { + while (iterate_leaves_only && node_ != nodes_->end() && + !node_->children.empty()) { + ++node_; } } - ShapeTreeIterator(const ShapeTreeIterator& other) - : node_(other.node_), - stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_), - reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { - CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - if (reverse_) { - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second - 1; - stack_.pop_back(); - if (next_child_index < 0) { - if (!iterate_leaves_only_) { - // All children are visited, yield . - return *this; - } - } else { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - while (!node_->children.empty()) { - const int child_index = node_->children.size() - 1; - stack_.push_back({node_, child_index}); - node_ = node_->children[child_index].get(); - } - return *this; - } - } - } else { - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - // Otherwise we are currently at a leaf. Walk back up until a node - // contains a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); - } - } - } + ++node_; + while (iterate_leaves_only_ && node_ != nodes_->end() && + !node_->children.empty()) { + ++node_; } - // We've walked off the end of the tree. Set node_ to nullptr to signify - // end(). - node_ = nullptr; - current_.reset(); return *this; } ShapeTreeIterator operator++(int) { @@ -429,52 +330,62 @@ class ShapeTreeIterator : public std::iterator nodes_->begin() && + !node_->children.empty()) { + --node_; + } + return *this; + } + ShapeTreeIterator operator--(int) { + auto i = *this; + --(*this); + return i; + } + bool operator==(const ShapeTreeIterator& other) const { return node_ == other.node_; } bool operator!=(const ShapeTreeIterator& other) const { return node_ != other.node_; } - value_type& operator*() { return UpdateCurrent(); } - value_type* operator->() { return &UpdateCurrent(); } + ValueType& operator*() { return node_->data; } + ValueType* operator->() { return &node_->data; } private: - // Updates the current_ member to reflect the current state. - value_type& UpdateCurrent() { - ShapeIndex index; - for (auto& node_and_index : stack_) { - index.push_back(node_and_index.second); - } - current_ = ::xla::MakeUnique(index, node_->data); - return *current_; - } - - // The node to which this iterator is pointing. This is the source of truth in - // the iterator - the stack only exists to facilitate walking back from - // children to parents. - NodeType* node_; - // Stack of {node, child-index} pairs of the path taken from the root to get - // to node_. This allows us to backtrack and know where to go next. - std::vector> stack_; + ContainerType* nodes_; + IteratorType node_; // True if we should not include interior nodes in our walk. bool iterate_leaves_only_; - // True if we should yield the reverse of the pre-order traversal. - bool reverse_; - // Placeholder for the current value. Ideally this wouldn't exist and would - // just be an rvalue, but operator -> needs to return a pointer to something. - // We cannot just use a plain old value_type as it contains a reference so - // cannot be default-constructed. - std::unique_ptr current_; }; +template +int64 ShapeTree::CountSubshapes(const Shape& shape) { + int64 current_count = 1; + if (ShapeUtil::IsTuple(shape)) { + int64 count = ShapeUtil::TupleElementCount(shape); + for (int i = 0; i < count; ++i) { + current_count += CountSubshapes(shape.tuple_shapes(i)); + } + } + return current_count; +} + template void ShapeTree::InitChildren(const Shape& shape, const T& init_value, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node(init_value)); - InitChildren(shape.tuple_shapes(i), init_value, - node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index, init_value); + InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back()); } } } @@ -482,63 +393,92 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, template void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - node->children.emplace_back(new Node()); - InitChildren(shape.tuple_shapes(i), node->children.back().get()); + const int64 size = ShapeUtil::TupleElementCount(shape); + node->children.reserve(size); + ShapeIndex shape_index = node->data.first; + shape_index.push_back(0); + for (int i = 0; i < size; ++i) { + shape_index[shape_index.size() - 1] = i; + node->children.push_back(nodes_.size()); + nodes_.emplace_back(shape_index); + InitChildren(shape.tuple_shapes(i), &nodes_.back()); } } } template ShapeTree::ShapeTree(Shape shape) - : root_(), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const Shape* shape) : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template -ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { - InitChildren(*shape_, &root_); +ShapeTree::ShapeTree(const std::shared_ptr& shape) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}); + InitChildren(*shape_, &nodes_[0]); } template ShapeTree::ShapeTree(Shape shape, const T& init_value) - : root_(init_value), - shape_storage_(::xla::MakeUnique(std::move(shape))), + : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. LayoutUtil::ClearLayout(shape_storage_.get()); - InitChildren(*shape_, init_value, &root_); + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template ShapeTree::ShapeTree(const Shape* shape, const T& init_value) - : root_(init_value), shape_(shape) { - InitChildren(*shape_, init_value, &root_); + : shape_(shape) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); +} + +template +ShapeTree::ShapeTree(const std::shared_ptr& shape, + const T& init_value) + : shape_storage_(shape), shape_(shape_storage_.get()) { + nodes_.reserve(CountSubshapes(*shape_)); + nodes_.emplace_back(ShapeIndex{}, init_value); + InitChildren(*shape_, init_value, &nodes_[0]); } template const T& ShapeTree::element(const ShapeIndex& index) const { - return Lookup(index)->data; + return Lookup(index)->data.second; } template T* ShapeTree::mutable_element(const ShapeIndex& index) { - return &Lookup(index)->data; + return &Lookup(index)->data.second; } template internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { - Node* node = &root_; + Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); CHECK_LT(i, node->children.size()); - node = node->children[i].get(); + node = &nodes_[node->children[i]]; } return node; } @@ -552,13 +492,10 @@ const internal::ShapeTreeNode* ShapeTree::Lookup( /* static */ template template -Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, node.data)); - for (int64 i = 0; i < node.children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR(ForEachHelper(func, *node.children[i], index)); - index->pop_back(); +Status ShapeTree::ForEachHelper(const Fn& func, + const std::vector& nodes) { + for (const auto& node : nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, node.data.second)); } return Status::OK(); } @@ -566,14 +503,10 @@ Status ShapeTree::ForEachHelper(const Fn& func, const Node& node, /* static */ template template -Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, - ShapeIndex* index) { - TF_RETURN_IF_ERROR(func(*index, &node->data)); - for (int64 i = 0; i < node->children.size(); ++i) { - index->push_back(i); - TF_RETURN_IF_ERROR( - ForEachMutableHelper(func, node->children[i].get(), index)); - index->pop_back(); +Status ShapeTree::ForEachMutableHelper(const Fn& func, + std::vector* nodes) { + for (auto& node : *nodes) { + TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second)); } return Status::OK(); } @@ -581,40 +514,36 @@ Status ShapeTree::ForEachMutableHelper(const Fn& func, Node* node, template template Status ShapeTree::ForEachElementWithStatus(const Fn& func) const { - ShapeIndex index; - return ForEachHelper(func, root_, &index); + return ForEachHelper(func, nodes_); } template template Status ShapeTree::ForEachMutableElementWithStatus(const Fn& func) { - ShapeIndex index; - return ForEachMutableHelper(func, &root_, &index); + return ForEachMutableHelper(func, &nodes_); } template template void ShapeTree::ForEachElement(const Fn& func) const { - ShapeIndex index; return ForEachHelper( [&func](const ShapeIndex& index, const T& data) { func(index, data); return Status::OK(); }, - root_, &index) + nodes_) .IgnoreError(); } template template void ShapeTree::ForEachMutableElement(const Fn& func) { - ShapeIndex index; return ForEachMutableHelper( [&func](const ShapeIndex& index, T* data) { func(index, data); return Status::OK(); }, - &root_, &index) + &nodes_) .IgnoreError(); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 4b6ab772811f4a..dc5facf1581c07 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace xla { namespace { @@ -421,8 +422,8 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - t.begin()->second = 78; - EXPECT_EQ(78, t.begin()->second); + (*t.begin()).second = 78; + EXPECT_EQ(78, (*t.begin()).second); i = 0; for (auto& index_to_data : t) { if (i == 0) { @@ -434,14 +435,14 @@ TEST_F(ShapeTreeTest, IterateAndMutate) { } ++i; } - EXPECT_EQ(78, t.begin()->second); - EXPECT_EQ(98, std::next(t.begin())->second); + EXPECT_EQ(78, (*t.begin()).second); + EXPECT_EQ(98, (*std::next(t.begin())).second); } TEST_F(ShapeTreeTest, IterateOrder) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t) { + for (auto index_to_data : t) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{{}, @@ -479,7 +480,7 @@ TEST_F(ShapeTreeTest, ReverseIterateOrder) { TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; - for (auto& index_to_data : t.leaves()) { + for (auto index_to_data : t.leaves()) { v.push_back(index_to_data.first); } EXPECT_EQ(v, (std::vector{ @@ -502,5 +503,106 @@ TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { })); } +void BM_Construct(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(shape); + } +} + +void BM_ConstructUnowned(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + for (int i = 0; i < iters; ++i) { + ShapeTree shape_tree(&shape); + } +} + +void BM_Copy(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = shape_tree; + tensorflow::testing::DoNotOptimize(copy); + } +} + +void BM_Move(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + ShapeTree copy = std::move(shape_tree); + shape_tree = std::move(copy); + } +} + +void BM_ForEach(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + shape_tree.ForEachMutableElement([](const ShapeIndex& index, int* data) { + tensorflow::testing::DoNotOptimize(index); + }); + } +} + +void BM_Iterate(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + Shape shape = ShapeUtil::MakeShape(F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = ShapeUtil::MakeTupleShape(shapes); + } + tensorflow::testing::StartTiming(); + + ShapeTree shape_tree(shape); + for (int i = 0; i < iters; ++i) { + for (auto& iter : shape_tree) { + tensorflow::testing::DoNotOptimize(iter.second); + } + } +} + +BENCHMARK(BM_Construct)->ArgPair(2, 8); +BENCHMARK(BM_ConstructUnowned)->ArgPair(2, 8); +BENCHMARK(BM_Copy)->ArgPair(2, 8); +BENCHMARK(BM_Move)->ArgPair(2, 8); +BENCHMARK(BM_ForEach)->ArgPair(2, 8); +BENCHMARK(BM_Iterate)->ArgPair(2, 8); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ac7e201bfdceab..ce4d0079ee5eb2 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -27,11 +27,11 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -41,17 +41,35 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + string ShapeIndex::ToString() const { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } string ShapeIndexView::ToString() const { - return tensorflow::strings::StrCat( - "{", - tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), - ","), - "}"); + return StrCat("{", + tensorflow::str_util::Join( + tensorflow::gtl::make_range(begin_, end_), ","), + "}"); +} + +bool ShapeIndexView::operator==(const ShapeIndexView& other) const { + if (size() != other.size()) { + return false; + } + for (auto it = begin(), other_it = other.begin(); it != end(); + ++it, ++other_it) { + if (*it != *other_it) { + return false; + } + } + return true; +} + +bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { + return !(*this == other); } std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { @@ -66,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { namespace { +// Returns whether the given primitive type corresponds to an array shape. +bool IsArrayPrimitiveType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { - return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + if (!ShapeUtil::SameElementType(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + + if (ShapeUtil::IsTuple(lhs)) { + return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); - } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { - return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); + } else if (!ShapeUtil::IsArray(lhs)) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return true; } if (compare_layouts) { @@ -107,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; } - if (!ShapeUtil::SameElementType(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } return true; } @@ -153,8 +179,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) - << "Tuples do not have a rank, shape: " << shape; + CHECK(ShapeUtil::IsArray(shape)) + << "Non-arrays do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -181,8 +207,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); return result; @@ -205,8 +230,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, int64 max_sparse_elements) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -253,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return result; } +/* static */ Shape ShapeUtil::MakeTokenShape() { + Shape result; + result.set_element_type(TOKEN); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); @@ -276,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + if (!IsArray(shape)) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -302,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case TUPLE: case OPAQUE: + case TOKEN: return false; default: @@ -317,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return IsArrayPrimitiveType(shape.element_type()); +} + /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), IsTuple); @@ -370,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); + CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -385,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.element_type() == F32 && Rank(shape) == 0; } -/* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { - string text = "("; - const char* prefix = ""; - for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); - prefix = ", "; - } - text += ")"; - return text; - } else { - return tensorflow::strings::StrCat( - tensorflow::str_util::Lowercase( - PrimitiveType_Name(shape.element_type())), - "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); - } -} namespace { @@ -452,48 +471,56 @@ StatusOr StringToPrimitiveType(const string& name) { } // namespace -/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { +/* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, - HumanStringWithLayout(elem_shape)); + StrAppend(&text, prefix, HumanString(elem_shape)); prefix = ", "; } text += ")"; return text; - } else { - string result = tensorflow::strings::StrCat( - LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", - shape.dimensions(i)); + } + return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", + tensorflow::str_util::Join(shape.dimensions(), ","), "]"); +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (IsTuple(shape)) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + StrAppend(&text, prefix, HumanStringWithLayout(elem_shape)); + prefix = ", "; } - result += "]"; - if (!IsScalar(shape) && !IsOpaque(shape)) { - if (LayoutUtil::HasLayout(shape)) { - tensorflow::strings::StrAppend(&result, - LayoutUtil::HumanString(shape.layout())); - } + text += ")"; + return text; + } + string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + for (int i = 0; i < shape.dimensions().size(); i++) { + StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + } + result += "]"; + if (!IsScalar(shape) && IsArray(shape)) { + if (LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } - return result; } + return result; } /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { std::vector parameters; for (auto& shape : program_shape.parameters()) { const int i = parameters.size(); - parameters.push_back( - tensorflow::strings::StrCat(i < program_shape.parameter_names_size() - ? program_shape.parameter_names(i) - : "(unknown)", - ": ", HumanString(shape))); + parameters.push_back(StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); } - return tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", - HumanString(program_shape.result())); + return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); } namespace { @@ -510,7 +537,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); @@ -540,7 +567,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { return InvalidArgument( "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), s->ToString().c_str()); + input.c_str(), std::string(*s).c_str()); } return element; }; @@ -563,14 +590,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the primitive element type. TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || - primitive_type == OPAQUE) { + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (format_string.empty() && layout_string.empty()) { + if (primitive_type == OPAQUE) { + result = ShapeUtil::MakeOpaqueShape(); + } else if (primitive_type == TOKEN) { + result = ShapeUtil::MakeTokenShape(); + } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); } else if (format_string == "sparse") { @@ -593,7 +623,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } return InvalidArgument("Invalid shape string to parse: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } } // namespace @@ -602,7 +632,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { return InvalidArgument("Invalid shape string to parse: \"%s\"", - s.ToString().c_str()); + std::string(s).c_str()); } return shape; } @@ -615,43 +645,44 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + CompatibleIgnoringElementType(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return CompatibleIgnoringElementType(lhs, rhs); - } - return false; } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -673,10 +704,6 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { switch (primitive_type) { case PRED: return sizeof(int8); - case TUPLE: - LOG(FATAL) << "tuples have no definitive size"; - case OPAQUE: - LOG(FATAL) << "opaque have no definitive size"; case S8: return sizeof(int8); case S16: @@ -703,6 +730,13 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(double); case C64: return sizeof(complex64); + case TOKEN: + // Tokens require no space. + return 0; + case TUPLE: + case OPAQUE: + LOG(FATAL) << PrimitiveType_Name(primitive_type) + << " primitive type has no definitive size"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } @@ -711,28 +745,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); + } else if (IsArray(shape)) { + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; + } else if (shape.element_type() == TOKEN) { + return 0; } - int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } - return byte_size; + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " primitive type has no definitive size"; } /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_EQ(TUPLE, shape.element_type()); CHECK_GT(pointer_size, 0); return pointer_size * shape.tuple_shapes_size(); } /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(ShapeUtil::IsArray(shape)); + CHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -757,13 +795,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(LayoutUtil::IsSparseArray(shape)); + CHECK(LayoutUtil::IsSparseArray(shape)); return LayoutUtil::MaxSparseElements(shape.layout()) * ShapeUtil::Rank(shape) * sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -779,10 +821,24 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.tuple_shapes_size() > 0) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + + // Tokens and opaques can should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.dimensions_size() != 0) { + return InvalidArgument( + "shape has %s element type, but has dimensions field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + if (shape.has_layout()) { + return InvalidArgument( + "shape has %s element type, but has layout field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + return Status::OK(); } + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " @@ -862,7 +918,30 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !IsTuple(GetSubshape(shape, index)); } +/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + int64 count = 0; + ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + ++count; + } + }); + return count; +} + +/* static */ std::vector ShapeUtil::GetLeafShapes( + const Shape& shape) { + std::vector leaves; + ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + leaves.emplace_back(index, sub_shape); + } + }); + return leaves; +} + /* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { + CHECK(IsArray(shape)); + std::vector dimension_sizes; std::vector degenerate_dimensions; for (int64 i = 0; i < shape.dimensions_size(); ++i) { @@ -905,10 +984,17 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { std::is_permutation(minor_to_major.begin(), minor_to_major.end(), dims.begin())); } - Shape stripped_shape = - shape.has_layout() ? MakeShapeWithLayout(shape.element_type(), - dimension_sizes, minor_to_major) - : MakeShape(shape.element_type(), dimension_sizes); + Shape stripped_shape; + if (LayoutUtil::IsDenseArray(shape)) { + stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, + minor_to_major); + } else if (LayoutUtil::IsSparseArray(shape)) { + stripped_shape = + MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, + shape.layout().max_sparse_elements()); + } else { + stripped_shape = MakeShape(shape.element_type(), dimension_sizes); + } VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); @@ -1020,6 +1106,9 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { + CHECK(IsArray(shape_pre)); + CHECK(IsArray(shape_post)); + auto nil = std::make_tuple(false, std::vector(), std::vector()); std::vector deleted_indices; @@ -1077,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -1130,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(LayoutUtil::HasLayout(input_shape) && - LayoutUtil::HasLayout(output_shape)); + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + CHECK(LayoutUtil::HasLayout(input_shape)); + CHECK(LayoutUtil::HasLayout(output_shape)); if (!SameElementType(input_shape, output_shape)) { return false; @@ -1293,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + int64 input_rank = Rank(input_shape); int64 output_rank = Rank(output_shape); @@ -1427,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { + CHECK(IsArray(shape)); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); @@ -1448,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { + CHECK(IsArray(shape)); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { @@ -1465,4 +1564,26 @@ std::ostream& operator<<(std::ostream& out, const Shape& shape) { return out; } +/*static*/ size_t ShapeUtil::Hash(const Shape& shape) { + using tensorflow::hash; + using tensorflow::Hash64Combine; + + size_t hash_value = hash()(shape.element_type()); + + if (shape.tuple_shapes().empty()) { + for (int64 dim : shape.dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(dim)); + } + + hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout())); + } else { + hash_value = 0; + for (const Shape& subshape : shape.tuple_shapes()) { + hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape)); + } + } + + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 5fa728e7c2fa5f..3853ada6ba65db 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -132,6 +133,9 @@ class ShapeIndexView { return ShapeIndexView(new_begin, end_); } + bool operator==(const ShapeIndexView& other) const; + bool operator!=(const ShapeIndexView& other) const; + string ToString() const; private: @@ -150,12 +154,22 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + // Data structure which describes the coordinates and the shape, of a tuple + // shaped sub-shape. + struct IndexedShape { + IndexedShape() = default; + IndexedShape(ShapeIndex index, Shape shape) + : index(std::move(index)), shape(std::move(shape)) {} + ShapeIndex index; + Shape shape; + }; + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: !IsTuple(shape) + // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -166,13 +180,11 @@ class ShapeUtil { // shapes. This includes only the size of the top-level buffer. For example, a // tuple is stored as an array of pointers to other buffers. In this case, // this method only returns the size of the pointer array. - // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && - // !ShapeUtil::IsOpaque(shape) static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); // Returns the number of bytes used to store the primitive_type. // - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + // Precondition: ShapeUtil::IsArray(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -279,10 +291,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + return IsArray(shape) && Rank(shape) == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + return IsArray(shape) && TrueRank(shape) == 0; } static bool IsScalarF32(const Shape& shape); @@ -311,6 +323,10 @@ class ShapeUtil { // into a custom operation. static Shape MakeOpaqueShape(); + // Creates a token shape. Values of this shape are used for ordering + // side-effecting operations. + static Shape MakeTokenShape(); + // Appends a shape to the given tuple. static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); @@ -410,11 +426,15 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } + // Returns whether the shape is an token value used for ordering + // side-effecting operations. + static bool IsToken(const Shape& shape) { + return shape.element_type() == TOKEN; + } + // Returns whether the shape is an array. Note that scalars are considered // arrays. - static bool IsArray(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape); - } + static bool IsArray(const Shape& shape); // Returns whether the shape is a tuple with at least one element which is // also a tuple. @@ -461,6 +481,13 @@ class ShapeUtil { // shape. static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); + // Returns the number of leaves in the shape. + static int64 GetLeafCount(const Shape& shape); + + // Retrieves all the leaf shapes and their indexes, in the order walked by + // the ForEachSubshape() API. + static std::vector GetLeafShapes(const Shape& shape); + // Calls the given visitor function for each subshape of the given shape. // Subshapes are visited in DFS pre-order starting with the entire shape // (index {}). @@ -626,6 +653,28 @@ class ShapeUtil { .IgnoreError(); } + // These convenience wrappers don't take `base`, `count` and `incr` + // explicitly, but iterate over every element in `shape` instead. + + template + static Status ForEachIndexWithStatus(const Shape& shape, + const FnType& visitor_function) { + std::vector base(shape.dimensions_size()); + std::vector incr(shape.dimensions_size(), 1); + return ForEachIndexWithStatus(shape, base, + /*count=*/AsInt64Slice(shape.dimensions()), + incr, visitor_function); + } + + template + static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + // A parallel version of ForEachIndex(WithStatus). This can only be used if // the visitor_function is thread-safe and the order of iteration does not // matter. @@ -650,6 +699,9 @@ class ShapeUtil { .ok()); } + // Compute a hash for `shape`. + static size_t Hash(const Shape& shape); + private: // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 13582a2a267854..ecdb6532f1d743 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { } TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2]), f32[3])"; + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeShape(F32, {3}), }); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) @@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, ParseInvalidShapeString) { string shape_strings[] = { "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", @@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); + + EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); + EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { @@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) { TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(nested_tuple)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); @@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) { EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", - ShapeUtil::HumanStringWithLayout(nested_tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + ShapeUtil::HumanStringWithLayout(nested_tuple)); ProgramShape prog = ShapeUtil::MakeProgramShape( {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); @@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) { "(unknown): u32[1,2], " "(unknown): s32[3,4], " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); prog.add_parameter_names("arg0"); @@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) { "matrix: u32[1,2], " "matrix2: s32[3,4], " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); } @@ -713,6 +742,16 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(ShapeUtilTest, StripDegenerateDimensions) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShape(F32, {3, 1, 2})), + ShapeUtil::MakeShape(F32, {3, 2}))); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index 4eb3bf3766412d..69abb51852ac09 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,7 +21,7 @@ limitations under the License. namespace xla { -using tensorflow::Status; +using tensorflow::Status; // TENSORFLOW_STATUS_OK } // namespace xla diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index cccbce5fc83af8..0e1387c93938fa 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc6175..377a618ffbd993 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); @@ -405,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) { EXPECT_EQ(&kI, thing.ValueOrDie()); } -// NOTE(tucker): tensorflow::StatusOr does not support this kind +// NOTE(tucker): StatusOr does not support this kind // of resize op. // TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { // using EvilType = std::vector>; diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 17bae2e4f61126..8918350135fbb8 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -40,13 +40,10 @@ class Literal; namespace testing { namespace internal_status { -inline const ::tensorflow::Status& GetStatus( - const ::tensorflow::Status& status) { - return status; -} +inline const Status& GetStatus(const Status& status) { return status; } template -inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { +inline const Status& GetStatus(const StatusOr& status) { return status.status(); } } // namespace internal_status @@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { // The following macros are similar to macros in gmock, but deliberately named // differently in order to avoid conflicts in files which include both. -// Macros for testing the results of functions that return tensorflow::Status or +// Macros for testing the results of functions that return Status or // StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index c28d14ba8ac3a0..7f6bbe6f879fd9 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -87,12 +87,12 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -117,11 +117,11 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -138,8 +138,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -152,7 +152,6 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", @@ -188,8 +187,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -288,8 +285,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -313,7 +308,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -335,7 +329,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -378,7 +371,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -398,7 +390,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -422,8 +413,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -450,8 +439,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -472,7 +459,6 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -491,7 +477,6 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -528,7 +513,6 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -552,7 +536,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -572,8 +555,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -598,8 +579,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -626,12 +605,12 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -640,6 +619,7 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", + size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], backends = [ "cpu", @@ -647,13 +627,13 @@ xla_test( ], shard_count = 48, tags = [ - "enormous", "manual", "notap", ], deps = [ ":client_library_test_base", ":literal_test_util", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -695,7 +675,6 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -718,8 +697,8 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -739,7 +718,6 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -764,7 +742,6 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -788,7 +765,6 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -800,30 +776,42 @@ xla_test( ], ) +CONVOLUTION_TEST_DEPS = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", +] + xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], + deps = CONVOLUTION_TEST_DEPS, +) + +xla_test( + name = "convolution_test_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + deps = CONVOLUTION_TEST_DEPS, ) xla_test( @@ -841,7 +829,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -866,7 +853,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -928,8 +914,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -958,8 +942,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1000,7 +982,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1053,8 +1034,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1076,7 +1055,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1106,8 +1084,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1201,6 +1177,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1218,9 +1195,9 @@ xla_test( ], deps = [ ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1237,8 +1214,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -1278,7 +1253,6 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1301,7 +1275,6 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1341,7 +1314,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1359,7 +1331,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1385,8 +1356,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1408,7 +1377,6 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1480,8 +1448,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1529,7 +1495,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1542,6 +1507,30 @@ xla_test( ], ) +xla_test( + name = "cross_replica_sum_test", + srcs = ["cross_replica_sum_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], @@ -1571,8 +1560,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1593,7 +1580,6 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1617,8 +1603,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1639,7 +1623,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1658,7 +1641,6 @@ xla_test( srcs = ["execution_profile_test.cc"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1673,7 +1655,6 @@ xla_test( args = ["--xla_hlo_profile"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1779,8 +1760,6 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1808,8 +1787,6 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1847,8 +1824,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1864,7 +1839,10 @@ xla_test( xla_test( name = "local_client_execute_test", + # TODO(b/79375911): Test times out in LLVM at normal size. + size = "large", srcs = ["local_client_execute_test.cc"], + shard_count = 30, tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", @@ -1874,8 +1852,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1930,24 +1906,6 @@ xla_test( ], ) -xla_test( - name = "set_return_value_test", - srcs = ["set_return_value_test.cc"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], @@ -1961,8 +1919,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1979,6 +1935,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2062,7 +2019,6 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index e8a5efe796a920..36a706496918ac 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2225,6 +2225,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { + XlaBuilder builder(TestName()); + auto a = + builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 4e65cf11f3f1a0..ca337e78840e77 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 6ebbf7191833ef..51b9f0d3e330e7 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, - error_spec_); + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { @@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralView::Create(*result, {0}), error_spec_); + LiteralSlice(*result, {0}), error_spec_)); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralView::Create(*result, {1}), error_spec_); + LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { @@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_); + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 5e42365ae38dcc..5fd33b50c94356 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,16 +32,16 @@ namespace { class CallOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0F32IdentityComputation() { - ComputationBuilder builder(client_, "Identity"); + XlaComputation CreateR0F32IdentityComputation() { + XlaBuilder builder("Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateR1S0F32AdditionComputation() { - ComputationBuilder builder(client_, "Addition"); + XlaComputation CreateR1S0F32AdditionComputation() { + XlaBuilder builder("Addition"); auto x = builder.Parameter(0, r1s0f32_, "x"); auto y = builder.Parameter(1, r1s0f32_, "y"); builder.Add(x, y); @@ -50,8 +50,8 @@ class CallOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR1S2F32AdditionComputation() { - ComputationBuilder builder(client_, "Addition"); + XlaComputation CreateR1S2F32AdditionComputation() { + XlaBuilder builder("Addition"); auto x = builder.Parameter(0, r1s2f32_, "x"); auto y = builder.Parameter(1, r1s2f32_, "y"); builder.Add(x, y); @@ -60,8 +60,8 @@ class CallOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0F32TupleComputation() { - ComputationBuilder builder(client_, "Tuple"); + XlaComputation CreateR0F32TupleComputation() { + XlaBuilder builder("Tuple"); builder.Tuple({builder.Parameter(0, r0f32_, "x")}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); @@ -74,8 +74,8 @@ class CallOpTest : public ClientLibraryTestBase { }; XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR0F32IdentityComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR0F32IdentityComputation(); auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); builder.Call(callee, {constant}); @@ -83,8 +83,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { } XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR1S0F32AdditionComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR1S0F32AdditionComputation(); auto x = builder.ConstantLiteral(*Literal::CreateR1({})); auto y = builder.ConstantLiteral(*Literal::CreateR1({})); builder.Call(callee, {x, y}); @@ -93,8 +93,8 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { } XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR1S2F32AdditionComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); builder.Call(callee, {x, y}); @@ -103,23 +103,23 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { } XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { - ComputationBuilder builder(client_, "inner"); + XlaBuilder builder("inner"); { auto x = builder.Parameter(0, r0f32_, "x"); builder.Add(x, builder.ConstantR0(1.0)); } - TF_ASSERT_OK_AND_ASSIGN(Computation inner, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build()); - ComputationBuilder builder2(client_, "outer"); + XlaBuilder builder2("outer"); { auto x = builder2.Parameter(0, r0f32_, "x"); x = builder2.Call(inner, {x}); x = builder2.Call(inner, {x}); x = builder2.Call(inner, {x}); } - TF_ASSERT_OK_AND_ASSIGN(Computation outer, builder2.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build()); - ComputationBuilder builder3(client_, "outermost"); + XlaBuilder builder3("outermost"); { auto x = builder3.Parameter(0, r0f32_, "x"); x = builder3.Call(outer, {x}); @@ -134,8 +134,8 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { } XLA_TEST_F(CallOpTest, CallR0F32Tuple) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR0F32TupleComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR0F32TupleComputation(); auto elem = Literal::CreateR0(42.0); auto tuple = Literal::MakeTuple({elem.get()}); builder.Call(callee, {builder.ConstantLiteral(*elem)}); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index f594cc10ac6496..660ff0cad56662 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -35,7 +35,7 @@ using ::testing::ContainsRegex; class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { - ComputationBuilder builder(client_, "add_two_params"); + XlaBuilder builder("add_two_params"); auto param_literal = Literal::CreateR1({1.1f, 2.2f}); auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); @@ -75,7 +75,7 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { } XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { - ComputationBuilder builder(client_, "add_two_params"); + XlaBuilder builder("add_two_params"); auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 22660c35dcaa0e..bf8ed4d9fb0bc6 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -94,27 +93,13 @@ string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } -template StatusOr> ClientLibraryTestBase::Execute( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout) { - ExecutionOptions execution_options = execution_options_; - if (shape_with_output_layout != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; - } - return client_->ExecuteAndTransfer(computation, arguments, - &execution_options); -} - StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -128,17 +113,6 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -template <> -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout) { - // Build the computation, as a convenience. - TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); -} - -template <> StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { @@ -162,18 +136,6 @@ ClientLibraryTestBase::ExecuteAndTransferReference( &execution_options); } -std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return Execute(builder, arguments).ConsumeValueOrDie(); -} - -std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); -} - string ClientLibraryTestBase::ExecuteToString( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { auto computation_status = builder->Build(); @@ -191,32 +153,6 @@ string ClientLibraryTestBase::ExecuteToString( } } -string ClientLibraryTestBase::ExecuteToString( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - auto computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status().ToString(); - } - auto computation = computation_status.ConsumeValueOrDie(); - - auto result = - client_->ExecuteAndTransfer(computation, arguments, &execution_options_); - if (!result.ok()) { - return result.status().ToString(); - } else { - return result.ValueOrDie()->ToString(); - } -} - -void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = Literal::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, - arguments); -} - void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { @@ -225,27 +161,24 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output) { @@ -266,12 +199,11 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( "Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& /*expected*/, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, @@ -281,8 +213,8 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // This is a recursive function. It's an std::function instead of a lambda // because it needs to capture itself. The index is the index of the argument // to try all layouts for. - std::function choose; - choose = [&, this](int64 index) -> tensorflow::Status { + std::function choose; + choose = [&, this](int64 index) -> Status { if (index < arguments.size()) { // Try out all layouts for the operand. TF_ASSIGN_OR_RETURN(auto literal, @@ -295,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); layout_strings.pop_back(); - return tensorflow::Status::OK(); + return Status::OK(); } std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); @@ -313,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( layout_strings.pop_back(); } while ( std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } // Every argument has an assigned layout. @@ -328,34 +260,14 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::strings::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); - return tensorflow::Status::OK(); + return Status::OK(); }; return choose(0); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice /*arguments*/, - const std::function& /*verify_output*/) { - return Unimplemented("not yet implemented for XlaComputation"); -} - -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice /*arguments*/, - const std::function& /*verify_output*/, - const Shape* /*output_with_layout*/) { - return Unimplemented("not yet implemented for XlaComputation"); -} - -template -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -382,7 +294,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -396,7 +308,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -408,13 +320,12 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + return Status::OK(); } -template -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -435,7 +346,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -449,7 +360,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)) + << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -461,8 +373,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( @@ -484,9 +396,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( EXPECT_EQ(expected, actual->GetR1U8AsString()); } -template void ClientLibraryTestBase::ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -494,12 +405,11 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); } -template void ClientLibraryTestBase::ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -507,61 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(expected, *actual, error); -} - -void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); - EXPECT_IS_OK(status_or_data); - if (!status_or_data.ok()) { - return; - } - std::unique_ptr reference, result; - std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); -} - -void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); - EXPECT_IS_OK(status_or_data); - if (!status_or_data.ok()) { - return; - } - std::unique_ptr reference, result; - std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); -} - -StatusOr, std::unique_ptr>> -ClientLibraryTestBase::ComputeValueAndReference( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - // Transfer the arguments to the executor service. We put the unique_ptr's - // into a vector to keep the data alive on the service until the end of this - // function. - std::vector> argument_data; - for (const auto& arg : arguments) { - TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); - argument_data.push_back(std::move(data)); - } - - // Create raw pointers to the GlobalData for the rest of the call stack. - std::vector argument_data_ptr; - std::transform( - argument_data.begin(), argument_data.end(), - std::back_inserter(argument_data_ptr), - [](const std::unique_ptr& data) { return data.get(); }); - - TF_ASSIGN_OR_RETURN( - auto reference, - builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); - TF_ASSIGN_OR_RETURN(auto result, - ExecuteAndTransfer(builder, argument_data_ptr)); - return std::make_pair(std::move(reference), std::move(result)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -573,7 +429,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -586,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); } StatusOr, std::unique_ptr>> @@ -651,8 +507,8 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { return computation_status.ConsumeValueOrDie(); } -Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { - ComputationBuilder builder(client_, "relu_sensitivity"); +XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { + XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto activation = builder.Parameter(0, shape, "activation"); auto backprop = builder.Parameter(1, shape, "backprop"); @@ -693,14 +549,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -ComputationDataHandle ClientLibraryTestBase::AddParam( - const Literal& argument, ComputationBuilder* builder) { - ComputationDataHandle data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; -} - XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { XlaOp data_handle; @@ -709,59 +557,39 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, return data_handle; } -ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( - const Literal& literal, ComputationBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); -} - XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); + use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = Literal::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - -template StatusOr> ClientLibraryTestBase::Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - -template StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 32eea7c2f3a65d..0499fec5898a42 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -25,10 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -91,21 +90,11 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience methods for building and running a computation with the member // execution options. Modify execution_options_ in your test if you want to // customize the options. - template StatusOr> Execute( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once - // the migration to XlaBuilder is complete. - - template StatusOr> ExecuteAndTransfer( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout = nullptr); - - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); StatusOr> ExecuteAndTransfer( @@ -121,101 +110,90 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); - // Convenience OrDie variants of above methods. - std::unique_ptr ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - string ExecuteToString(ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder/XlaBuilder for details). For each rank, two forms are + // XlaBuilder for details). For each rank, two forms are // provided: one for floating point types with an ErrorSpec parameter, and one // for integral types without the ErrorSpec parameter. - template - void ComputeAndCompareR0(BuilderT* builder, NativeT expected, + template + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR0(BuilderT* builder, NativeT expected, + template + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR1(BuilderT* builder, + template + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR1(BuilderT* builder, + template + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. - void ComputeAndCompareR1(ComputationBuilder* builder, - const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, + template + void ComputeAndCompareR2(XlaBuilder* builder, + const Array2D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, + template + void ComputeAndCompareR2(XlaBuilder* builder, + const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, + template + void ComputeAndCompareR3(XlaBuilder* builder, + const Array3D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, + template + void ComputeAndCompareR3(XlaBuilder* builder, + const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, + template + void ComputeAndCompareR4(XlaBuilder* builder, + const Array4D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, + template + void ComputeAndCompareR4(XlaBuilder* builder, + const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - template void ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - template void ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - template - tensorflow::Status ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + Status ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - template - tensorflow::Status ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + Status ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -227,25 +205,13 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - template void ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments); - template void ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - // Convenience method for running a built computation and comparing the result - // with the HloEvaluator. - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); - // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, @@ -257,7 +223,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); XlaComputation CreateScalarMax(); - Computation CreateScalarReluSensitivity(); + XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -297,34 +263,25 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. - template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - BuilderT* builder, HandleT* data_handle); + XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. - template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, BuilderT* builder, - HandleT* data_handle); + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters // or none and no parameters must be passed when invoking the computation if // using this mechanism. If using this mechanism, then each parameter must be // set exactly once. The first added parameter gets index 0, then 1 and so on. - ComputationDataHandle AddParam(const Literal& argument, - ComputationBuilder* builder); XlaOp AddParam(const Literal& argument, XlaBuilder* builder); - template - ComputationDataHandle AddParam(const Array& argument, - ComputationBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); - } template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { return AddParam(*Literal::CreateFromArray(argument), builder); @@ -333,18 +290,11 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a constant instruction with the given literal. When the // use_bfloat16 flag is set but the literal has F32 elements, the elements // will be converted to BF16s. - ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, - ComputationBuilder* builder); XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be // converted to bfloat16s. - template - ComputationDataHandle CreateConstantFromArray(const Array& array, - ComputationBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); - } template XlaOp CreateConstantFromArray(const Array& array, @@ -353,13 +303,6 @@ class ClientLibraryTestBase : public ::testing::Test { } // Same as CreateConstantFromArray, but for scalars. - template - ComputationDataHandle CreateConstantFromScalar(NativeT value, - ComputationBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0(value), - builder); - } - template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { return CreateConstantFromLiteral(*Literal::CreateR0(value), @@ -374,12 +317,12 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR0Parameter(NativeT value, int64 parameter_number, const string& name, - BuilderT* builder, - HandleT* data_handle); + XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -389,10 +332,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -403,10 +346,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -417,10 +360,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -435,40 +378,18 @@ class ClientLibraryTestBase : public ::testing::Test { ExecutionOptions execution_options_; private: - // Build and run the computation with all permutations of output layouts. - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const std::function& verify_output); - // Build and run the computation with all permutations of layouts of all input - // arguments. - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const std::function& verify_output, - const Shape* output_with_layout = nullptr); - - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output); - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the HloEvaluator. Returns two literals in the order of (expected, actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments); - // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). @@ -484,9 +405,9 @@ class ClientLibraryTestBase : public ::testing::Test { std::vector> arguments_; }; -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - BuilderT* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -494,9 +415,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - BuilderT* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -510,9 +431,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - BuilderT* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -520,9 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - BuilderT* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -536,9 +457,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - BuilderT* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -546,9 +467,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - BuilderT* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -562,9 +483,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - BuilderT* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -572,9 +493,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - BuilderT* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -588,9 +509,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - BuilderT* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -598,9 +519,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - BuilderT* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -614,13 +535,13 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } -template +template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - BuilderT* builder, HandleT* data_handle) { + XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -628,13 +549,13 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -642,13 +563,13 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -656,13 +577,13 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -695,37 +616,6 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -template -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - BuilderT* builder, - HandleT* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -template -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, BuilderT* builder, - HandleT* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 1e544717967731..08671cf6244582 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -39,7 +38,7 @@ namespace { class ClientTest : public ClientLibraryTestBase {}; XLA_TEST_F(ClientTest, ExecuteWithLayout) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& execute_layout : layouts) { @@ -63,15 +62,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})}); @@ -92,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralView::Create(*result, {0})); + LiteralSlice(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralView::Create(*result, {1})); + LiteralSlice(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -143,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { auto result_literal, client_->Transfer(*results[0], &expected_result->shape())); - LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 0f780fa87ef98f..50a006964869b3 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -39,7 +39,7 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: void ExecuteComputationR0F32( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; @@ -49,13 +49,13 @@ class CompilationCacheTest : public ClientLibraryTestBase { /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } void ExecuteComputationR2F32( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, std::initializer_list> expected_result, bool expect_cache_hit) { @@ -66,25 +66,28 @@ class CompilationCacheTest : public ClientLibraryTestBase { .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } ErrorSpec error_spec_{0.0001}; }; -XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { - ComputationBuilder builder(client_, TestName()); +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(42.0)); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, + DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = client_->TransferToServer(*Literal::CreateR0(42.0f)) .ConsumeValueOrDie(); @@ -95,9 +98,9 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { client_->TransferToServer(*Literal::CreateR0(456.0f)) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, /*expect_cache_hit=*/false); @@ -109,19 +112,20 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, MultipleComputations) { - ComputationBuilder builder_neg(client_, TestName() + "_neg"); +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { + XlaBuilder builder_neg(TestName() + "_neg"); builder_neg.Neg(builder_neg.ConstantR0(42.0)); - Computation computation_neg = builder_neg.Build().ConsumeValueOrDie(); + XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie(); - ComputationBuilder builder_exp(client_, TestName() + "_exp"); + XlaBuilder builder_exp(TestName() + "_exp"); builder_exp.Exp(builder_exp.ConstantR0(1.0)); - Computation computation_exp = builder_exp.Build().ConsumeValueOrDie(); + XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie(); - ComputationBuilder builder_add(client_, TestName() + "_add"); + XlaBuilder builder_add(TestName() + "_add"); builder_add.Add(builder_add.ConstantR0(2.0), builder_add.ConstantR0(3.0)); - Computation computation_add = builder_add.Build().ConsumeValueOrDie(); + XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation_neg, {}, -42.0, /*expect_cache_hit=*/false); @@ -133,7 +137,8 @@ XLA_TEST_F(CompilationCacheTest, MultipleComputations) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { // Create two GlobalData arrays with the same shape but different // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache @@ -148,9 +153,9 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR2F32(computation, {colmaj_handle.get()}, {{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -169,32 +174,5 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, MutatedComputation) { - // Build a computation, execute it, then mutate it. The mutated computation - // should not be in the cache until it is run once. This must be done through - // the stub interface because Computations built from ComputationBuilder are - // immutable. - ComputationBuilder builder(client_, TestName()); - auto neg = builder.Neg(builder.ConstantR0(42.0)); - Computation computation = builder.Build().ConsumeValueOrDie(); - - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); - - BinaryOpRequest request; - request.set_binop(BINOP_ADD); - *request.mutable_lhs() = neg; - *request.mutable_rhs() = neg; - OpRequest op_request; - *op_request.mutable_computation() = computation.handle(); - *op_request.mutable_binary_op_request() = request; - OpResponse response; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - ASSERT_TRUE(s.ok()); - - ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 7ea82a791f72ea..ba22530f1cfee5 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -18,8 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -88,17 +86,6 @@ class ComputeConstantTest : public ::testing::Test { return literal->Get({}); } - template - StatusOr ComputeConstantScalar( - Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN(auto literal, - builder->ComputeConstant( - operand, /*output_layout=*/nullptr, parameters)); - return literal->Get({}); - } - bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { StatusOr result = builder->IsConstant(operand); EXPECT_TRUE(result.ok()) << result.status(); @@ -150,26 +137,6 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, Param) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); - auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); - auto computation = b.Add(param, b.ConstantR0(1.5f)); - - std::vector arguments; - arguments.push_back(std::move(*Literal::CreateR0(42.5f))); - TF_ASSERT_OK_AND_ASSIGN(bool is_constant, - b.IsConstant(computation, arguments.size())); - EXPECT_TRUE(is_constant); - - TF_ASSERT_OK_AND_ASSIGN( - auto value, - ComputeConstantScalar(client, computation, &b, arguments)); - EXPECT_EQ(value, 44.0f); - } -} - TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); @@ -241,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -255,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -277,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) { std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 35aa3f6d696297..916ffadbc798ec 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,12 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,7 +38,7 @@ class ConstantsTest : public ClientLibraryTestBase { }; TEST_F(ConstantsTest, ZeroCellF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({}); ComputeAndCompareR1(&builder, {}, {}, error_spec_); @@ -48,7 +47,7 @@ TEST_F(ConstantsTest, ZeroCellF32) { TEST_F(ConstantsTest, OneCellF32) { std::vector constant = {2.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); @@ -57,7 +56,7 @@ TEST_F(ConstantsTest, OneCellF32) { TEST_F(ConstantsTest, OneCellS32) { std::vector constant = {2}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}); @@ -66,7 +65,7 @@ TEST_F(ConstantsTest, OneCellS32) { TEST_F(ConstantsTest, OneCellU32) { std::vector constant = {2}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}); @@ -75,7 +74,7 @@ TEST_F(ConstantsTest, OneCellU32) { TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); @@ -85,14 +84,14 @@ TEST_F(ConstantsTest, SixteenCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR2FromArray2D(Array2D(0, 2)); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); @@ -102,14 +101,14 @@ TEST_F(ConstantsTest, Small_2x2) { std::unique_ptr> constant = MakeLinspaceArray2D(100.0, 200.0, 2, 2); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR2FromArray2D(*constant); ComputeAndCompareR2(&builder, *constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_3x0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto constant = builder.ConstantLiteral( *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); @@ -117,7 +116,7 @@ TEST_F(ConstantsTest, Empty_3x0x2) { } TEST_F(ConstantsTest, Small_2x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D array3d({ // x0 x1 {{1.f, 2.f}, // y0 @@ -145,13 +144,13 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { Literal::CreateR4FromArray4D(input_array); { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantLiteral(*input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR4FromArray4D(input_array); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,17 +158,18 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantLiteral( *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), Literal::CreateR1({2.0, 42}).get()})); - std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); + std::unique_ptr result = + ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); + {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index e67a30d76c2fac..722d882471a41a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -44,7 +44,7 @@ class ConvertTest : public ClientLibraryTestBase { }; TEST_F(ConvertTest, ConvertR1S32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.ConvertElementType(a, S32); @@ -53,7 +53,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { } TEST_F(ConvertTest, ConvertR1F32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0f, 64.0f}); builder.ConvertElementType(a, F32); @@ -62,7 +62,7 @@ TEST_F(ConvertTest, ConvertR1F32ToR1F32) { } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.ConvertElementType(a, F32); @@ -71,7 +71,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, true}); builder.ConvertElementType(a, S32); @@ -80,7 +80,7 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { } TEST_F(ConvertTest, ConvertR1PREDToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, true}); builder.ConvertElementType(a, F32); @@ -89,7 +89,7 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); builder.ConvertElementType(a, F32); @@ -98,7 +98,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.6, 64.4}); builder.ConvertElementType(a, S32); @@ -107,7 +107,7 @@ TEST_F(ConvertTest, ConvertR1F32ToR1S32) { } XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{ -9223371216516022272, -2, @@ -160,7 +160,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; @@ -179,7 +179,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); @@ -197,7 +197,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { } XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); @@ -214,7 +214,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { } XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); @@ -231,7 +231,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Test cases from compiler_rt library. std::vector arg{0.0f, 0.5f, @@ -249,10 +249,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { -1.99f, -2.0f, -2.01f, - 0x1.FFFFFEp+62F, - 0x1.FFFFFCp+62F, - -0x1.FFFFFEp+62F, - -0x1.FFFFFCp+62F}; + 9223371487098961920.f, + 9223370937343148032.f, + -9223371487098961920.f, + -9223370937343148032.f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = @@ -268,7 +268,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, F32); @@ -277,7 +277,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, S32); @@ -286,7 +286,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, U32); @@ -295,7 +295,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32.0f, 64.0f}); builder.ConvertElementType(a, F64); @@ -304,7 +304,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { } XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32.0, 64.0}); builder.ConvertElementType(a, F32); @@ -313,7 +313,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { } TEST_F(ConvertTest, ConvertS32Extremes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {std::numeric_limits::min(), std::numeric_limits::max()}); builder.ConvertElementType(a, F32); @@ -325,7 +325,7 @@ TEST_F(ConvertTest, ConvertS32Extremes) { } TEST_F(ConvertTest, ConvertMapToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->ConvertElementType(param, S32); @@ -337,7 +337,7 @@ TEST_F(ConvertTest, ConvertMapToS32) { } TEST_F(ConvertTest, ConvertMapToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->ConvertElementType(param, F32); @@ -354,7 +354,7 @@ TEST_F(ConvertTest, ConvertMapToF32) { // input -> convert -> reshape // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR1({42}); auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); builder.ConvertElementType(reshape, F32); @@ -393,7 +393,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::unique_ptr dot_lhs_handle, client_->TransferToServer(*Literal::CreateR1(input))); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConvertElementType( builder.Parameter( 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), @@ -413,7 +413,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::unique_ptr dot_lhs_handle, client_->TransferToServer(*Literal::CreateR1(input))); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConvertElementType( builder.Parameter( 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), @@ -424,28 +424,28 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { } XLA_TEST_F(ConvertTest, ConvertC64ToC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector x = {{42.0f, 64.0f}}; builder.ConvertElementType(builder.ConstantR1(x), C64); ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConvertTest, ConvertS64S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector x = {{-42, 64}}; builder.ConvertElementType(builder.ConstantR1(x), S64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64U64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector x = {{42, 64}}; builder.ConvertElementType(builder.ConstantR1(x), U64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector unsigned_x = {{42, UINT64_MAX}}; builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); std::vector signed_x = {{42, -1}}; @@ -453,7 +453,7 @@ XLA_TEST_F(ConvertTest, ConvertU64S64) { } XLA_TEST_F(ConvertTest, ConvertS64U64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector signed_x = {{42, -1, INT64_MIN}}; builder.ConvertElementType(builder.ConstantR1(signed_x), U64); std::vector unsigned_x = { diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 50d6e25d868c49..fea850dc135e33 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index ece7c3b05e7faf..2b3390ca98cb29 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -48,7 +49,7 @@ class CopyOpTest : public HloTestBase { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectEqual(literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -246,13 +247,13 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); auto empty = Literal::CreateFromShape(in_shape); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto param0 = builder.Parameter(0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*empty, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc new file mode 100644 index 00000000000000..b151187c4b8f01 --- /dev/null +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class TrivialCrossReplicaSumTest : public HloTestBase {}; + +// Currently the CPU and GPU backends only support CrossReplicaSum with one +// replica. But we can at least check this. + +XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[3] parameter(0) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal = Literal::CreateR1({1, 2, 3}); + EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); +} + +XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] parameter(1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ( + *Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); +} + +// On the GPU backend, constants get special handling. Someone might pass a +// constant to CRS to e.g. count the number of replicas -- we need to make sure +// it works. +XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { + const char* module_str = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p0 = f32[3] parameter(0) + p1 = f32[2] constant({10, 20}) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add + })"; + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto literal0 = Literal::CreateR1({1, 2, 3}); + auto literal1 = Literal::CreateR1({10, 20}); + EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), + *ExecuteAndTransfer(std::move(module), {literal0.get()})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index fe5621e8dc209d..bfe688e20d182d 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -36,9 +36,8 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - Computation computation = builder->Build().ConsumeValueOrDie(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) .ConsumeValueOrDie(); @@ -48,7 +47,7 @@ class DeallocationTest : public ClientLibraryTestBase { }; TEST_F(DeallocationTest, DeallocateScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -66,7 +65,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { } TEST_F(DeallocationTest, DeallocateVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -79,7 +78,7 @@ TEST_F(DeallocationTest, DeallocateVector) { } TEST_F(DeallocationTest, DeallocateEmptyVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -92,7 +91,7 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { } XLA_TEST_F(DeallocationTest, DeallocateTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tuple({builder.ConstantR0(42.0), builder.ConstantR1({1.0, 2.0, 3.0})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -106,7 +105,7 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { } XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto element = builder.ConstantR0(42.0); auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); builder.Tuple({element, inner_tuple, element}); @@ -121,7 +120,7 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { } XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), builder.ConstantR1({1.0, 2.0, 3.0})}); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 3ab0ea4ad48c00..12789fe66530fe 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -42,9 +42,8 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - Computation computation = builder->Build().ConsumeValueOrDie(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) .ConsumeValueOrDie(); @@ -54,7 +53,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { }; TEST_F(DeconstructTupleTest, DeconstructTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2}); @@ -73,7 +72,7 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2}); @@ -103,7 +102,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2, const2, const1}); @@ -129,7 +128,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2, const1}); @@ -159,7 +158,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -170,7 +169,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = @@ -186,7 +185,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { } XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({builder.Tuple({const1, const2}), const1}); diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 1da7a96fe2388e..085a5105aca1c1 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { @@ -22,12 +23,12 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { // intended to track, we need to set kDepth to 20000. // Unfortunately, setting it that high causes the test to time out. const int kDepth = 200; - ComputationBuilder b(client_, TestName()); - ComputationDataHandle x; - ComputationDataHandle y; + XlaBuilder b(TestName()); + XlaOp x; + XlaOp y; auto x_data = CreateR0Parameter(3, 0, "x", &b, &x); auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); - ComputationDataHandle z = x; + XlaOp z = x; for (int i = 0; i < kDepth; ++i) { z = b.Add(z, y); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c4031dfee593a1..0fd846cef8095a 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -51,21 +51,20 @@ using TypesF16F32F64 = ::testing::Types; using TypesF16F32F64CF64 = ::testing::Types; #elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ - defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX) using TypesF16F32 = ::testing::Types; using TypesF16F32F64 = ::testing::Types; -using TypesF16F32F64CF64 = - ::testing::Types; +using TypesF16F32F64CF64 = ::testing::Types; #else #error "Situation not handled yet" #endif // Check that we can safely pass an input tuple's elements to a dot operation. -TEST_F(DotOperationTest, DotOfInputTupleElem) { - ComputationBuilder builder(client_, TestName()); +XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { + XlaBuilder builder(TestName()); - ComputationDataHandle param; + XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), @@ -86,7 +85,7 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); @@ -102,7 +101,7 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); auto result = builder.Dot(lhs, rhs); @@ -113,7 +112,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR1({static_cast(2.0f)}); auto rhs = builder.ConstantR1({static_cast(3.0f)}); auto result = builder.Dot(lhs, rhs); @@ -124,7 +123,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); auto result = builder.Dot(lhs, rhs); @@ -139,7 +138,7 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); @@ -150,7 +149,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto rhs = builder.ConstantR2FromArray2D( {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); @@ -162,7 +161,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D( {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); @@ -174,7 +173,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto result = builder.Dot(lhs, rhs); @@ -185,7 +184,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto param0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); auto param1 = @@ -230,7 +229,7 @@ class SquareMatrixDot : public DotOperationTest { LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), @@ -315,7 +314,7 @@ void ParametricDotTest::TestImpl() { addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, @@ -491,7 +490,7 @@ class NonsquareMatrixDot : public DotOperationTest { MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), @@ -523,7 +522,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), @@ -538,7 +537,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); auto matrix12 = builder.Dot(matrix1, matrix2); @@ -559,7 +558,7 @@ TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); // sync-dependent on bitcasts' operands. XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); auto y = @@ -569,7 +568,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); // Slice batches into individual matrices and multiply them. - std::vector out_slices; + std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); @@ -615,7 +614,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); auto y = @@ -677,7 +676,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto lhs_arg = builder.Parameter( 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), @@ -713,7 +712,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, new Array2D({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); @@ -761,7 +760,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {4.0f, 3.0f}, {2.0f, 1.0f}})); - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); @@ -799,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{96.0, 105.0, 114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{105.0}, {105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{105.0, 105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{96.0}, {105.0}, {114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{126.0, 129.0, 132.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{129.0}, {129.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{56.0, 168.0, 91.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{168.0}, {168.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index ff53a84588fc04..49f3a10d227f2f 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -53,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR1Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + void TestR1OOB() { + // Slice at dimension boundaries, but with out of bounds indices. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7}); } template @@ -78,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR2Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR2OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } template @@ -106,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR3Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR3OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, - {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); + {2, 1, 2}, {{{5, 6}}, {{11, 12}}}); } template @@ -199,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } @@ -332,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void TestWrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestOOB() { + // // Slice at dimension boundaries, but with out of bounds indices. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {10, 1, 2, 3, 4, 5, 8, 9}); + {0, 1, 2, 3, 4, 8, 9, 10}); // R2 Shape: [3, 3] RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); } template @@ -361,9 +361,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -476,20 +476,19 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { Array3D input_values(kSeq, kBatch, kDim); Array3D update_values(size, kBatch, kDim); Array3D expected_values(kSeq, kBatch, kDim); + index = std::min(std::max(0, index), kSeq - size); input_values.FillIota(static_cast(0)); T value = static_cast(10); update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when - // the update wraps. According to documentation, the results are technically - // implementation specific where the update is out of bounds, and hence - // we don't really know what to pass into ComputeAndCompareR3. + // the indices are out of bounds. expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = value++; + expected_values(index + i, j, k) = value++; } } } @@ -547,12 +546,10 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) { - TestWrap(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -615,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index c8cc8e40aa3210..a6ba6db5d3bf86 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -32,9 +33,9 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { client_->TransferToServer( *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); - TF_ASSERT_OK_AND_ASSIGN(Computation dot_product, b.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build()); ExecutionProfile execution_profile; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index b28fe0c15a89a1..0a37e4d4236201 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -35,7 +36,7 @@ class ExhaustiveF32ElementwiseOpTest int64 input_size = end - begin; LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateFromDimensions(F32, {input_size}); @@ -78,9 +79,7 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #endif ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Log(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); }, std::log, known_incorrect_range); } @@ -96,17 +95,13 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { #endif ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Exp(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); }, std::exp, known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Tanh(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); }, std::tanh, /*known_incorrect_range=*/{0, 0}); } diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index a5f6872c46c780..93d1c921c4a138 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -38,7 +38,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern)); // Invoke FileCheck to check whether input matches `pattern`. - const char* file_check_path_suffix = "external/llvm/FileCheck"; + const char* file_check_path_suffix = "org_tensorflow/external/llvm/FileCheck"; string file_check_path; if (const char* test_srcdir = getenv("TEST_SRCDIR")) { file_check_path = JoinPath(test_srcdir, file_check_path_suffix); @@ -66,6 +66,11 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { // the error message generated by FileCheck and the inputs. bool succeeded = (exit_status == 0); if (!succeeded) { + LOG(WARNING) << "Tried to execute FileCheck at " << file_check_path; + if (!env->FileExists(file_check_path).ok()) { + LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; + } + LOG(WARNING) << "FileCheck error: " << standard_error; LOG(WARNING) << "FileCheck input was:"; XLA_LOG_LINES(tensorflow::WARNING, input); diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index e75a41acacc3aa..71eb914a8e5eae 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -41,7 +41,7 @@ class FloorCeilTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice expected, Function f) { LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") << "}"; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto c = builder.ConstantR1(input); if (f == kCeil) { builder.Ceil(c); @@ -54,7 +54,7 @@ class FloorCeilTest : public ClientLibraryTestBase { void TestR0F32(float input, float expected, Function f) { LOG(INFO) << "input: " << expected; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto c = builder.ConstantR0(input); if (f == kCeil) { builder.Ceil(c); diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index f2aaf6621c1f0d..73f029b59bc56a 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" @@ -27,7 +27,7 @@ namespace { class FmaxSimpleTest : public ClientLibraryTestBase {}; TEST_F(FmaxSimpleTest, FmaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 6f89e9164c8d44..e6f79b5ac55ddd 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -25,8 +25,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -119,9 +118,9 @@ class FusionTest : public HloTestBase { auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); } else { - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } } @@ -222,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{0.5}, {2.72}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -248,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -308,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -323,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(5), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -337,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -352,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -367,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -381,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -395,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -409,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -424,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -439,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -455,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -472,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -489,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -506,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -527,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { /*instructions_to_fuse=*/{negate3, dynamic_slice2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -544,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // TODO(b/64070202): Investigate failure. @@ -562,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -592,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -613,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(-15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -662,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -675,21 +687,20 @@ XLA_TEST_F(FusionTest, SharedConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(Literal::CreateR1({0}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); hlo_module->AddEntryComputation(builder.Build()) - ->CreateFusionInstruction( - {add4, add3, add2, add1, const1}, - HloInstruction::FusionKind::kLoop); + ->CreateFusionInstruction({add4, add3, add2, add1, const1}, + HloInstruction::FusionKind::kLoop); HloComputation* entry_comp = hlo_module->entry_computation(); @@ -699,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { // fused instruction contains the constant(2), the parameter, and 4 adds EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({8}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -779,7 +791,7 @@ void BM_ParallelFusion(int num_iters) { const int64 param2_dim1 = 1024; // Create computation. - ComputationBuilder builder(client, "ParallelFusion"); + XlaBuilder builder("ParallelFusion"); Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); auto param0 = builder.Parameter(0, shape0, "param0"); Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 4dd3acd9af1621..143ffbdeb409d9 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" // NB! TODO(b/74360564): These tests do not test out of bounds behavior since // that hasn't been specced yet. @@ -41,7 +41,7 @@ class GatherOperationTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; @@ -399,6 +399,184 @@ ENTRY main { RunTest(hlo_text, operand.get(), gather_indices.get()); } +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, + FusedTensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { + const char* hlo_text = R"( +HloModule FusedDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { + const string hlo_text = R"( +HloModule FusedBatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { @@ -451,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectEqual( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index ec2f49d43bd8ce..76bf47845ca045 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,8 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -39,7 +38,7 @@ class HalfTestBase : public ClientLibraryTestBase { }; using UnaryBuildFuncTy = - std::function; + std::function; struct UnaryOpTestParam { std::function compute_func; @@ -51,8 +50,8 @@ class UnaryOpTest : public HalfTestBase, XLA_TEST_P(UnaryOpTest, Ops) { std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); @@ -79,30 +78,21 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, - ::testing::Values(UnaryOpTestParam{[](half x) { return abs(x); }, - &ComputationBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &ComputationBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, - &ComputationBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, - &ComputationBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, - &ComputationBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, - &ComputationBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, - &ComputationBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, - &ComputationBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, - &ComputationBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, - &ComputationBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, - &ComputationBuilder::Tanh} + ::testing::Values( + UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, + &XlaBuilder::Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, + UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} - )); + )); struct UnaryPredTestParam { std::function compute_func; @@ -115,8 +105,8 @@ class UnaryPredTest : public HalfTestBase, XLA_TEST_P(UnaryPredTest, Ops) { std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); @@ -136,11 +126,11 @@ XLA_TEST_P(UnaryPredTest, Ops) { INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, - &ComputationBuilder::IsFinite})); + &XlaBuilder::IsFinite})); using BinaryBuildFuncTy = std::function)>; + xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y, + tensorflow::gtl::ArraySlice)>; struct BinaryOpTestParam { std::function compute_func; @@ -153,12 +143,12 @@ class BinaryOpTest : public HalfTestBase, XLA_TEST_P(BinaryOpTest, Ops) { std::vector x({half(1.0), half(2.0), half(3.0), half(-4.0)}); std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); - ComputationDataHandle y_opnd; + XlaOp y_opnd; auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", &builder, &y_opnd); @@ -184,21 +174,21 @@ INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( BinaryOpTestParam{[](half x, half y) { return x + y; }, - &ComputationBuilder::Add}, + &XlaBuilder::Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &ComputationBuilder::Atan2}, + &XlaBuilder::Atan2}, BinaryOpTestParam{[](half x, half y) { return x / y; }, - &ComputationBuilder::Div}, + &XlaBuilder::Div}, BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &ComputationBuilder::Max}, + &XlaBuilder::Max}, BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &ComputationBuilder::Min}, + &XlaBuilder::Min}, BinaryOpTestParam{[](half x, half y) { return x * y; }, - &ComputationBuilder::Mul}, + &XlaBuilder::Mul}, BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &ComputationBuilder::Pow}, + &XlaBuilder::Pow}, BinaryOpTestParam{[](half x, half y) { return x - y; }, - &ComputationBuilder::Sub} + &XlaBuilder::Sub} )); @@ -214,12 +204,12 @@ class BinaryPredTest XLA_TEST_P(BinaryPredTest, Ops) { std::vector x({half(1.0), half(2.0), half(0.2), half(-4.0)}); std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); - ComputationDataHandle y_opnd; + XlaOp y_opnd; auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", &builder, &y_opnd); @@ -239,17 +229,17 @@ XLA_TEST_P(BinaryPredTest, Ops) { INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &ComputationBuilder::Eq}, + &XlaBuilder::Eq}, BinaryPredTestParam{[](half x, half y) { return x != y; }, - &ComputationBuilder::Ne}, + &XlaBuilder::Ne}, BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &ComputationBuilder::Ge}, + &XlaBuilder::Ge}, BinaryPredTestParam{[](half x, half y) { return x > y; }, - &ComputationBuilder::Gt}, + &XlaBuilder::Gt}, BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &ComputationBuilder::Le}, + &XlaBuilder::Le}, BinaryPredTestParam{[](half x, half y) { return x < y; }, - &ComputationBuilder::Lt} + &XlaBuilder::Lt} )); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 9984aba089be89..08ed826c80823e 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -93,17 +93,16 @@ HloTestBase::HloTestBase(se::Platform* test_platform, } /* static */ -std::unique_ptr HloTestBase::CreateNewModule() { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); +std::unique_ptr HloTestBase::CreateNewModule(const string& name) { + return MakeUnique(name, VersionedComputationHandle(), + GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); return debug_options; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 79fcea9403e6e2..eb3a2ea76a667a 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -85,13 +85,21 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - static std::unique_ptr CreateNewModule(); + static std::unique_ptr CreateNewModule( + const string& name = TestName()); // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. static DebugOptions GetDebugOptionsForTest(); + // Gets an HloModuleConfig with options appropriate for tests. + static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return config; + } + // Executes the given module and return the result as a Literal. StatusOr> Execute( std::unique_ptr module, @@ -176,9 +184,13 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT(param_no, - module->mutable_entry_computation_layout()->parameter_count()); - module->mutable_entry_computation_layout() + ASSERT_LT( + param_no, + module->mutable_host_entry_computation_layout()->parameter_count()); + module->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -186,7 +198,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -194,7 +209,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index da4cf4ae0c31bc..c8a05c2e9e971d 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -67,7 +67,7 @@ HloModule& HloVerifiedTestBase::module() { void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); VerifyModule(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index c28f79ae386670..cde1dcd9cd10c8 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,978 +15,93 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include -#include -#include - -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( - const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return ::testing::AssertionFailure() - << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) - << " got: " << ShapeUtil::HumanString(actual); - } - if (ShapeUtil::IsTuple(expected)) { - if (ShapeUtil::TupleElementCount(expected) != - ShapeUtil::TupleElementCount(actual)) { - return ::testing::AssertionFailure() - << "want tuple element count: " - << ShapeUtil::TupleElementCount(expected) - << " got tuple element count: " - << ShapeUtil::TupleElementCount(actual); - } - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) - << "mismatch in tuple index " << i; - if (!result) { - return result; - } - } - } else { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { - return ::testing::AssertionFailure() - << "want rank of: " << ShapeUtil::HumanString(expected) - << " got rank of: " << ShapeUtil::HumanString(actual); - } - if (expected.element_type() != actual.element_type()) { - return ::testing::AssertionFailure() - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - } - if (expected.dimensions_size() != actual.dimensions_size()) { - return ::testing::AssertionFailure() - << "want dimensions_size " << expected.dimensions_size() - << " got dimensions_size " << actual.dimensions_size(); - } - for (int i = 0; i < expected.dimensions_size(); ++i) { - if (expected.dimensions(i) != actual.dimensions(i)) { - return ::testing::AssertionFailure() - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - } - } - return ::testing::AssertionSuccess(); -} - -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_TRUE(EqualShapes(expected, actual)); -} - -/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( - const Shape& expected, const Shape& actual) { - ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); -} - -namespace { - -// Return a literal with all arrays of type FromNativeT converted to type -// ToNativeT in the given literal. -template -std::unique_ptr ConvertType(const Literal& literal) { - // First construct shape of the result. - Shape result_shape(literal.shape()); - ShapeUtil::ForEachMutableSubshape( - &result_shape, [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == - primitive_util::NativeToPrimitiveType()) { - subshape->set_element_type( - primitive_util::NativeToPrimitiveType()); - } - }); - auto result = MakeUnique(result_shape); - - // Then copy over the data from 'literal' converting FromNativeT values to - // ToNativeT values as necessary. - ShapeUtil::ForEachSubshape( - literal.shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { - if (subshape.element_type() == - primitive_util::NativeToPrimitiveType()) { - auto src = literal.data(shape_index); - auto dest = result->data(shape_index); - for (int64 i = 0; i < src.size(); ++i) { - dest[i] = static_cast(src[i]); - } - } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); - } - } - }); - return result; -} - -} // namespace - -/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { - return ConvertType(literal); -} - -/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { - return ConvertType(literal); -} - namespace { -string Hostname() { - char hostname[1024]; - gethostname(hostname, sizeof hostname); - hostname[sizeof hostname - 1] = 0; - return string(hostname); -} - -// Helper function for comparing a floating point type, FloatT, bitwise equal -// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT -// -- on miscompare, a nice error message is given in the AssertionFailure. -template -::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { - auto ulhs = tensorflow::bit_cast(lhs); - auto urhs = tensorflow::bit_cast(rhs); - auto lhs_double = static_cast(lhs); - auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { - return ::testing::AssertionFailure() << Printf( - "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, - lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(), - rhs_double, rhs_double); - } - return ::testing::AssertionSuccess(); -} - -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). -template -::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { - if (lhs == rhs) { +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + auto get_hostname = [] { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); + }; + int64 now_usec = tensorflow::Env::Default()->NowMicros(); + string filename = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), + tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), + now_usec, name.c_str())); + TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, + literal.ToProto())); + LOG(ERROR) << "wrote to " << name << " file: " << filename; +} + +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); +} + +::testing::AssertionResult StatusToAssertion(const Status& s) { + if (s.ok()) { return ::testing::AssertionSuccess(); } - ::testing::Message msg; - msg << "Expected equality of these values:"; - msg << "\n " << lhs; - msg << "\n " << rhs; - - return ::testing::AssertionFailure() << msg; -} - -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. -template <> -::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(Eigen::half lhs, - Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(complex64 lhs, - complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); - if (!res) { - return res; - } - return CompareEqual(lhs.imag(), rhs.imag()); -} - -// A recursive function which iterates through every index of expected and -// actual literal and compares their values elementwise. Returns true if all -// elements are equal. -template -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { - if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = expected.Get(multi_index); - NativeT actual_value = actual.Get(multi_index); - ::testing::AssertionResult result = - CompareEqual(expected_value, actual_value); - return result; // Defines implicit coersion to bool. - } - - bool all_match = true; - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index[dimension] = i; - all_match = all_match && ExpectLiteralsEqual( - expected, actual, multi_index, dimension + 1); - } - return all_match; + return ::testing::AssertionFailure() << s.error_message(); } } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, - const string& message) { - EXPECT_TRUE(Equal(expected, actual)) - << "expected:\n" - << expected.ToString() << "\n\tvs actual:\n" - << actual.ToString() - << (message.empty() ? "" : StrCat("\nmessage: ", message)); -} - -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { - EXPECT_FALSE(Equal(expected, actual)); -} - -/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); - - AssertEqualShapes(expected.shape(), actual.shape()); - std::vector multi_index(expected.shape().dimensions_size(), 0); - bool match = false; - switch (expected.shape().element_type()) { - case PRED: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U8: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case BF16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case C64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case TUPLE: { - bool tuple_match = true; - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(StrCat("Tuple index ", i, " in ", - ShapeUtil::HumanString(expected.shape()))); - - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); - tuple_match = tuple_match ? !!result : false; - } - match = tuple_match; - break; - } - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - ::testing::AssertionResult result = ::testing::AssertionSuccess(); - if (!match) { - result = ::testing::AssertionFailure() - << "expected: " << expected.ToString() - << "\nactual: " << actual.ToString(); - VLOG(1) << result.message(); - } - return result; -} - -namespace { - -// Gets the total element count. For tuples, this is not the count of tuple -// elements, but the sum of elements of each tuple element. -int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); - int64 total = 0; - for (int64 i = 0; i < tuple_elements; ++i) { - total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); - } - return total; - } else { - return ShapeUtil::ElementsIn(shape); - } -} - -// Calling ToString on a literal with over 100 million elements takes around -// 3 minutes. The utility of printing a literal with >1000 elements is -// questionable, especially when writing the Literal proto to disk is orders -// of magnitude faster. -string TruncateHugeLiteral(const Literal& literal) { - return RecursiveElementCount(literal.shape()) < 1000 - ? literal.ToString() - : "[TRUNCATED, Literal with more than 1000 values]"; +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + return StatusToAssertion(literal_comparison::EqualShapes(expected, actual)); } -// Returns whether the actual and expected values are mismatched with respect to -// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. -template -bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) { + if (expected.ShortDebugString() != actual.ShortDebugString()) { + return ::testing::AssertionFailure() + << "want: " << expected.ShortDebugString() + << " got: " << actual.ShortDebugString(); } + return ::testing::AssertionSuccess(); } -template <> -bool NanMismatch(complex64 expected, complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); -} - -template <> -bool NanMismatch(half expected, half actual, bool relaxed_nans) { - return NanMismatch(static_cast(expected), - static_cast(actual), relaxed_nans); -} - -// Converts the given floating-point value to a string. -template -string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); -} - -template <> -string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); -} - -// Returns the absolute value of the given floating point value. This function -// is used instead of std::abs directly in order to allow type-dependent -// implementations for NearComparator. -template -float FpAbsoluteValue(NativeT value) { - return std::abs(value); -} - -template <> -float FpAbsoluteValue(bfloat16 value) { - return FpAbsoluteValue(static_cast(value)); -} - -template <> -float FpAbsoluteValue(half value) { - return FpAbsoluteValue(static_cast(value)); -} - -// Helper class for comparing floating-point literals within an error bound. -template -class NearComparator { - public: - // Compares the two array literals elementwise and returns an assertion - // result. The assertion result is successful if all actual and expected - // elements are within the given error bound. In case of error, the assertion - // result contains a detailed error message in case of failure. - static ::testing::AssertionResult Compare(const Literal& expected, - const Literal& actual, - ErrorSpec error, - bool detailed_message) { - NearComparator comparator(expected, actual, error, - detailed_message); - return comparator.Run(); - } - - private: - // Data structure encapsulating metadata about a single element mismatch. - struct Mismatch { - NativeT actual; - NativeT expected; - float rel_error; - float abs_error; - - // The linear index of the failure within the shape. This linear index is - // from the 'actual' literal. - int64 linear_index; - - bool operator<(const Mismatch& other) const { - return rel_error < other.rel_error; - } - - string ToString(const Shape& shape) const { - return Printf( - "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), - LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), - rel_error, abs_error); - } - }; - - explicit NearComparator(const Literal& expected, const Literal& actual, - ErrorSpec error, bool detailed_message) - : expected_(expected), - actual_(actual), - error_(error), - detailed_message_(detailed_message), - abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), - abs_error_buckets_(kErrorBucketBounds.size(), 0), - rel_error_buckets_(kErrorBucketBounds.size(), 0) {} - - // Runs the comparison between expected and actual literals. - ::testing::AssertionResult Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(actual_)); - - // If the shapes mismatch, we simply fail the expectation instead of - // printing out data, as it's a type error rather than a value error. - ::testing::AssertionResult equal_shapes = - LiteralTestUtil::EqualShapes(expected_.shape(), actual_.shape()); - if (!equal_shapes) { - return equal_shapes; - } - if (!ShapeUtil::IsArray(expected_.shape())) { - return ::testing::AssertionFailure() << "Expected array shape"; - } - - mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); - mismatches_.PopulateWithValue(false); - - CompareLiterals(); - - if (num_mismatches_ == 0) { - return ::testing::AssertionSuccess(); - } else if (!VLOG_IS_ON(1)) { - LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected_.shape()) - << " " << TruncateHugeLiteral(expected_); - LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual_.shape()) - << " " << TruncateHugeLiteral(actual_); - LOG(INFO) << "Dumping literals to temp files..."; - WriteLiteralToTempFile(expected_, "expected"); - WriteLiteralToTempFile(actual_, "actual"); - WriteLiteralToTempFile(mismatches_, "mismatches"); - } - return ::testing::AssertionFailure() << ErrorMessage(); - } - - // Insert the given absolute value into the absolute value bucket vector. The - // bounds of the buckets are given by kAbsValueBucketBounds. - void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { - // Adjust the bucket containing the absolute values of the 'actual' - // elements. - const float abs_value = FpAbsoluteValue(value); - for (int i = 0; i < abs_value_buckets_.size(); ++i) { - if (i == abs_value_buckets_.size() - 1 || - (abs_value >= kAbsValueBucketBounds[i] && - abs_value < kAbsValueBucketBounds[i + 1])) { - // The first value of the pair is the count of elements in the bucket, - // the second is the count of mismatches in the bucket. - abs_value_buckets_[i].first++; - if (is_mismatch) { - abs_value_buckets_[i].second++; - } - return; - } - } - } - - // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { - CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); - for (int i = 0; i < error_buckets.size(); ++i) { - if (error >= kErrorBucketBounds[i]) { - error_buckets[i]++; - } - } - } - - // Compares the two given elements from the expected and actual literals at - // the given literal_index and keeps track of various mismatch statistics. - void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { - const bool is_nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); - float abs_error; - float rel_error; - if (actual == expected) { - abs_error = 0; - rel_error = 0; - } else if (is_nan_mismatch) { - num_nan_mismatches_++; - // A nan mismatch is considered to have infinite error. rel_error is used - // for sorting a std::set of the top mismatchs, and a nan value here will - // result in undefined behavior because nan's do not satisfy the strict - // weak ordering requirement of std containers. - abs_error = std::numeric_limits::infinity(); - rel_error = std::numeric_limits::infinity(); - } else { - abs_error = FpAbsoluteValue(actual - expected); - rel_error = abs_error / FpAbsoluteValue(expected); - } - const bool is_abs_mismatch = abs_error > error_.abs; - const bool is_rel_mismatch = rel_error > error_.rel; - const bool is_mismatch = - is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); - - // Update the error of the relative bucket only if the *absolute* error - // bound is exceeded and vice versa. - if (is_abs_mismatch) { - num_abs_mismatches_++; - UpdateErrorBucket(rel_error, &rel_error_buckets_); - } - if (is_rel_mismatch) { - num_rel_mismatches_++; - UpdateErrorBucket(abs_error, &abs_error_buckets_); - } - - UpdateAbsValueBucket(actual, is_mismatch); - - if (!is_mismatch) { - return; - } - - num_mismatches_++; - - // Keep track of the kTopRelativeErrorCount relative error mismatches. - if (top_rel_mismatches_.size() < kTopRelativeErrorCount || - rel_error > top_rel_mismatches_.begin()->rel_error) { - Mismatch mismatch = {actual, expected, rel_error, abs_error, - linear_index}; - top_rel_mismatches_.insert(mismatch); - if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { - top_rel_mismatches_.erase(top_rel_mismatches_.begin()); - } - } - - mismatches_.data()[linear_index] = true; - } - - // Compares the two literals elementwise. - void CompareLiterals() { - // Fast path optimization for the case were layouts match. - if (LayoutUtil::Equal(actual_.shape().layout(), - expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); - const int64 len = expected_data.size(); - for (int64 i = 0; i < len; ++i) { - CompareValues(expected_data[i], actual_data[i], i); - } - return; - } - std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); - CompareLiteralsSlow(0, &multi_index); - } - - // Slow path for CompareLiterals when 'actual' and 'expected' literals have - // different layouts. In this case, multidimensional indices are constructed - // and indexed for each element. - void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { - if (dimension == multi_index->size()) { - CompareValues(expected_.Get(*multi_index), - actual_.Get(*multi_index), - IndexUtil::MultidimensionalIndexToLinearIndex( - actual_.shape(), *multi_index)); - } else { - for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - CompareLiteralsSlow(dimension + 1, multi_index); - } - } - } - - // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { - int64 now_usec = tensorflow::Env::Default()->NowMicros(); - string filename = tensorflow::io::JoinPath( - tensorflow::testing::TmpDir(), - Printf("tempfile-%s-%llx-%s", Hostname().c_str(), now_usec, - name.c_str())); - TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal.ToProto())); - LOG(ERROR) << "wrote to " << name << " file: " << filename; - } - - // Returns an error message string with a detailed breakdown of the - // mismatches. Called after calling Run(). - string ErrorMessage() { - string out; - int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); - - auto percent_string = [](float a, float b) { - float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); - }; - - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); - if (num_nan_mismatches_ > 0) { - StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); - } - Appendf(&out, "Top relative error mismatches:\n"); - for (auto it = top_rel_mismatches_.rbegin(); - it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); - } - - if (!detailed_message_) { - return out; - } - - StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); - CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); - for (int i = 0; i < abs_value_buckets_.size(); ++i) { - const int64 bucket_size = abs_value_buckets_[i].first; - const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); - } - - auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { - StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); - CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); - for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); - } - }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); - print_accum_buckets( - "Relative error breakdown of elements exceeding abs error bound", - num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); - print_accum_buckets( - "Absolute error breakdown of elements exceeding rel error bound", - num_rel_mismatches_, abs_error_buckets_); - return out; - } - - // 'actual' and 'expected' literals being compared. - const Literal& expected_; - const Literal& actual_; - - // The error bounds of the comparison. - ErrorSpec error_; - - // Whether to include detailed breakdown of mismatches in the error message. - bool detailed_message_; - - // Number of element element mismatches encountered so far. - int64 num_mismatches_ = 0; - - // Number of elements with a nan mismatch. - int64 num_nan_mismatches_ = 0; - - // Number of elements which exceed the absolute/relative error bound. - int64 num_abs_mismatches_ = 0; - int64 num_rel_mismatches_ = 0; - - // A Literal containing which elements did not match in the expected and - // actual literals. mismatches_ contains PREDs and is of the same sizes as - // the comparison literals. - Literal mismatches_; - - // The number of mismatches to report in the output, sorted by relative error - // magnitude. - static constexpr int64 kTopRelativeErrorCount = 5; - - // The set of mismatches with the largest relative error. The size of this set - // is bounded by kTopRelativeErrorCount. - std::multiset top_rel_mismatches_; - - // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the - // bounds of these buckets. abs_value_buckets_ contains a pair for each - // bucket: the element count and failure count. - static constexpr std::array kAbsValueBucketBounds = { - 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; - std::vector> abs_value_buckets_; - - // Buckets for relative and absolute errors. The relative error buckets only - // contains those elements which exceed the *absolute* error bound, and vice - // versa. This makes it easy to see the effect of adjusting the relative (or - // absolute) error bound on the success of the comparison. kErrorBucketBounds - // are the lower bounds of the buckets in both vectors. The error buckets are - // a cumulative distribution so an error value may appear in more than one - // bucket. For example an error value of 0.003 may appear in the buckets - // bounded by 0.01, 0.1, and 1.0. - static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, - 0.01, 0.1, 1}; - std::vector abs_error_buckets_; - std::vector rel_error_buckets_; -}; - -template -constexpr std::array NearComparator::kAbsValueBucketBounds; -template -constexpr std::array NearComparator::kErrorBucketBounds; - -// Helper function for comparing two literals for nearness. Handles tuple-shapes -// via recursion. shape_index is the ShapeIndex of expected (or actual) -// currently being compared. -::testing::AssertionResult NearHelper(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - bool detailed_message, - const ShapeIndex& shape_index) { - ::testing::AssertionResult err = - LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); - if (!err) { - return err; - } - - if (ShapeUtil::IsTuple(expected.shape())) { - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - ShapeIndex element_index = shape_index; - element_index.push_back(i); - ::testing::AssertionResult res = - NearHelper(expected_element, actual_element, error, detailed_message, - element_index); - if (!res) { - string err_message = - Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), res.message()); - if (err) { - err = ::testing::AssertionFailure() << err_message; - } else { - err << err_message; - } - } - } - if (!err && shape_index.empty()) { - // Emit a top-level error message containing the top-level shape in case - // of mismatch. - int64 total_elements = RecursiveElementCount(actual.shape()); - err = ::testing::AssertionFailure() - << Printf("\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), - total_elements, err.message()); - } - return err; - } - - if (ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())) { - switch (expected.shape().element_type()) { - case BF16: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F16: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F32: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case F64: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - case C64: - return NearComparator::Compare(expected, actual, error, - detailed_message); - break; - default: - LOG(FATAL) << "Unsupported primitive type in near comparator: " - << PrimitiveType_Name(expected.shape().element_type()) - << ". Must be floating-point type."; - } - } - - // Non-floating point literal. - return LiteralTestUtil::Equal(expected, actual); +/* static */ ::testing::AssertionResult LiteralTestUtil::Equal( + const LiteralSlice& expected, const LiteralSlice& actual) { + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } -} // namespace - /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, - bool detailed_message) { - return NearHelper(expected, actual, error, detailed_message, - /*shape_index=*/{}); -} - -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - const string& message) { - ::testing::AssertionResult res = - Near(expected, actual, error, /*detailed_message=*/false); - if (!res) { - res << "Expected: " << TruncateHugeLiteral(expected) << "\n"; - res << "Actual: " << TruncateHugeLiteral(actual) << "\n"; - if (!message.empty()) { - res << StrCat("\nmessage: ", message); - } - } - EXPECT_TRUE(res); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message) { + return StatusToAssertion(literal_comparison::Near( + expected, actual, error_spec, detailed_message, &OnMiscompare)); } -/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, +/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; - return Near(expected, actual, *error); + return StatusToAssertion(literal_comparison::Near( + expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); } VLOG(1) << "Expects equal"; - return Equal(expected, actual); -} - -/*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error) { - EXPECT_TRUE(NearOrEqual(expected, actual, error)); -} - -/* static */ string LiteralTestUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); -} - -/* static */ std::unique_ptr LiteralTestUtil::Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { - int64 new_num_elements = 1; - for (int64 i = 0; i < new_dimensions.size(); ++i) { - new_num_elements *= new_dimensions[i]; - } - CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); - CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - - auto new_literal = MakeUnique( - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); - - // Create a new shape with the given minor-to-major layout. This shape is used - // solely for converting linear address to multi-dimensional addresses when - // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); - *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - - // Copy data into new literal, element-by-element. - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - std::vector from_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - std::vector to_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); - switch (literal.shape().element_type()) { - case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - default: - LOG(FATAL) << "Unhandled primitive element type: " - << PrimitiveType_Name(literal.shape().element_type()); - } - } - - return new_literal; + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a755568c0f098e..d1b8a6cf0b2552 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -38,282 +39,190 @@ limitations under the License. namespace xla { -// Structure describing permissible absolute and relative error bounds. -struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) - : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} - - float abs; // Absolute error bound. - float rel; // Relative error bound. - - // If relaxed_nans is true then any result is valid if we are expecting NaNs. - // In effect, this allows the tested operation to produce incorrect results - // for inputs outside its mathematical domain. - bool relaxed_nans; -}; - // Utility class for making expectations/assertions related to XLA literals. class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. - static ::testing::AssertionResult EqualShapes(const Shape& expected, - const Shape& actual); - static void AssertEqualShapes(const Shape& expected, const Shape& actual); + static ::testing::AssertionResult EqualShapes( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; // Asserts that the provided shapes are equal as defined in AssertEqualShapes // and that they have the same layout. - static void AssertEqualShapesAndLayouts(const Shape& expected, - const Shape& actual); - - // If the given literal's data type is bfloat16, converts it to a float - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - - // If the given literal's data type is float, converts it to a bfloat16 - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); - - // Asserts that the expected and actual literals are (bitwise) equal for all - // elements in the literal. Also, asserts that the rank, dimensions sizes, and - // primitive type are equal. - static ::testing::AssertionResult Equal( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + static ::testing::AssertionResult EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; - // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual, - const string& message = ""); - - // Expects that expected and actual are Not Equal. - static void ExpectNotEqual(const Literal& expected, const Literal& actual); + static ::testing::AssertionResult Equal(const LiteralSlice& expected, + const LiteralSlice& actual) + TF_MUST_USE_RESULT; // Asserts the given literal are (bitwise) equal to given expected values. template - static void ExpectR0Equal(NativeT expected, const Literal& actual); + static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); + template static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR2Equal( std::initializer_list> expected, - const Literal& actual); + const LiteralSlice& actual); + template static void ExpectR3Equal( std::initializer_list< std::initializer_list>> expected, - const Literal& actual); + const LiteralSlice& actual); // Asserts the given literal are (bitwise) equal to given array. template static void ExpectR2EqualArray2D(const Array2D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR3EqualArray3D(const Array3D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR4EqualArray4D(const Array4D& expected, - const Literal& actual); + const LiteralSlice& actual); - // Asserts that the expected and actual literals are within the given error - // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. + // Decorates literal_comparison::Near() with an AssertionResult return type. // - // Tuples are matched recursively. When comparing tensors of - // non-floating-point type, checks for exact equality, ignoring the ErrorSpec. - // - // If the shape of the literals is neither a complex/floating-point tensor nor - // a tuple which contains a complex/floating-point tensor, Near() is - // equivalent to Equal(). We don't raise an error in this case, because we - // want to allow callers to call Near() even if they have no preconceptions - // about the shapes being compared. - // - // If detailed_message is true, then the error message in the assertion result - // will contain a more detailed breakdown of mismatches. + // See comment on literal_comparison::Near(). static ::testing::AssertionResult Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error, + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message = false) TF_MUST_USE_RESULT; - // Expects expected and actual to be Near with the given error. - static void ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error, const string& message = ""); - // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. template - static void ExpectR0Near(NativeT expected, const Literal& actual, + static void ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3Near( std::initializer_list< std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4Near( std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. template static void ExpectR2NearArray2D(const Array2D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3NearArray3D(const Array3D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4NearArray4D(const Array4D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. static ::testing::AssertionResult NearOrEqual( - const Literal& expected, const Literal& actual, + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; - // If the error spec is given, expects the expected and the actual to be near; - // otherwise, expects them to be equal. Tuples will be compared recursively. - static void ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error); - - // Returns a multi-dimensional index as a string. For example: '{7, 8}' will - // be returned for a 2-dimensional index with dimension 0 index equal to 7, - // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); - - // Creates a literal with a new shape with the given new dimensions using the - // data in the given input literal. For reshaping purposes the (flat) data - // buffer of the input literal is assumed to have the given minor_to_major - // layout order. - static std::unique_ptr Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const Literal& literal); - - // Creates a literal with the supplied shape, and uses the provided value - // generator to populate the literal's values. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation, and using the engine as entropy generator. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, typename E, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); - private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR0(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR1(expected), actual); + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR2(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3Equal( std::initializer_list>> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR3(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( - const Array2D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); + const Array2D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( - const Array3D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); + const Array3D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( - const Array4D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); + const Array4D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR0(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const Literal& actual, + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR1(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3Near( std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); } template @@ -321,63 +230,29 @@ template std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( - const Array2D& expected, const Literal& actual, + const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( - const Array3D& expected, const Literal& actual, + const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( - const Array4D& expected, const Literal& actual, + const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); - return std::move(literal); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - std::normal_distribution generator(mean, stddev); - return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { - std::minstd_rand0 engine; - return CreateRandomLiteral(shape, &engine, mean, stddev); + EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 9d619a77c7e8d6..bbac7285aefbb1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { std::unique_ptr literal = Literal::MakeTuple({ Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - LiteralTestUtil::ExpectEqual(*literal, *literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { + auto expected = Literal::CreateR1({1, 2, 3}); + auto actual = Literal::CreateR1({4, 5, 6}); + ::testing::AssertionResult result = + LiteralTestUtil::Equal(*expected, *actual); + EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); + EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); +} + TEST(LiteralTestUtilTest, NearComparatorR1) { auto a = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 3023df47cda33f..2c45f19c090d26 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -62,8 +62,8 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - ASSERT_TRUE( - CompileToAotCompilationResult(std::move(hlo_module), options).ok()); + TF_ASSERT_OK( + CompileToAotCompilationResult(std::move(hlo_module), options).status()); ResetIrHook(); StatusOr filecheck_result = RunFileCheck(ir_, pattern); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 3704ddd8010bf7..a366afe8262e1f 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,7 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -29,27 +30,31 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +namespace { + using xla::string; -xla::Computation Doubler(xla::Client* client) { - xla::ComputationBuilder builder(client, "doubler"); +xla::XlaComputation Doubler() { + xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto x = builder.Parameter(0, r0f32, "x"); builder.Mul(x, builder.ConstantR0(2.0)); return std::move(builder.Build().ValueOrDie()); } +} // namespace + int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); - xla::ComputationBuilder builder(client, "aot_test_helper"); + xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); auto opaque_param = builder.Parameter(0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(client), {sum}); + builder.Call(Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; @@ -71,8 +76,8 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); - xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::CompileOnlyClient::AotComputationInstance instance{ + xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); + xla::CompileOnlyClient::AotXlaComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 44c6811df84f49..96858c00d6bbe5 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -210,12 +210,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {1})); + LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -239,16 +239,16 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 0})); + LiteralSlice(*result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {0, 1})); + LiteralSlice(*result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 2})); + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -274,9 +274,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -321,9 +321,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralView::Create(*result_literal, {0})); + LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); + {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -361,9 +361,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); + {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -391,16 +391,16 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal( {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralView::Create(*result_0_literal, {0})); + LiteralSlice(*result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); + {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); + {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -447,7 +447,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -502,7 +502,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), error_spec_); } } @@ -548,7 +548,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralView::Create(*result_literal, index)); + 165.0, LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -754,9 +754,9 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); + {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index ca8e4cdbdb6a8f..88797a7d0a7d05 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -35,9 +35,9 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate(int device_ordinal, - uint64 size, - bool retry_on_failure) { +StatusOr TestAllocator::Allocate(int device_ordinal, + uint64 size, + bool retry_on_failure) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { tensorflow::mutex_lock lock(count_mutex_); @@ -48,8 +48,7 @@ StatusOr TestAllocator::Allocate(int device_ordinal, retry_on_failure); } -tensorflow::Status TestAllocator::Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) { +Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { tensorflow::mutex_lock lock(count_mutex_); @@ -149,8 +148,6 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ExecutableRunOptions run_options; - run_options.set_inter_op_thread_pool( - local_client_->backend().inter_op_thread_pool()); run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); return run_options; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 3bbb760c806412..258226523d830b 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -46,10 +46,9 @@ class TestAllocator : public StreamExecutorMemoryAllocator { platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { } - StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - tensorflow::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64 allocation_count() const; diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index 174d433a9e1731..c0c02e584c2348 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -29,7 +29,7 @@ namespace { class LogTest : public ClientLibraryTestBase {}; XLA_TEST_F(LogTest, LogZeroValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); builder.Log(x); @@ -41,7 +41,7 @@ TEST_F(LogTest, LogTenValues) { std::vector input = {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1(input); builder.Log(x); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 7fa61eb33c2930..27fd36e06acdc5 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -52,12 +51,7 @@ class MatOpsSimpleTest : public ClientLibraryTestBase {}; template class MatOpsSimpleTest_F16F32 : public MatOpsSimpleTest {}; -// TODO(bixia): This test for F16 failed on GPU 02-25-2018. -#ifdef XLA_TEST_BACKEND_GPU -TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, ::testing::Types); -#else TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); -#endif XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { using T = TypeParam; @@ -72,8 +66,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -101,8 +94,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { std::unique_ptr expected = Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -121,8 +113,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { std::unique_ptr expected = Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -171,11 +162,8 @@ string PrintTestLinspaceMaxParam( } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 -// TODO(bixia): This test failed on GPU 02-25-2018 -#ifdef XLA_TEST_BACKEND_CPU XLA_TEST_P(TestLinspaceMaxParametric, TestF16) { TestImpl(); } #endif -#endif XLA_TEST_P(TestLinspaceMaxParametric, TestF32) { TestImpl(); } INSTANTIATE_TEST_CASE_P( diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 0a603f4954badd..7bfc8eb546d10d 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -108,7 +107,7 @@ class MultiOutputFusionTest : public HloTestBase { expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer( std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -168,7 +167,7 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } }; @@ -211,5 +210,175 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) + } + + ENTRY PredFloatMOF { + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[] parameter(0) + multiply = f32[] multiply(p, p) + less-than = pred[] less-than(p, multiply) + ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + } + + map_computation { + p0 = f32[] parameter(0) + fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + const = f32[] constant(0) + ROOT select = f32[] select(gte0, gte1, const) + } + + ENTRY MapMOF { + p1 = f32[3] parameter(0) + ROOT map = f32[3] map(p1), to_apply=map_computation + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); +} + +const char* const kScalarOps = R"( + HloModule m + + Add { + lhsadd = f32[] parameter(0) + rhsadd = f32[] parameter(1) + ROOT add = f32[] add(lhsadd, rhsadd) + } + + Max { + lhsmax = f32[] parameter(0) + rhsmax = f32[] parameter(1) + ROOT max = f32[] maximum(lhsmax, rhsmax) + } +)"; + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(1.17549e-38) + r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max + r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add + ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 97dab860c06bdd..838f1b4e2f0f0e 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -161,7 +160,7 @@ XLA_TEST_F(ParamsTest, MissingParameter) { auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); auto computation_status = builder.Build(); - ASSERT_NE(computation_status.status(), tensorflow::Status::OK()); + ASSERT_NE(computation_status.status(), Status::OK()); } XLA_TEST_F(ParamsTest, UnusedParameter) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 29a4f75001c688..1a2de6937c3e13 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - LiteralTestUtil::ExpectEqual(*result1, *result2); - LiteralTestUtil::ExpectEqual(*result1, *result3); - LiteralTestUtil::ExpectNotEqual(*result1, *result4); - LiteralTestUtil::ExpectNotEqual(*result4, *result5); - LiteralTestUtil::ExpectNotEqual(*result5, *result6); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index c0a2c0ca4cb841..9052b188ed09a7 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -73,7 +73,7 @@ ENTRY reduce.1 { } )"; - return tools::Parse(hlo_string); + return ParseHloString(hlo_string); } // TODO(b/72454718): XLA:GPU does not support executing code compiled without diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index bcc05c2d41d843..d671d40456a276 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 10a3da3a387641..266760e8202fdd 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -356,12 +356,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 1.0f; - }; - TF_EXPECT_OK(arg_literal->Populate(generator)); - + auto arg_literal = MakeUnique(shape); + arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); Padding padding = Padding::kValid; @@ -371,13 +367,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); - auto out_generator = - [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 27.0f; - }; - TF_EXPECT_OK(expected->Populate(out_generator)); - + auto expected = MakeUnique(result_shape); + expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -1348,7 +1339,7 @@ INSTANTIATE_TEST_CASE_P( class ReduceWindowTextTest : public HloTestBase {}; TEST_F(ReduceWindowTextTest, R2General256x384) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1365,7 +1356,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1382,7 +1373,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= } TEST_F(ReduceWindowTextTest, R2General2x5) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1399,7 +1390,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1417,7 +1408,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R3Window mul { lhs = f32[] parameter(0) @@ -1435,7 +1426,7 @@ ENTRY R3Window { } TEST_F(HloTestBase, ReduceWindowIdentity) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule ReduceWindowIdentity identity.pad_to_reduce_window { param0 = f32[] parameter(0) @@ -1444,7 +1435,26 @@ identity.pad_to_reduce_window { ENTRY reduce-window-identity { operand = f32[1,32,64]{2,1,0} parameter(0) constant.4466 = f32[] constant(0) - ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window + ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + +TEST_F(HloTestBase, ReduceWindowS32) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] { + %param0 = s32[] parameter(0) + ROOT %param1 = s32[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { + %parameter.0 = s32[81,8]{1,0} parameter(0) + %parameter.1 = s32[] parameter(1) + ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window } )"; diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 5ebd5268992846..da1b588ec41cef 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d7462d581b8596..a4580cd71d46ad 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + expected = Literal::ConvertF32ToBF16(*expected); } - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { @@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); + Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); + Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + auto expected = Literal::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 8cbfcc6f5c4272..7cfca781acda15 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index 32db45f8a66266..f334a8c1318a59 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase { client_->TransferToServer(original).ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data).ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(original, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 0c88bef69dfc52..308d3fc78a51e6 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,9 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -43,83 +44,80 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare(NativeT lhs, NativeT rhs, bool expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + void TestCompare( + NativeT lhs, NativeT rhs, bool expected, + XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(2.1f); ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(2.1f)); ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(2)); ComputeAndCompareR0(&builder, -2, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); ComputeAndCompareR0(&builder, 7, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const uint64 a = static_cast(1) << 63; const uint64 b = a + 1; builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); @@ -128,7 +126,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const int64 a = static_cast(1) << 62; const int64 b = a - 1; builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); @@ -137,7 +135,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(0.25), builder.ConstantR0(3.5)); @@ -145,21 +143,21 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); ComputeAndCompareR0(&builder, -3, {}); } XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); builder.ConvertElementType(a, F32); @@ -172,7 +170,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.Mul(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)), builder.ConstantR0(0.5f)); @@ -191,7 +189,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { for (int32 x : data) { for (int32 y : data) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); // Signed integer overflow is undefined behavior in C++. Convert the input @@ -210,7 +208,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { for (uint32 x : data) { for (uint32 y : data) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); uint32 expected = x * y; @@ -220,7 +218,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul( builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), builder.ConstantR0(1)); @@ -229,7 +227,7 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR0(2.1f); std::unique_ptr b_literal = Literal::CreateR0(5.5f); std::unique_ptr c_literal = Literal::CreateR0(0.5f); @@ -241,9 +239,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { std::unique_ptr c_data = client_->TransferToServer(*c_literal).ConsumeValueOrDie(); - ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a"); - ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b"); - ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c"); + XlaOp a = builder.Parameter(0, a_literal->shape(), "a"); + XlaOp b = builder.Parameter(1, b_literal->shape(), "b"); + XlaOp c = builder.Parameter(2, c_literal->shape(), "c"); builder.Mul(builder.Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -252,14 +250,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); @@ -282,7 +280,7 @@ class DivS32Test : public ClientLibraryTestBase, XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Div(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); @@ -291,7 +289,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); @@ -300,9 +298,9 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividendd = CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = @@ -315,9 +313,9 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividendd = CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = @@ -364,13 +362,13 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; // clang-format on - Computation div_computation; + XlaComputation div_computation; { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle dividend = + XlaOp dividend = builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); - ComputationDataHandle divisor = + XlaOp divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Div(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); @@ -392,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend / divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } @@ -405,13 +403,13 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; // clang-format on - Computation rem_computation; + XlaComputation rem_computation; { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle dividend = + XlaOp dividend = builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); - ComputationDataHandle divisor = + XlaOp divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Rem(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); @@ -433,14 +431,14 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend % divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } } XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); @@ -450,7 +448,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). builder.Div(builder.ConstantR0(0xFFFFFFFE), @@ -460,7 +458,7 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); ComputeAndCompareR0(&builder, 2, {}); @@ -469,7 +467,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x && y, {}); @@ -480,7 +478,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) { XLA_TEST_F(ScalarComputationsTest, AndS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x & y, {}); @@ -491,7 +489,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) { XLA_TEST_F(ScalarComputationsTest, AndU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x & y, {}); @@ -502,7 +500,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) { XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x || y, {}); @@ -513,7 +511,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) { XLA_TEST_F(ScalarComputationsTest, OrS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x | y, {}); @@ -524,7 +522,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) { XLA_TEST_F(ScalarComputationsTest, OrU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x | y, {}); @@ -534,7 +532,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, !x, {}); @@ -543,7 +541,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) { XLA_TEST_F(ScalarComputationsTest, NotS32) { for (int32 x : {-1, 0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, ~x, {}); @@ -552,7 +550,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) { XLA_TEST_F(ScalarComputationsTest, NotU32) { for (uint32 x : {0, 1, 2}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, ~x, {}); @@ -560,7 +558,7 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) { } XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Select(builder.ConstantR0(true), // The predicate. builder.ConstantR0(123.0f), // The value on true. builder.ConstantR0(42.0f)); // The value on false. @@ -569,7 +567,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Select(builder.ConstantR0(false), // The predicate. builder.ConstantR0(123.0f), // The value on true. builder.ConstantR0(42.0f)); // The value on false. @@ -580,7 +578,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { // This test is an explicit version of what is happening in the following // templatized comparison tests. XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); ComputeAndCompareR0(&builder, true, {}); @@ -588,157 +586,156 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &ComputationBuilder::Eq); + TestCompare(2, 1, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &ComputationBuilder::Eq); + TestCompare(3, 3, true, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &ComputationBuilder::Ne); + TestCompare(2, 1, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &ComputationBuilder::Ge); + TestCompare(2, 1, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &ComputationBuilder::Gt); + TestCompare(1, 5, false, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &ComputationBuilder::Le); + TestCompare(2, 1, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(9, 7, false, &XlaBuilder::Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, - &ComputationBuilder::Lt); + std::numeric_limits::max(), true, &XlaBuilder::Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &ComputationBuilder::Eq); + TestCompare(2, 1, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &ComputationBuilder::Ne); + TestCompare(2, 1, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &ComputationBuilder::Ge); + TestCompare(2, 1, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &ComputationBuilder::Ge); + TestCompare(3, 3, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &ComputationBuilder::Gt); - TestCompare(5, 5, false, &ComputationBuilder::Gt); - TestCompare(5, 1, true, &ComputationBuilder::Gt); + TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(5, 5, false, &XlaBuilder::Gt); + TestCompare(5, 1, true, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &ComputationBuilder::Le); + TestCompare(2, 1, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(9, 7, false, &XlaBuilder::Lt); TestCompare(0, std::numeric_limits::max(), true, - &ComputationBuilder::Lt); + &XlaBuilder::Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &ComputationBuilder::Eq); + TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &ComputationBuilder::Ne); + TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &ComputationBuilder::Ge); + TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &ComputationBuilder::Ge); + TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &ComputationBuilder::Gt); + TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &ComputationBuilder::Le); + TestCompare(2.0, 1.2, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &ComputationBuilder::Lt); + TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &ComputationBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &ComputationBuilder::Lt); + TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &ComputationBuilder::Lt); + TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &ComputationBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &ComputationBuilder::Ge); + TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &ComputationBuilder::Ge); + TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Exp(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { - ComputationBuilder builder(client_, "log"); + XlaBuilder builder("log"); builder.Log(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tanh(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tanh(builder.ConstantR0(2.0)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -747,7 +744,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(2), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -756,7 +753,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(-5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -765,7 +762,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -774,7 +771,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(2), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -783,7 +780,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(0), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -792,7 +789,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(5.0f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -801,7 +798,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(2.5f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -810,7 +807,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(-5.0f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -819,70 +816,70 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &ComputationBuilder::Min); + TestMinMax(10, 3, 3, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &ComputationBuilder::Min); + TestMinMax(-100, 3, -100, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &ComputationBuilder::Max); + TestMinMax(10, 3, 10, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &ComputationBuilder::Max); + TestMinMax(-100, 3, 3, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &ComputationBuilder::Min); + TestMinMax(large, 3, 3, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &ComputationBuilder::Min); + TestMinMax(0, 5, 0, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &ComputationBuilder::Max); + TestMinMax(large, 3, large, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &ComputationBuilder::Max); + TestMinMax(0, 5, 5, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Min); - TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Min); + TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); + TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Max); - TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Max); + TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); + TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Div( b.Sub(b.Mul(b.ConstantR0(1), b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), @@ -895,7 +892,7 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Sub(b.Mul(b.ConstantR0(1), b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), b.Add(b.ConstantR0(7), b.ConstantR0(0)))), @@ -905,21 +902,20 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { } XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Literal zero_literal = Literal::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - ComputationDataHandle zero = - builder.Parameter(0, zero_literal.shape(), "zero"); + XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero"); builder.SqrtF32(zero); ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RoundScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Round(builder.ConstantR0(1.4f)); ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 009e7d24c5cbfa..72707f224446c7 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -16,13 +16,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -35,7 +34,7 @@ class SelectTest : public ClientLibraryTestBase { }; TEST_F(SelectTest, SelectScalarF32True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR0(123.0f); auto on_false = builder.ConstantR0(42.0f); @@ -45,7 +44,7 @@ TEST_F(SelectTest, SelectScalarF32True) { } TEST_F(SelectTest, SelectScalarS32True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR0(-42); auto on_false = builder.ConstantR0(42); @@ -55,7 +54,7 @@ TEST_F(SelectTest, SelectScalarS32True) { } TEST_F(SelectTest, SelectScalarF32False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto on_true = builder.ConstantR0(123.0f); auto on_false = builder.ConstantR0(42.0f); @@ -65,7 +64,7 @@ TEST_F(SelectTest, SelectScalarF32False) { } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR1({}); auto on_true = builder.ConstantR1({}); auto on_false = builder.ConstantR1({}); @@ -75,7 +74,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR1({false, true, false, true, false}); auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); @@ -88,7 +87,7 @@ TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({}); auto v2 = builder.ConstantR1({}); auto cmp = builder.Eq(v1, v2); @@ -102,7 +101,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); auto cmp = builder.Eq(v1, v2); @@ -116,7 +115,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); auto cmp = builder.Gt(v1, v2); @@ -131,9 +130,9 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { // Selects among two R1F32s, which come from parameters. v1 and v2 are // compared, and selection between them happens based on a gt-comparison mask. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -151,7 +150,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the // data size passed in and out is large. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Number of floats in the data passed into and out of the computation. constexpr int datalen = 15 * 1000; @@ -174,7 +173,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { expected_vec.push_back(larger); } - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -192,7 +191,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, -1, 2, -2}); auto s = builder.ConstantR0(0); auto cmp = builder.Gt(v, s); @@ -209,7 +208,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); auto s = builder.ConstantR0(2.5f); auto cmp = builder.Gt(v, s); @@ -225,7 +224,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(which); auto on_true = builder.ConstantR1({}); auto on_false = builder.ConstantR1({}); @@ -236,7 +235,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { } TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR1({-2.5f, 25.5f}); auto on_false = builder.ConstantR1({10.0f, 5.0f}); @@ -246,7 +245,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto on_true = builder.ConstantR1({-2.5f, 25.5f}); auto on_false = builder.ConstantR1({10.0f, 5.0f}); diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc deleted file mode 100644 index 29f79ec28a1ae6..00000000000000 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class SetReturnValueTest : public ClientLibraryTestBase {}; - -TEST_F(SetReturnValueTest, NoSetValue) { - ComputationBuilder builder(client_, "no_set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - - std::vector expected = {1.0, 3.0, 4.0, 0.0, -1.0, - 5.0, 6.0, -2.0, -3.0, 7.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValue) { - ComputationBuilder builder(client_, "set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueAndModify) { - ComputationBuilder builder(client_, "set_value_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { - ComputationBuilder builder(client_, "set_value_multiple_times_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(aax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaaax = builder.Add(alpha, aaax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 52195db2aa7471..5653bf11a7364b 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -197,9 +197,10 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); + auto literal = Literal::CreateR1(input); XlaBuilder builder(TestName()); - auto original = builder.ConstantR1(input); + auto original = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -210,7 +211,9 @@ class SliceR1Test : public ClientLibraryTestBase, expected.push_back(i); } - ComputeAndCompareR1(&builder, expected, {}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); + ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -365,15 +368,18 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); + auto literal = Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(spec.layout)); + auto a = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR2(&builder, *expected, {}); + ComputeAndCompareR2(&builder, *expected, {arg.get()}); } INSTANTIATE_TEST_CASE_P( @@ -453,7 +459,7 @@ class SliceR4Test : public ClientLibraryTestBase, void Run(const R4Spec& spec) { Array4D values(spec.input_dims[0], spec.input_dims[1], spec.input_dims[2], spec.input_dims[3]); - values.FillRandom(3.14f); + values.FillIota(3.14159); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 997a1d8273736a..dd7c5417336342 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -26,6 +26,7 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal @@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal, template void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } @@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( @@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution generator( @@ -107,7 +112,10 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } - std::unique_ptr literal = Literal::CreateFromShape(shape); + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); @@ -201,11 +209,13 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1(start_indices); } @@ -321,26 +331,26 @@ StatusOr> MakeConstrainedArgument( } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } -Status VerifyHloModule(const se::Platform& platform, HloModule* const module, - bool allow_mixed_precision) { +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { return HloVerifier(allow_mixed_precision).Run(module).status(); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 30c147910cae85..a8689f64981569 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -55,20 +55,32 @@ class PseudorandomGenerator { }; // Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random. +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // // Will handle special cases such as making sure that indices used for dynamic // slices are bounded, reduces that call adds use 0 as an init value, etc. +// +// If pseudo_random is true, the generated numbers will be generated +// deterministically in a pseudo random way unless the values are constrated to +// be e.g. init values as above. If pseudo_random is false, the returned values +// will be generated in a faster way that yields less interesting data, e.g. the +// values may all be just the same value. +// +// TODO(b/79942829): Make interesting argument generation fast enough that using +// pseudo_random does not save any noticeable amount of time so that the +// parameter can be removed. StatusOr>> MakeFakeArguments( - HloModule* const module); + HloModule* const module, bool pseudo_random = true); // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(const se::Platform& platform, HloModule* const module, +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision = false); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index e2067bc1b835a9..0063e7ad415e9b 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { @@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { @@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { @@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { @@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe5a1778a2cecf..fe1e3da7eca00e 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -16,14 +16,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -38,7 +37,7 @@ class TransposeTest : public ClientLibraryTestBase { }; XLA_TEST_F(TransposeTest, Transpose0x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); auto result = builder.Transpose(lhs, {1, 0}); @@ -46,7 +45,7 @@ XLA_TEST_F(TransposeTest, Transpose0x0) { } XLA_TEST_F(TransposeTest, Transpose0x42) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); auto result = builder.Transpose(lhs, {1, 0}); @@ -54,7 +53,7 @@ XLA_TEST_F(TransposeTest, Transpose0x42) { } XLA_TEST_F(TransposeTest, Transpose7x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); auto result = builder.Transpose(lhs, {1, 0}); @@ -62,7 +61,7 @@ XLA_TEST_F(TransposeTest, Transpose7x0) { } TEST_F(TransposeTest, Transpose2x2) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2({ {1.0, 2.0}, {3.0, 4.0}, }); @@ -74,7 +73,7 @@ TEST_F(TransposeTest, Transpose2x2) { } XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); auto result = builder.Transpose(operand, {1, 2, 0}); @@ -82,7 +81,7 @@ XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { } TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {1, 2, 0}); @@ -92,7 +91,7 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { } TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {2, 1, 0}); @@ -102,7 +101,7 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { } TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {0, 1, 2}); @@ -116,7 +115,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { Array2D transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}}); for (int transposes = 0; transposes <= 10; ++transposes) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto computed = builder.ConstantR2FromArray2D(input); for (int i = 0; i < transposes; ++i) { computed = builder.Transpose(computed, {1, 0}); @@ -130,7 +129,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { TEST_F(TransposeTest, Small_1x1) { auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); - ComputationBuilder builder(client_, "transpose_1x1"); + XlaBuilder builder("transpose_1x1"); auto operand = builder.ConstantR2FromArray2D(*aoperand); builder.Transpose(operand, {1, 0}); @@ -142,7 +141,7 @@ TEST_F(TransposeTest, Small_1x1) { TEST_F(TransposeTest, Small_2x2) { auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); - ComputationBuilder builder(client_, "transpose_2x2"); + XlaBuilder builder("transpose_2x2"); auto operand = builder.ConstantR2FromArray2D(*aoperand); builder.Transpose(operand, {1, 0}); @@ -162,7 +161,7 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR3FromArray3D(aoperand); builder.Transpose(operand, {0, 2, 1}); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 61be1746530a19..41189231b90e84 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -17,8 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -287,13 +285,13 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { } XLA_TEST_F(TupleTest, TuplesInAMap) { - Computation tuple_computation; + XlaComputation tuple_computation; { // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples. // // Need to put a select in there to prevent HLO-level optimizations from // optimizing out the tuples. - ComputationBuilder b(client_, "sort_square"); + XlaBuilder b("sort_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto x2 = b.Mul(x, x); auto x_smaller_tuple = b.Tuple({x, x2}); @@ -307,7 +305,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { tuple_computation = computation_status.ConsumeValueOrDie(); } - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); b.Map({input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); @@ -497,7 +495,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = Literal::CreateFromShape(sum->shape()); + auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -516,7 +514,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; // Disabled on the interpreter because bitcast doesn't exist on the interpreter. -TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { +XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { const char* testcase = R"( HloModule m diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 835e2d7e5594d7..c3abe22797f5ea 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -37,7 +37,7 @@ class UnaryOpTest : public ClientLibraryTestBase { } template void AbsSize0TestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({}); auto abs = builder.Abs(arg); @@ -50,7 +50,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); auto abs = builder.Abs(arg); @@ -59,7 +59,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); auto sign = builder.Sign(arg); @@ -69,7 +69,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignAbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({-2, 25, 0, -123}); auto sign = builder.Sign(arg); auto abs = builder.Abs(arg); @@ -84,9 +84,14 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +int64 UnaryOpTest::inf() { + return 0x7FFFFFFFFFFFFFFFl; +} + template <> void UnaryOpTest::AbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, @@ -102,7 +107,7 @@ void UnaryOpTest::AbsTestHelper() { template <> void UnaryOpTest::SignTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); auto sign = builder.Sign(arg); @@ -114,7 +119,7 @@ void UnaryOpTest::SignTestHelper() { template <> void UnaryOpTest::SignAbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); auto sign = builder.Sign(arg); @@ -139,7 +144,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) { } XLA_TEST_F(UnaryOpTest, AbsTestR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto argi = builder.ConstantR0(-5); auto absi = builder.Abs(argi); auto argf = builder.ConstantR0(-3.0f); @@ -155,7 +160,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) { } XLA_TEST_F(UnaryOpTest, SignTestR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto argi = builder.ConstantR0(-5); auto sgni = builder.Sign(argi); // -1 auto argf = builder.ConstantR0(-4.0f); @@ -176,6 +181,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); + SignTestHelper(); SignTestHelper(); SignTestHelper(); } @@ -187,7 +193,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { } XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {2, 25, 0, 123, std::numeric_limits::max()}); auto abs = builder.Abs(arg); @@ -197,7 +203,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { } XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {2, 25, 0, 123, std::numeric_limits::max()}); auto sign = builder.Sign(arg); @@ -206,7 +212,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); auto sign = builder.Sign(arg); auto abs = builder.Abs(arg); @@ -216,7 +222,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 1}); auto rhs = builder.ConstantR1({1, 1}); builder.ConvertElementType(builder.Eq(lhs, rhs), S32); @@ -225,7 +231,7 @@ XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 1}); auto rhs = builder.ConstantR1({1, 1}); builder.ConvertElementType(builder.Eq(lhs, rhs), F32); diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3dded3f7157195..5cce7a2bf82c1a 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -350,7 +349,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { } XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(10); auto x = builder.ConstantR1({-3, 3, 9, 13}); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 336fed27c6f19f..c463f3eac55e5b 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -957,22 +957,21 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); auto t = outer.Tuple({p, outer.ConstantR1({1, 1})}); - TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr tuple_shape, - outer.GetShape(t)); + TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t)); - ComputationBuilder cond(client_, "cond"); - auto cond_t = cond.Parameter(0, *tuple_shape, "t"); + XlaBuilder cond("cond"); + auto cond_t = cond.Parameter(0, tuple_shape, "t"); TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), cond.ConstantR1({42, 42})), &cond) .status()); - ComputationBuilder body(client_, "body"); - auto body_t = body.Parameter(0, *tuple_shape, "t"); + XlaBuilder body("body"); + auto body_t = body.Parameter(0, tuple_shape, "t"); auto e = body.GetTupleElement(body_t, 1); body.Tuple({e, e}); @@ -993,15 +992,15 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); TF_ASSERT_OK( Any(cond.Eq(cond_t, cond.ConstantR1({42, 42})), &cond).status()); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, element_shape, "t"); auto e = body.Broadcast(body.ConstantR0(1.0), {2}); @@ -1019,14 +1018,14 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); cond.Eq(cond_t, cond.ConstantR0(42)); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, element_shape, "t"); auto tuple = body.Tuple({body_t, body.Add(body_t, body.ConstantR0(1))}); @@ -1055,23 +1054,23 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { auto result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Tuple({outer.ConstantR0(0), outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto params = cond.Parameter(0, result_shape, "prev"); auto cond_t = cond.Add(cond.GetTupleElement(params, 1), cond.GetTupleElement(params, 0)); cond.Lt(cond_t, cond.ConstantR0(30)); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, result_shape, "t"); auto tuple = body.Tuple( - {body.Add(body.GetTupleElement(params, 0), body.ConstantR0(1)), - body.Add(body.GetTupleElement(params, 1), body.ConstantR0(1))}); + {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0(1)), + body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0(1))}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 8354bb71cb7e88..3c9a01653c6720 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,8 +17,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -83,8 +84,8 @@ Status ParseOneProfileOutputLine( string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = "([^ ]+)"; - string match_trops = "([^ ]+)"; + string match_flops = "([^ ]*)"; + string match_trops = "([^ ]*)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; @@ -119,7 +120,7 @@ Status ParseOneProfileOutputLine( // Returns void so that we can ASSERT. void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, - const Computation& computation, + const XlaComputation& computation, const Shape& lhs_arg_shape, const Shape& rhs_arg_shape) { LocalService* service = ClientLibrary::GetXlaService(client->platform()); @@ -185,7 +186,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(platform)); - ComputationBuilder builder(client, TestName()); + XlaBuilder builder(TestName()); auto result = builder.Tanh(builder.Add( builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); @@ -251,18 +252,18 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(platform)); - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_result_shape, "state"); auto iteration = builder.GetTupleElement(state, 0); builder.Gt(builder.ConstantR0(5), iteration); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_result_shape, "state"); auto matrix = builder.GetTupleElement(state, 1); auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), @@ -271,7 +272,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client, TestName()); + XlaBuilder builder(TestName()); auto initial_while_state = builder.Tuple({builder.ConstantR0(0), builder.Parameter(0, matrix_shape, "initial_value")}); diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 44f874cd2ae8e6..56702feab9a4e8 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -42,7 +42,7 @@ StatusOr> TextLiteralReader::ReadPath( << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = - tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file); + tensorflow::Env::Default()->NewRandomAccessFile(std::string(path), &file); if (!s.ok()) { return s; } @@ -92,7 +92,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece sp(shape_string); if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = sp.ToString(); + string tmp = std::string(sp); shape_string = tmp; } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); @@ -124,10 +124,10 @@ StatusOr> TextLiteralReader::ReadAllLines() { line.c_str()); } float value; - if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(), + if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - value_string.ToString().c_str()); + std::string(value_string).c_str()); } SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); coordinate_values.clear(); @@ -136,7 +136,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - piece.ToString().c_str()); + std::string(piece).c_str()); } coordinate_values.push_back(coordinate_value); } diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 3fee467594d842..373c0d2d8d8ab0 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -30,10 +30,10 @@ limitations under the License. namespace xla { -/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( +/* static */ Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(path.ToString(), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); if (!s.ok()) { return s; } @@ -43,7 +43,7 @@ namespace xla { return s; } - tensorflow::Status status; + Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( [f_ptr, &status](tensorflow::gtl::ArraySlice indices, diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 7375493f4309c9..0a1235b5e04675 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -37,8 +37,8 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static tensorflow::Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 0bc4045a549031..ff5340ee3fac51 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,11 +36,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -63,10 +62,9 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -84,11 +82,11 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -137,7 +135,7 @@ tf_cc_binary( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -164,12 +162,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -183,12 +179,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -201,13 +196,12 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc index fe03a6e7bdfe99..14d01b5bfb067c 100644 --- a/tensorflow/compiler/xla/tools/convert_computation.cc +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/env.h" @@ -33,7 +33,7 @@ namespace xla { namespace tools { void RealMain(const string& mode, const string& path) { - SessionModule module; + HloSnapshot module; tensorflow::Env* env = tensorflow::Env::Default(); if (mode == "txt2bin") { TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 21ae8583d7cd33..befb55453777dc 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -17,7 +17,7 @@ limitations under the License. // // Dumps a graphviz URL for a snapshot computation to the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The GraphViz URL is placed into the log stderr, whereas computation @@ -30,11 +30,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,10 +48,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index b82f1c81c84b48..cfb8f37487d649 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -21,11 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); std::unique_ptr program_shape = client->GetComputationShape(computation).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 05c0fdf97d27c0..5dd5150be33984 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -19,11 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,16 +38,16 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); if (compile) { std::unique_ptr program_shape = @@ -65,8 +63,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); @@ -74,13 +71,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { local_service->backend().platform()->Name().c_str(), module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { - const ComputationTracker& tracker = local_service->computation_tracker(); - UserComputation* user_computation = - tracker.Resolve(computation.handle()).ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); std::unique_ptr module = - tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + HloModule::CreateFromProto(computation.proto(), config) .ConsumeValueOrDie(); fprintf(stdout, "%s\n", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 51f90b07c66f7d..a5dce20456c6a2 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -48,10 +47,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index 0fa4b98d0a41a1..00000000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 62a353ad09af00..be094b7890aab0 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -17,7 +17,7 @@ limitations under the License. // // Replays computations and shows the results on the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // Computations that require arguments can be replayed using fake data by @@ -36,13 +36,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -64,99 +64,157 @@ namespace { // fields. struct Options { string fake_infeed_shape; + bool generate_fake_infeed = false; bool use_fake_data = false; bool print_result = true; int num_runs = 1; - bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // -// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; -// otherwise, no infeed is performed. -StatusOr> ReplayComputation( - const SessionModule& module, Client* client, const Options& opts) { - TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); +// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided. +// If generate_fake_infeed is true, the required infeed shape is derived from +// the computation and then used to provide a fake infeed shape. +// +// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, +// no infeed is performed. +StatusOr ReplayComputation(const HloSnapshot& module, + LocalClient* client, const Options& opts) { + XlaComputation computation(module.hlo().hlo_module()); - std::vector> arguments; + // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our + // arguments. This is a bit involved, because we may have to convert from + // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our + // objects. + std::vector scoped_shaped_buffer_arguments; + std::vector> global_data_arguments; + std::vector argument_ptrs; if (opts.use_fake_data) { - arguments = MakeFakeArgumentsOrDie(computation, client); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + for (const auto& data : global_data_arguments) { + argument_ptrs.push_back( + client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) + .ValueOrDie()); + } } else { // use recorded data if available for (const auto& proto : module.arguments()) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(proto)); - TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(*literal)); - arguments.push_back(std::move(data)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer data, + client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + scoped_shaped_buffer_arguments.push_back(std::move(data)); + } + for (const auto& argument : scoped_shaped_buffer_arguments) { + argument_ptrs.push_back(&argument); } } + bool provide_infeed = false; + Shape infeed_shape; + if (!opts.fake_infeed_shape.empty()) { + StatusOr shape_status = + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + TF_CHECK_OK(shape_status.status()); + infeed_shape = std::move(shape_status).ValueOrDie(); + provide_infeed = true; + } else if (opts.generate_fake_infeed) { + for (const auto& comp : computation.proto().computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { + CHECK(!provide_infeed) + << "--generate_fake_infeed only works if the model has 0 or 1 " + "infeed ops, but this one has >= 2."; + provide_infeed = true; + infeed_shape = instruction.shape(); + LOG(INFO) << "Generating fake infeed shape for inferred shape: " + << ShapeUtil::HumanString(infeed_shape); + } + } + } + } // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape. + // concurrent infeed occur via the fake_infeed_shape, or when + // --generate_fake_infeed is passed and there exists an infeed operation in + // the HloSnapshot. tensorflow::gtl::optional pool; - - if (!opts.fake_infeed_shape.empty()) { + std::unique_ptr data; + if (provide_infeed) { + data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + } + auto transfer_infeed = [&data, client]() { + TF_CHECK_OK(client->TransferToInfeed(*data)); + }; + if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([opts, client]() { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - Shape shape = std::move(shape_status).ValueOrDie(); - StatusOr> data_status = MakeFakeLiteral(shape); - TF_CHECK_OK(data_status.status()); - std::unique_ptr data = std::move(data_status).ValueOrDie(); - while (true) { - TF_CHECK_OK(client->TransferToInfeed(*data)); - } + pool->Schedule([transfer_infeed]() { + // There may be several infeed buffers needed, however we don't know how + // many. If we proactively transfer too many infeed buffers, we may run + // out of memory. If we transfer too few infeed buffers, the program will + // hang. Therefore, we register a callback that is called when the infeed + // becomes empty, and in this callback we will transfer another fake + // infeed. + auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); + infeed_manager->RegisterOnEmptyCallback(transfer_infeed); + transfer_infeed(); }); } - std::vector execute_arguments; - execute_arguments.reserve(arguments.size()); - for (auto& argument : arguments) { - execute_arguments.push_back(argument.get()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); } + std::unique_ptr executable = + client->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); // Run the computation num_runs times, and return the result from the last // execution. - std::unique_ptr result; + StreamExecutorMemoryAllocator allocator( + client->platform(), + {client->platform()->ExecutorForDevice(0).ValueOrDie()}); + tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { - execution_options.mutable_debug_options()->set_xla_hlo_profile(true); - } + ExecutableRunOptions run_options; + run_options.set_execution_profile(&profile); + run_options.set_allocator(&allocator); - if (opts.print_result) { - TF_ASSIGN_OR_RETURN( - result, client->ExecuteAndTransfer(computation, execute_arguments, - &execution_options, &profile)); - } else { - // If we're not printing the result, execute the computation but don't - // bother retrieving the result. This can be a significant speedup. - TF_RETURN_IF_ERROR(client - ->Execute(computation, execute_arguments, - &execution_options, &profile) - .status()); - } + TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - return std::move(result); + // Check that --num_runs > 0, otherwise *result below will fail with an + // unhelpful error (because the loop didn't run any iterations). + CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; + TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + client->ShapedBufferToLiteral(*result)); + return std::move(*result_literal); } int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { - Client* client = ClientLibrary::LocalClientOrDie(); + LocalClient* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - SessionModule module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); - StatusOr> result_status = - ReplayComputation(module, client, opts); + HloSnapshot snapshot; + auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); + if (!status.ok()) { + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg); + status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo()); + if (!status.ok()) { + fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg, + status.ToString().c_str()); + continue; + } + CHECK(opts.use_fake_data) + << "HloProto input must be handled with --use_fake_data"; + } + + StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -164,16 +222,17 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { continue; } - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { + if (opts.print_result) { + Literal result = std::move(result_status).ValueOrDie(); + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), + ShapeUtil::HumanString(result.shape()).c_str(), + result.ToString().c_str()); + if (snapshot.has_result()) { std::unique_ptr literal = - Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), literal->ToString().c_str()); } } @@ -198,9 +257,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), - tensorflow::Flag( - "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, - "Pass --xla_hlo_profile the last time we run the computation."), + tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, + "Whether a fake infeed shape should be generated " + "derived from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 1f3340cbc6afa9..4e53fafcc97ff5 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -18,7 +18,7 @@ limitations under the License. // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command // line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The output format is: @@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -49,13 +48,14 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + auto computation = client->LoadSnapshot(module).ConsumeValueOrDie(); std::unique_ptr shape = client->GetComputationShape(computation).ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s\n", arg, + module.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(*shape).c_str()); } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index be33bd6dd1304f..b4f45cc972d3d3 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -218,6 +218,12 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); // Passed-varargs variant of the InvalidArgument factory above. Status InvalidArgumentV(const char* format, va_list args); +template +Status InvalidArgumentStrCat(Args&&... concat) { + return InvalidArgument( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + template Status UnimplementedStrCat(Args&&... concat) { return Unimplemented( @@ -486,6 +492,12 @@ bool c_is_sorted(const C& c) { return std::is_sorted(std::begin(c), std::end(c)); } +template +bool c_is_sorted(const C& c, Compare&& comp) { + return std::is_sorted(std::begin(c), std::end(c), + std::forward(comp)); +} + template auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); @@ -514,12 +526,29 @@ typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, std::forward(binary_op)); } +template +typename std::iterator_traits< + decltype(std::begin(std::declval()))>::difference_type +c_count_if(const C& c, Pred&& pred) { + return std::count_if(std::begin(c), std::end(c), std::forward(pred)); +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } +template +void InsertAt(C* c, int64 index, Value&& value) { + c->insert(c->begin() + index, std::forward(value)); +} + +template +void EraseAt(C* c, int64 index) { + c->erase(c->begin() + index); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f619b8dc24038a..53ba120d21a9e1 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -17,7 +17,6 @@ syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; -import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -230,14 +229,6 @@ message SnapshotComputationRequest { ComputationHandle computation = 1; } -message SnapshotComputationResponse { - SessionModule module = 1; -} - -message LoadComputationSnapshotRequest { - SessionModule module = 1; -} - message LoadComputationSnapshotResponse { ComputationHandle computation = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index d23f9e5918f54c..6bdfb0179cd6a5 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -66,11 +66,16 @@ enum PrimitiveType { // in the dimensions field. TUPLE = 13; - // An opaque type used for passing context specific data to a custom - // operation. + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. OPAQUE = 14; - // Next = 17 + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 18 } // Describes the value held inside padding elements. @@ -134,6 +139,8 @@ enum Format { // example, Convert) are ignored. // // See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange message Layout { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. @@ -159,9 +166,12 @@ message Layout { // memory. This field must be unset unless the format is SPARSE. int64 max_sparse_elements = 5; - // Important: if any field is added, be sure to modify ShapeUtil::Equal() - // appropriately to account for the new field. + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. } +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) // A shape describes the number of dimensions in the array, the size of each // dimension, and the primitive component type. @@ -170,6 +180,8 @@ message Layout { // defined. // // See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange message Shape { reserved 1; reserved "rank"; @@ -190,9 +202,12 @@ message Shape { // The layout used to back this shape. Layout layout = 5; - // Important: if any field is added, be sure to modify ShapeUtil::Equal() and - // ShapeUtil::Compatible() appropriately to account for the new field. + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. } +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) // Shape of the parameters and output of a computation (like a traditional // function signature). @@ -804,6 +819,12 @@ enum UnaryOperation { // Elementwise, computes clz(x). UNOP_CLZ = 17; + + // Elementwise, computes exp(x)-1. + UNOP_EXPM1 = 18; + + // Elementwise, computes log(x+1). + UNOP_LOG1P = 19; } message UnaryOpRequest { diff --git a/tensorflow/compiler/xla/xlalogo.png b/tensorflow/compiler/xla/xlalogo.png new file mode 100644 index 00000000000000..7a0a295953d0c4 Binary files /dev/null and b/tensorflow/compiler/xla/xlalogo.png differ diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index abdbdb4cd22ff3..50b1ae5cc3cba2 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -31,13 +31,15 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", + "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", - "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", @@ -71,6 +73,7 @@ py_library( "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/mixed_precision:mixed_precision", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", @@ -82,7 +85,6 @@ py_library( "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 7f33d460dce077..ad8c40395c2cdc 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -30,6 +30,7 @@ from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization +from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn @@ -60,6 +61,7 @@ from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import mixed_precision from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn @@ -69,7 +71,6 @@ from tensorflow.contrib import proto from tensorflow.contrib import quantization from tensorflow.contrib import quantize -from tensorflow.contrib import recurrent from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn @@ -96,6 +97,7 @@ from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field +from tensorflow.contrib.recurrent.python import recurrent_api as recurrent from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs from tensorflow.contrib.summary import summary diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 62d1b1cf079d04..881808a98bfd68 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -11,6 +11,16 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_py_test") +py_library( + name = "all_reduce_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":all_reduce", + "//tensorflow/python:util", + ], +) + py_library( name = "all_reduce", srcs = [ diff --git a/tensorflow/contrib/all_reduce/__init__.py b/tensorflow/contrib/all_reduce/__init__.py new file mode 100644 index 00000000000000..f9824f4cfbf83d --- /dev/null +++ b/tensorflow/contrib/all_reduce/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All-reduce implementations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.all_reduce.python.all_reduce import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'build_ring_all_reduce', + 'build_recursive_hd_all_reduce', + 'build_shuffle_all_reduce', + 'build_nccl_all_reduce', + 'build_nccl_then_ring', + 'build_nccl_then_recursive_hd', + 'build_nccl_then_shuffle', + 'build_shuffle_then_ring', + 'build_shuffle_then_shuffle' +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 60306ebdc6cddb..c10179ba8b290b 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -72,7 +72,7 @@ cc_binary( "-s", "-Wl,--gc-sections", "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - LINKER_SCRIPT, + "$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 707853b59befc2..30de7b59af79cb 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/android/jni/run_stats_jni.h" #include + #include #include "tensorflow/core/protobuf/config.pb.h" @@ -73,7 +74,8 @@ JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz, StatSummarizer* s = requireHandle(env, handle); if (s == nullptr) return nullptr; std::stringstream ret; - ret << s->GetStatsByMetric("Top 10 CPU", StatSummarizer::BY_TIME, 10) + ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME, + 10) << s->GetStatsByNodeType() << s->ShortSummary(); return env->NewStringUTF(ret.str().c_str()); } diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md new file mode 100644 index 00000000000000..a4aec8c74a9ad1 --- /dev/null +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -0,0 +1,48 @@ +# How to Contribute + +We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. + +## TensorFlow Code of Conduct +Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult [GitHub +Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +After a pull request is approved, we merge it. Note our merging process differs +from GitHub in that we pull and submit the change into an internal version +control system. This system automatically pushes a git commit to the GitHub +repository (with credit to the original author) and closes the pull request. + +## Style + +See the [AutoGraph style guide](STYLE_GUIDE.md). + +## Unit tests + +Please include unit tests when contributing new features ([example here](converters/continue_statements_test.py)), as they help to a) prove that your code works correctly, and b) guard against future breaking +changes to lower the maintenance cost. +It's also helpful to check that any +changes you propose do not break existing unit tests. You can run tests using the command, + +```shell +bazel test --config=opt --copt=-O3 --copt=-march=native \ + //tensorflow/contrib/autograph/... +``` + +from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 0fcbf5dd59cece..674859bed4ec15 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,6 +1,6 @@ # AutoGraph -IMPORTANT: AutoGraph is pre-alpha, under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! +IMPORTANT: AutoGraph is alpha software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. @@ -56,8 +56,6 @@ Use AutoGraph in one of the following ways, described below: 1. Annotations (simpler) 2. Functional API (more flexible) -NOTE: You can find more examples in this [interactive notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb). - To get started, install the latest nightly TensorFlow build: ```shell @@ -70,6 +68,13 @@ Then import the `autograph` module from `tf.contrib`: from tensorflow.contrib import autograph as ag ``` +### Interactive demo notebooks + +For more extensive examples, check out these interactive notebooks: + + * [RNN trained using Keras and Estimators](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) + * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb) + ## Using with annotations Annotating a function or class with `@convert` converts it in place: diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md new file mode 100644 index 00000000000000..866e5f583a3457 --- /dev/null +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -0,0 +1,75 @@ +# AutoGraph Style Guide + +This page contains style decisions that developers should follow when +contributing code to AutoGraph. + +## TensorFlow Style + +Follow the [TensorFlow style +guide](https://www.tensorflow.org/community/style_guide), the [documentation +guide](https://www.tensorflow.org/community/documentation) and the +[Google Python style guide](https://google.github.io/styleguide/pyguide.html). + +Naming conventions: + +1. The name is TensorFlow, not Tensorflow. +2. The name is AutoGraph, not Autograph. + +## AutoGraph Style + +Below are AutoGraph-specific conventions. In the event of conflict, +it supercedes all previous conventions. + +1. __Citations in Docstrings.__ Write a `#### References` subsection at the + bottom of any docstring with citations. Use ICLR’s bibliography style to + write references; for example, order entries by the first author's last + name. Add a link to the paper if the publication is open source (ideally, + arXiv). + + Write in-paragraph citations in general, e.g., [(Tran and Blei, 2018)][1]. + Write in-text citations when the citation is a noun, e.g., [Tran and Blei + (2018)][1]. Write citations with more than two authors using et al., e.g., + [(Tran et al., 2018)][1]. Separate multiple citations with semicolon, e.g., + ([Tran and Blei, 2018][1]; [Gelman and Rubin, 1992][2]). + + Examples: + + ```none + #### References + + # technical report + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + + # journal + [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation + Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992. + + # arXiv preprint + # use "et al." for papers with too many authors to maintain + [3]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 + + # conference + [4]: Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, and Roger Grosse. + Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches. In _International Conference on Learning + Representations_, 2018. + https://arxiv.org/abs/1803.04386 + ``` + +2. Avoid LaTeX in docstrings. + + * It is not rendered in many (if not most) editors and can be hard to read + for both LaTeX experts and non-experts. + +3. Write docstring and comment math using ASCII friendly notation; python using + operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, + `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: + x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. + + * The more we stick to python style, the more someone can + copy/paste/execute. + * Python style is usually easier to read as ASCII. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 3386c4eca4b93e..79d73af98097ae 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -23,18 +23,32 @@ # TODO(mdan): Bring only the relevant symbols to the top level. from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.impl.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', - 'to_code', 'to_graph', 'AutographParseError' + # Main API + 'RunMode', + 'convert', + 'converted_call', + 'do_not_convert', + 'to_code', + 'to_graph', + # Special functions and overloaded operators + 'operators', + 'stack', + # Exceptions + 'AutographParseError', + # Utilities: to be removed + 'utils', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index 2d9e2c58e3afce..3b0db677ce5e41 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -33,7 +33,7 @@ def visit_Assert(self, node): # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ - tf.Assert(test, [msg]) + tf.Assert(test, (msg,)) """ if node.msg is None: diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 91de82f0a78cca..775d92c1d9f8bc 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -32,14 +32,6 @@ class BreakStatementTransformer(transformer.Base): """Canonicalizes break statements into additional conditionals.""" - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) - nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() - return nodes, break_used - def visit_Break(self, node): self.set_local(BREAK_USED, True) var_name = self.get_local(CONTROL_VAR_NAME) @@ -54,6 +46,7 @@ def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" if not block: return block + template = """ if not var_name: block @@ -64,9 +57,17 @@ def _guard_if_present(self, block, var_name): block=block) return node + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used + def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -74,6 +75,10 @@ def visit_While(self, node): node.orelse = self.visit_block(node.orelse) if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + template = """ var_name = False while test and not var_name: @@ -81,20 +86,18 @@ def visit_While(self, node): else: orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, - orelse=self._guard_if_present(node.orelse, break_var)) + orelse=guarded_orelse) return node def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -103,20 +106,33 @@ def visit_For(self, node): node.orelse = self.visit_block(node.orelse) if break_used: - node.orelse = self._guard_if_present(node.orelse, break_var) + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False - for_stmt + for target in iter_: + (var_name,) + body + else: + orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, - for_stmt=node) - extra_cond = templates.replace_as_expression( - 'not var_name', var_name=break_var) - anno.setanno(node[1], 'extra_cond', extra_cond) + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + + anno.setanno(node[1], 'extra_test', extra_test) return node diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 317711a866f731..231e4ee35a72f5 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -31,9 +31,6 @@ class BuiltinFunctionTransformer(transformer.Base): TF equivalent, like `len`. """ - def __init__(self, context): - super(BuiltinFunctionTransformer, self).__init__(context) - def _convert_builtin(self, node): template = """ ag__.utils.dynamic_builtin(func, args) @@ -51,7 +48,7 @@ def visit_Call(self, node): # TODO(mdan): This won't work if the function was hidden. # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange')): + node.func.id in ('len', 'range', 'xrange', 'float', 'int')): return self._convert_builtin(node) # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index 554f0471d44d54..b6ecdcb7809b1a 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -292,15 +292,25 @@ def visit_Call(self, node): raise NotImplementedError( 'py_func with return values (unknown function)') else: + if anno.hasanno(node.func, anno.Basic.QN): + # Special-case a few builtins that otherwise go undetected. This + # normally doesn't pose a problem, but the dict built-in doesn't + # work with inspect.getargspec which is required for dynamic functions. + # Note: expecting this is resilient to aliasing (e.g. + # dict = an_evil_dict), because in those cases the regular mechanisms + # process a simple user function. + qn = anno.getanno(node.func, anno.Basic.QN) + # Add items to this list as needed. + if str(qn) in ('dict',): + return node + if ast_util.matches(node, 'super(_)'): # super() calls are preserved. The class conversion mechanism will # ensure that they return the correct value. - pass - elif self.context.recursive: + return node + + if self.context.recursive: node = self._insert_dynamic_conversion(node) - else: - # Unresolved functions are allowed in non-recursive mode. - pass return node diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 4299a8a9d59715..0417817a77e706 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -24,103 +24,115 @@ from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(transformer.Base): - """Canonicalizes continue statements into additional conditionals.""" +# Tags for local state. +CONTROL_VAR_NAME = 'control_var_name' +CONTINUE_USED = 'continue_used' +GUARD_CREATED = 'guard_created' +CREATE_GUARD_NEXT = 'create_guard_next' - def __init__(self, context): - super(ContinueCanonicalizationTransformer, self).__init__(context) - # This is a stack structure, to correctly process nested loops. - self.continuation_uses = [] - def _create_continuation_check(self): - template = """ - if not var_name: - pass - """ - cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) - cond.body = [] - return cond +class ContinueCanonicalizationTransformer(transformer.Base): + """Canonicalizes continue statements into additional conditionals.""" - def _create_continuation_trigger(self): + def visit_Continue(self, node): + self.set_local(CONTINUE_USED, True) template = """ var_name = True """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _create_continuation_init(self): - template = """ - var_name = False - """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _visit_and_reindent_if_necessary(self, nodes): - reorganized_nodes = [] - current_dest = reorganized_nodes - continue_used_in_block = False - for i, n in enumerate(nodes): - # TODO(mdan): This could be optimized if control structures are simple. - self.continuation_uses[-1][0] = False - n = self.visit(n) - current_dest.append(n) - if self.continuation_uses[-1][0]: - continue_used_in_block = True - if i < len(nodes) - 1: # Last statement in block needs no protection. - cond = self._create_continuation_check() - current_dest.append(cond) - current_dest = cond.body - self.continuation_uses[-1][0] = continue_used_in_block - return reorganized_nodes - - def _process_loop_block(self, block, scope): - cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) - self.continuation_uses.append([False, cont_var]) - block = self._visit_and_reindent_if_necessary(block) - if self.continuation_uses[-1][0]: - block.insert(0, self._create_continuation_init()) - self.continuation_uses.pop() - return block + return templates.replace( + template, var_name=self.get_local(CONTROL_VAR_NAME)) + + def _postprocess_statement(self, node): + # Example of how the state machine below works: + # + # 1| stmt # State: CONTINUE_USED = False + # | # Action: none + # 2| if cond: + # 3| continue # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = False + # | # Action: set CREATE_GUARD_NEXT = True + # 4| stmt # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = True + # | # Action: create `if not continue_used`, + # | # set GUARD_CREATED = True + # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True + # | # Action: none (will be wrapped under previously + # | # created if node) + + if self.get_local(CONTINUE_USED, False): + if self.get_local(GUARD_CREATED, False): + return node, None + + elif not self.get_local(CREATE_GUARD_NEXT, False): + self.set_local(CREATE_GUARD_NEXT, True) + return node, None + + else: + self.set_local(GUARD_CREATED, True) + template = """ + if not var_name: + original_node + """ + cond, = templates.replace( + template, + var_name=self.get_local(CONTROL_VAR_NAME), + original_node=node) + return cond, cond.body + return node, None + + def _visit_loop_body(self, node, nodes): + self.enter_local_scope() + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + self.set_local(CONTROL_VAR_NAME, continue_var) + + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + + if self.get_local(CONTINUE_USED, False): + template = """ + var_name = False + """ + control_var_init = templates.replace(template, var_name=continue_var) + nodes = control_var_init + nodes + + self.exit_local_scope() + return nodes + + def _visit_non_loop_body(self, nodes): + self.enter_local_scope(inherit=(CONTROL_VAR_NAME,)) + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + continue_used = self.get_local(CONTINUE_USED, False) + self.exit_local_scope(keep=(CONTINUE_USED,)) + return nodes, continue_used def visit_While(self, node): - self.generic_visit(node.test) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.test = self.visit(node.test) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_For(self, node): - self.generic_visit(node.target) - self.generic_visit(node.iter) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.target = self.generic_visit(node.target) + node.iter = self.generic_visit(node.iter) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_If(self, node): - if self.continuation_uses: - self.generic_visit(node.test) - node.body = self._visit_and_reindent_if_necessary(node.body) - continue_used_in_body = self.continuation_uses[-1][0] - node.orelse = self._visit_and_reindent_if_necessary(node.orelse) - self.continuation_uses[-1][0] = ( - continue_used_in_body or self.continuation_uses[-1][0]) - else: - node = self.generic_visit(node) + node.test = self.generic_visit(node.test) + node.body, continue_used_body = self._visit_non_loop_body(node.body) + node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse) + self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse) return node - def visit_Continue(self, node): - self.continuation_uses[-1][0] = True - return self._create_continuation_trigger() - - def visit_Break(self, node): - assert False, 'break statement should be desugared at this point' + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body, _ = self._visit_non_loop_body(node.body) + return node def transform(node, namer): diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 2e26cdb3d9387d..d7ddbe8a04f648 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handles control flow statements: while, if.""" +"""Handles control flow statements: while, for, if.""" from __future__ import absolute_import from __future__ import division @@ -25,6 +25,7 @@ from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -47,9 +48,6 @@ def new_symbol(self, name_root, reserved_locals): class ControlFlowTransformer(transformer.Base): """Transforms control flow structures like loops an conditionals.""" - def __init__(self, context): - super(ControlFlowTransformer, self).__init__(context) - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -98,30 +96,63 @@ def visit_If(self, node): body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) - - if body_scope.created - orelse_scope.created: - raise ValueError( - 'The if branch creates new symbols that the else branch does not.') - if orelse_scope.created - body_scope.created: - raise ValueError( - 'The else branch creates new symbols that the if branch does not.') - - modified = tuple(body_scope.modified | orelse_scope.modified) - all_referenced = body_scope.referenced | orelse_scope.referenced + body_defs = body_scope.created | body_scope.modified + orelse_defs = orelse_scope.created | orelse_scope.modified + live = anno.getanno(node, 'live_out') + + # We'll need to check if we're closing over variables that are defined + # elsewhere in the function + # NOTE: we can only detect syntactic closure in the scope + # of the code passed in. If the AutoGraph'd function itself closes + # over other variables, this analysis won't take that into account. + defined = anno.getanno(node, 'defined_in') + + # We only need to return variables that are + # - modified by one or both branches + # - live (or has a live parent) at the end of the conditional + modified = [] + for def_ in body_defs | orelse_defs: + def_with_parents = set((def_,)) | def_.support_set + if live & def_with_parents: + modified.append(def_) + + # We need to check if live created variables are balanced + # in both branches + created = live & (body_scope.created | orelse_scope.created) + + # The if statement is illegal if there are variables that are created, + # that are also live, but both branches don't create them. + if created: + if created != (body_scope.created & live): + raise ValueError( + 'The main branch does not create all live symbols that the else ' + 'branch does.') + if created != (orelse_scope.created & live): + raise ValueError( + 'The else branch does not create all live symbols that the main ' + 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. - need_alias = ( - (body_scope.modified | orelse_scope.modified) - - (body_scope.created | orelse_scope.created)) - aliased_orig_names = tuple(need_alias) - aliased_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), all_referenced) - for s in aliased_orig_names) - alias_map = dict(zip(aliased_orig_names, aliased_new_names)) - node_body = ast_util.rename_symbols(node.body, alias_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_map) + # We will alias variables independently for body and orelse scope, + # because different branches might write different variables. + aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) + aliased_orelse_orig_names = tuple(orelse_scope.modified - + orelse_scope.created) + aliased_body_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + for s in aliased_body_orig_names) + aliased_orelse_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + for s in aliased_orelse_orig_names) + + alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) + alias_orelse_map = dict( + zip(aliased_orelse_orig_names, aliased_orelse_new_names)) + + node_body = ast_util.rename_symbols(node.body, alias_body_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without @@ -134,26 +165,47 @@ def visit_If(self, node): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', all_referenced) - orelse_name = self.context.namer.new_symbol('if_false', all_referenced) + body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.context.namer.new_symbol('if_false', + orelse_scope.referenced) if modified: - body_returns = tuple( - alias_map[s] if s in aliased_orig_names else s for s in modified) + + def build_returns(aliased_names, alias_map, scope): + """Builds list of return variables for a branch of a conditional.""" + returns = [] + for s in modified: + if s in aliased_names: + returns.append(alias_map[s]) + else: + if s not in scope.created | defined: + raise ValueError( + 'Attempting to return variable "%s" from the true branch of ' + 'a conditional, but it was not closed over, or created in ' + 'this branch.' % str(s)) + else: + returns.append(s) + return tuple(returns) + + body_returns = build_returns(aliased_body_orig_names, alias_body_map, + body_scope) + orelse_returns = build_returns(aliased_orelse_orig_names, + alias_orelse_map, orelse_scope) + else: - body_returns = templates.replace('tf.ones(())')[0].value + body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_body_orig_names), + aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_orelse_orig_names), + aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, - returns=body_returns) + returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) @@ -207,7 +259,7 @@ def test_name(state_ssf): def body_name(state_ssf): body return state_ssf, - state_ast_tuple = ag__.while_loop( + state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( @@ -252,31 +304,31 @@ def visit_For(self, node): state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) - if anno.hasanno(node, 'extra_cond'): - extra_cond = anno.getanno(node, 'extra_cond') - extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) + if anno.hasanno(node, 'extra_test'): + extra_test = anno.getanno(node, 'extra_test') + extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: - extra_cond = parser.parse_expression('True') + extra_test = parser.parse_expression('True') template = """ - def extra_cond_name(state_ssf): - return extra_cond_expr + def extra_test_name(state_ssf): + return extra_test_expr def body_name(iterate, state_ssf): body return state_ssf, - state_ast_tuple = ag__.for_loop( - iterated, extra_cond_name, body_name, (state,)) + state_ast_tuple = ag__.for_stmt( + iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - iterated=node.iter, + iter_=node.iter, iterate=node.target, - extra_cond_name=self.context.namer.new_symbol('extra_cond', + extra_test_name=self.context.namer.new_symbol('extra_test', all_referenced), - extra_cond_expr=extra_cond, + extra_test_expr=extra_test, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) @@ -284,6 +336,7 @@ def body_name(iterate, state_ssf): def transform(node, context): - t = ControlFlowTransformer(context) - node = t.visit(node) + cfg.run_analyses(node, cfg.Liveness(context)) + cfg.run_analyses(node, cfg.Defined(context)) + node = ControlFlowTransformer(context).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index c5610b16b4e5de..9d23d9b5b7e8e8 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test @@ -41,7 +42,7 @@ def test_fn(n): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((10, 5, 5), sess.run(result.test_fn(constant_op.constant(5)))) @@ -56,7 +57,7 @@ def test_fn(n): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) @@ -74,7 +75,7 @@ def test_fn(n): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) @@ -91,10 +92,95 @@ def test_fn(n): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + def test_imbalanced_aliasing(self): + + def test_fn(n): + if n > 0: + n = 3 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_ignore_unread_variable(self): + + def test_fn(n): + b = 3 # pylint: disable=unused-variable + if n > 0: + b = 4 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_handle_temp_variable(self): + + def test_fn_using_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z, w + + node = self.parse_and_analyze(test_fn_using_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + self.assertEqual(3, w) + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + self.assertEqual(2, w) + + def test_fn_ignoring_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z + + node = self.parse_and_analyze(test_fn_ignoring_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + def test_simple_for(self): def test_fn(l): diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 54424e26472b84..91ae0b9b82c6f6 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -21,6 +21,7 @@ py_library( "config.py", "conversion.py", "naming.py", + "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -69,3 +70,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/impl/config.py index 2600088595a127..878bb7e12f2b39 100644 --- a/tensorflow/contrib/autograph/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -33,7 +33,7 @@ (utils.__name__,), # All of tensorflow's subpackages. Unlike the root tf module, they don't - # have well-known names. Not refering to the module directly to avoid + # have well-known names. Not referring to the module directly to avoid # circular imports. ( utils.__name__[:-len('.contrib.autograph.utils')],), diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 5edd8e74a8899a..bc61498b5422f5 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -24,7 +24,7 @@ from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras.engine import training from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py new file mode 100644 index 00000000000000..b7a8177c44c882 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None): + """Stacks the input, if it admits the notion of stacking. No-op otherwise. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any entity. + element_dtype: Optional dtype for the elements in the list. Required if the + input is stackable, and the list is untyped. + + Returns: + If the input is stackable, a new object representing the stacked inputs. + Otherwise it returns list_or_tensor unchanged. + """ + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/impl/special_functions_test.py new file mode 100644 index 00000000000000..9b52d2a59b5a3e --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1), 1) + self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)]), list)) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index efb8d441dd839b..0c6ab65505ee03 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -51,3 +52,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 04b4734551d322..c900fd6af2ea5d 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -19,11 +19,32 @@ closures for the body. """ +# Naming conventions: +# * operator names match the name usually used for the respective Python +# idiom; examples: for_stmt, list_append +# * operator arguments match either of: +# - the corresponding Python AST attribute (e.g. the condition of an if +# statement is called test) if the operator represents an AST construct +# - the names used in the Python docs, if the operator is a function (e.g. +# list_ and x for append, see +# https://docs.python.org/3.7/tutorial/datastructures.html) +# +# All operators may accept a final argument named "opts", of a type that +# subclasses namedtuple and contains any arguments that are only required +# for some specializations of the operator. + from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Add a container for implementation-specific toggles (throughout). - -from tensorflow.contrib.autograph.operators.control_flow import for_loop -from tensorflow.contrib.autograph.operators.control_flow import while_loop +from tensorflow.contrib.autograph.operators.control_flow import for_stmt +from tensorflow.contrib.autograph.operators.control_flow import while_stmt +from tensorflow.contrib.autograph.operators.data_structures import list_append +from tensorflow.contrib.autograph.operators.data_structures import list_pop +from tensorflow.contrib.autograph.operators.data_structures import list_stack +from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts +from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts +from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.slices import get_item +from tensorflow.contrib.autograph.operators.slices import GetItemOpts +from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index d9d8b0d593e537..671c9ccc13eaa8 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -25,44 +25,55 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops -# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature. -# TODO(mdan): Rename arguments to match the AST names. - -def for_loop(iterated, extra_cond, loop_body, init_state): +def for_stmt(iter_, extra_test, body, init_state): """Functional form of a for statement. - The loop operates on a so-called state, which includes all symbols that are - variant across loop iterations, excluding the iterate. In what follows we - refer to state as either a tuple of entities that represent an actual state, - or a list of arguments of the corresponding types. + The loop operates on a state, which includes all symbols that are + variant across loop iterations, excluding the iterate as well as the + variables local to the loop. + + For example, given the loop below that calculates the geometric and + arithmetic means or some numbers: + + geo_mean = 1 + arith_mean = 0 + for i in range(n): + a = numbers[i] + geo_mean *= a + arith_mean += a + + The state is represented by the variables geo_mean and arith_mean. The + argument for initial_state may contain the tuple (1, 0), the body will + include the arguments geo_mean and arith_mean and will return a tuple + representing the new values for geo_mean and respectively arith_mean. Args: - iterated: The entity being iterated over. - extra_cond: Callable with the state as arguments, and boolean return type. + iter_: The entity being iterated over. + extra_test: Callable with the state as arguments, and boolean return type. An additionnal loop condition. - loop_body: Callable with the iterate and the state as arguments, and + body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. Returns: Tuple containing the final state. """ - if tensor_util.is_tensor(iterated): - return _known_len_for_loop(iterated, extra_cond, loop_body, init_state) - elif isinstance(iterated, dataset_ops.Dataset): - return _dataset_for_loop(iterated, extra_cond, loop_body, init_state) + if tensor_util.is_tensor(iter_): + return _known_len_for_stmt(iter_, extra_test, body, init_state) + elif isinstance(iter_, dataset_ops.Dataset): + return _dataset_for_stmt(iter_, extra_test, body, init_state) else: - return _py_for_loop(iterated, extra_cond, loop_body, init_state) + return _py_for_stmt(iter_, extra_test, body, init_state) -def _py_for_loop(iterated, extra_cond, loop_body, init_state): - """Overload of for_loop that executes a Python for loop.""" +def _py_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that executes a Python for loop.""" state = init_state - for iterate in iterated: - if not extra_cond(*state): + for target in iter_: + if not extra_test(*state): break - state = loop_body(iterate, *state) + state = body(target, *state) # TODO(mdan): Remove this special case. if len(state) == 1: @@ -70,23 +81,23 @@ def _py_for_loop(iterated, extra_cond, loop_body, init_state): return state -def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): - """Overload of for_loop that iterates over objects that define a length.""" - n = builtins.dynamic_len(iterated) +def _known_len_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that iterates over objects that define a length.""" + n = builtins.dynamic_len(iter_) def while_body(iterate_index, *state): - iterate = iterated[iterate_index] - new_state = loop_body(iterate, *state) + iterate = iter_[iterate_index] + new_state = body(iterate, *state) return (iterate_index + 1,) + new_state def while_cond(iterate_index, *state): - return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state)) + return gen_math_ops.logical_and(iterate_index < n, extra_test(*state)) - results = while_loop( + results = while_stmt( while_cond, while_body, init_state=(0,) + init_state, - extra_deps=(iterated,), + extra_deps=(iter_,), opts=dict(maximum_iterations=n)) # Dropping the iteration index because it's not syntactically visible. results = results[1:] @@ -97,8 +108,8 @@ def while_cond(iterate_index, *state): return results -def _dataset_for_loop(ds, extra_cond, loop_body, init_state): - """Overload of for_loop that iterates over TF Datasets.""" +def _dataset_for_stmt(ds, extra_test, body, init_state): + """Overload of for_stmt that iterates over TF Datasets.""" # Because Datsets only expose get_next, in the style of Python iterators, # we are forced to unpack the loop as: # @@ -117,15 +128,15 @@ def tag_with(ds, tag): epoch_number, iterate = iterator.get_next() def while_body(epoch_number, iterate, *state): - new_state = loop_body(iterate, *state) + new_state = body(iterate, *state) epoch_number, iterate = iterator.get_next() return (epoch_number, iterate) + new_state def while_cond(epoch_number, iterate, *state): del iterate - return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state)) + return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state)) - results = while_loop( + results = while_stmt( while_cond, while_body, init_state=(epoch_number, iterate) + init_state, @@ -140,7 +151,7 @@ def while_cond(epoch_number, iterate, *state): return results -def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): +def while_stmt(test, body, init_state, extra_deps, opts=None): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are @@ -149,13 +160,13 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): of the corresponding types. Args: - loop_cond: Callable with the state as arguments, and boolean return type. + test: Callable with the state as arguments, and boolean return type. The loop condition. - loop_body: Callable with the state as arguments, and state as return type. + body: Callable with the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. extra_deps: Tuple containing additional entities on which the loop may - depend, such as loop invariants referenced by loop_cond. Used + depend, such as loop invariants referenced by test. Used exclusively for dispatch control. opts: Optional dict of extra loop parameters. @@ -163,27 +174,27 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): Tuple containing the final state. """ # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch. - # That could be somethins as simple as a collection of dispatch rules, with + # That could be something as simple as a collection of dispatch rules, with # some prioritization. if any(tensor_util.is_tensor(v) for v in init_state + extra_deps): - return _tf_while_loop(loop_cond, loop_body, init_state, opts) + return _tf_while_stmt(test, body, init_state, opts) else: - return _py_while_loop(loop_cond, loop_body, init_state, opts) + return _py_while_stmt(test, body, init_state, opts) -def _tf_while_loop(loop_cond, loop_body, init_state, opts): - """Overload of while_loop that stages a TF while_loop.""" +def _tf_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that stages a TF while_stmt.""" if opts is None: opts = {} - return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts) + return control_flow_ops.while_loop(test, body, init_state, **opts) -def _py_while_loop(loop_cond, loop_body, init_state, opts): - """Overload of while_loop that executes a Python while loop.""" +def _py_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that executes a Python while loop.""" del opts state = init_state - while loop_cond(*state): - state = loop_body(*state) + while test(*state): + state = body(*state) return state diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py index a0cd0bfa82bb05..b14d7edba38461 100644 --- a/tensorflow/contrib/autograph/operators/control_flow_test.py +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -29,28 +29,28 @@ class ForLoopTest(test.TestCase): def test_tensor(self): - s = control_flow.for_loop( + s = control_flow.for_stmt( constant_op.constant([1, 2, 3, 4]), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) with self.test_session() as sess: self.assertEqual((10,), sess.run(s)) def test_python(self): - s = control_flow.for_loop( + s = control_flow.for_stmt( range(5), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) self.assertEqual(10, s) def test_dataset(self): to_int32 = lambda i: math_ops.cast(i, dtypes.int32) - s = control_flow.for_loop( + s = control_flow.for_stmt( dataset_ops.Dataset.range(5).map(to_int32), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) with self.test_session() as sess: self.assertEqual((10,), sess.run(s)) @@ -60,9 +60,9 @@ class WhileLoopTest(test.TestCase): def test_tensor(self): n = constant_op.constant(5) - results = control_flow.while_loop( - loop_cond=lambda i, s: i < n, - loop_body=lambda i, s: (i + 1, s + i,), + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i,), init_state=(0, 0), extra_deps=(n,)) with self.test_session() as sess: @@ -70,9 +70,9 @@ def test_tensor(self): def test_python(self): n = 5 - results = control_flow.while_loop( - loop_cond=lambda i, s: i < n, - loop_body=lambda i, s: (i + 1, s + i), + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i), init_state=(0, 0), extra_deps=(n,)) self.assertEqual((5, 10), results) diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index c862306baa9e81..06d8727b0fcc30 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -18,39 +18,250 @@ from __future__ import division from __future__ import print_function +import collections + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables + + +# TODO(mdan): Once control flow supports objects, repackage as a class. + + +def new_list(iterable=None): + """The list constructor. + + Args: + iterable: Optional elements to fill the list with. + + Returns: + A list-like object. The exact return value depends on the initial elements. + """ + if iterable: + elements = tuple(iterable) + else: + elements = () + + # TODO(mdan): Extend these criteria. + if any(isinstance(el, variables.Variable) for el in elements): + return _py_list_new(elements) + return _tf_tensor_list_new(elements) -# TODO(mdan): Add support for TensorList once functional. -# TODO(mdan): Add primitives for empty list, list with elements. +def _tf_tensor_list_new(elements): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + element_dtype = tuple(all_dtypes)[0] + else: + # Heterogeneous lists are ok. + element_dtype = dtypes.variant + + # TODO(mdan): This may fail for elements of variable shapes. + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + element_shape = array_ops.shape(elements[0]) + else: + # Heterogeneous lists are ok. + element_shape = constant_op.constant(-1) # unknown shape, by convention + + l = list_ops.empty_tensor_list( + element_shape=element_shape, element_dtype=element_dtype) + for el in elements: + l = list_ops.tensor_list_push_back(l, el) + return l -def append(target, element): + +def _py_list_new(elements): + """Overload of new_list that creates a Python list.""" + return list(elements) + + +def list_append(list_, x): """The list append function. - Note: it is unspecified where target will be mutated or not. If target is - a TensorFlow entity, it will not be typically mutated. If target is a plain - list, it will be. In general, if the target is mutated then the return value + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value should point to the original entity. Args: - target: An entity that supports append semantics. - element: The element to append. + list_: An entity that supports append semantics. + x: The element to append. Returns: - Same as target, after the append was performed. + Same as list_, after the append was performed. + + Raises: + ValueError: if list_ is not of a known list-like type. """ - if isinstance(target, tensor_array_ops.TensorArray): - return _tf_tensorarray_append(target, element) + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(list_, x) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_append(list_, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) else: - return _py_append(target, element) + return _py_list_append(list_, x) + + +def _tf_tensor_list_append(list_, x): + """Overload of list_append that stages a Tensor list write.""" + def empty_list_of_elements_like_x(): + tensor_x = ops.convert_to_tensor(x) + return list_ops.empty_tensor_list( + element_shape=array_ops.shape(tensor_x), + element_dtype=tensor_x.dtype) + + list_ = control_flow_ops.cond( + list_ops.tensor_list_length(list_) > 0, + lambda: list_, + empty_list_of_elements_like_x, + ) + return list_ops.tensor_list_push_back(list_, x) + + +def _tf_tensorarray_append(list_, x): + """Overload of list_append that stages a TensorArray write.""" + return list_.write(list_.size(), x) + + +def _py_list_append(list_, x): + """Overload of list_append that executes a Python list append.""" + # Revert to the original call. + list_.append(x) + return list_ + + +class ListPopOpts( + collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): + pass + + +def list_pop(list_, i, opts): + """The list pop function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports pop semantics. + i: Optional index to pop from. May be None. + opts: A ListPopOpts. + + Returns: + Tuple (x, out_list_): + out_list_: same as list_, after the removal was performed. + x: the removed element value. + + Raises: + ValueError: if list_ is not of a known list-like type or the operation is + not supported for that type. + """ + assert isinstance(opts, ListPopOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + raise ValueError('TensorArray does not support item removal') + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_pop(list_, i, opts) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_pop(list_, i) + + +def _tf_tensor_list_pop(list_, i, opts): + """Overload of list_pop that stages a Tensor list pop.""" + if i is not None: + raise NotImplementedError('tensor lists only support removing from the end') + + if opts.element_dtype is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'type; use set_element_type to annotate it') + if opts.element_shape is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'shape; use set_element_type to annotate it') + list_out, x = list_ops.tensor_list_pop_back( + list_, element_dtype=opts.element_dtype) + x.set_shape(opts.element_shape) + return list_out, x + + +def _py_list_pop(list_, i): + """Overload of list_pop that executes a Python list append.""" + if i is None: + x = list_.pop() + else: + x = list_.pop(i) + return list_, x + + +# TODO(mdan): Look into reducing duplication between all these containers. +class ListStackOpts( + collections.namedtuple('ListStackOpts', + ('element_dtype', 'original_call'))): + pass + + +def list_stack(list_, opts): + """The list stack function. + + This does not have a direct correspondent in Python. The closest idiom to + this is tf.append or np.stack. It's different from those in the sense that it + accepts a Tensor list, rather than a list of tensors. It can also accept + TensorArray. When the target is anything else, the dispatcher will rely on + ctx.original_call for fallback. + + Args: + list_: An entity that supports append semantics. + opts: A ListStackOpts object. + + Returns: + The output of the stack operation, typically a Tensor. + """ + assert isinstance(opts, ListStackOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_stack(list_) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_stack(list_, opts) + else: + # No-op for primitive Tensor arguments. + return list_ + else: + return _py_list_stack(list_, opts) + + +def _tf_tensorarray_stack(list_): + """Overload of list_stack that stages a TensorArray stack.""" + return list_.stack() -def _tf_tensorarray_append(target, element): - """Overload of append that stages a TensorArray write at the last position.""" - return target.write(target.size(), element) +def _tf_tensor_list_stack(list_, opts): + """Overload of list_stack that stages a Tensor list write.""" + if opts.element_dtype is None: + raise ValueError('cannot stack a list without knowing its element type;' + ' use set_element_type to annotate it') + return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) -def _py_append(target, element): - """Overload of append that executes a Python list append.""" - target.append(element) - return target +def _py_list_stack(list_, opts): + """Overload of list_stack that executes a Python list append.""" + # Revert to the original call. + return opts.original_call(list_) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 577d28c34da39f..8bbb52d6c10b24 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -19,25 +19,98 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test -class AppendTest(test.TestCase): +class ListTest(test.TestCase): - def test_tf_tensorarray(self): + def test_new_list_empty(self): + l = data_structures.new_list() + # Can't evaluate an empty list. + # TODO(mdan): sess.run should allow tf.variant maybe? + self.assertTrue(isinstance(l, ops.Tensor)) + + def test_new_list_tensor(self): + l = data_structures.new_list([3, 4, 5]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_append_tensor_list(self): + l = data_structures.new_list() + x = constant_op.constant([1, 2, 3]) + l = data_structures.list_append(l, x) + + t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [[1, 2, 3]]) + + def test_append_tensorarray(self): l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) - l1 = data_structures.append(l, 1) - l2 = data_structures.append(l1, 2) + l1 = data_structures.list_append(l, 1) + l2 = data_structures.list_append(l1, 2) with self.test_session() as sess: self.assertAllEqual(sess.run(l1.stack()), [1]) self.assertAllEqual(sess.run(l2.stack()), [1, 2]) - def test_python(self): + def test_append_python(self): l = [] - self.assertAllEqual(data_structures.append(l, 1), [1]) - self.assertAllEqual(data_structures.append(l, 2), [1, 2]) + self.assertAllEqual(data_structures.list_append(l, 1), [1]) + self.assertAllEqual(data_structures.list_append(l, 2), [1, 2]) + + def test_pop_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListPopOpts( + element_dtype=initial_list.dtype, + element_shape=(2,)) + + with self.assertRaises(NotImplementedError): + data_structures.list_pop(l, 0, opts) + + with self.test_session() as sess: + l, x = data_structures.list_pop(l, None, opts) + self.assertAllEqual(sess.run(x), [3, 4]) + + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[1, 2]]) + + def test_pop_python(self): + l = [1, 2, 3] + opts = data_structures.ListPopOpts(element_dtype=None, element_shape=()) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3)) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2)) + + def test_stack_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListStackOpts( + element_dtype=initial_list.dtype, original_call=None) + + with self.test_session() as sess: + t = data_structures.list_stack(l, opts) + self.assertAllEqual(sess.run(t), sess.run(initial_list)) + + def test_stack_fallback(self): + + def dummy_function(l): + # Lazy person's mock: just transform the argument in a way in which we + # can check that this function was indeed called. + return [x * 2 for x in l] + + opts = data_structures.ListStackOpts( + element_dtype=None, original_call=dummy_function) + + self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4]) if __name__ == '__main__': diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/contrib/autograph/operators/dispatch_context.py similarity index 60% rename from tensorflow/python/keras/applications/densenet/__init__.py rename to tensorflow/contrib/autograph/operators/dispatch_context.py index 6b8ea83920733a..097002465bd140 100644 --- a/tensorflow/python/keras/applications/densenet/__init__.py +++ b/tensorflow/contrib/autograph/operators/dispatch_context.py @@ -12,18 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DenseNet Keras applications.""" +"""Structures that allow uniform control over the dispatch process.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input +import collections -del absolute_import -del division -del print_function + +# TODO(mdan): This is where macro override controls fit. + + +class DispatchContext(collections.namedtuple( + 'DispatchContext', + ('options',))): + """Allows passing additional parameters to the specific implementations. + + Attributes: + options: Optional dict of extra arguments that may be required by specific + implementations. + """ + + def option(self, name): + return self.options[name] + + +NO_CTX = DispatchContext(options={}) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py new file mode 100644 index 00000000000000..04fbeb2f6e3923 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to slicing operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +# TODO(mdan): Support extended slices. + + +class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): + pass + + +def get_item(target, i, opts): + """The slice read operator (i.e. __getitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports getitem semantics. + i: Index to read from. + opts: A GetItemOpts object. + + Returns: + The read element. + + Raises: + ValueError: if target is not of a supported type. + """ + assert isinstance(opts, GetItemOpts) + + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_get_item(target, i) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_get_item(target, i, opts) + else: + return _tf_tensor_get_item(target, i) + else: + return _py_get_item(target, i) + + +def _tf_tensorarray_get_item(target, i): + """Overload of get_item that stages a TensorArray read.""" + return target.read(i) + + +def _tf_tensor_list_get_item(target, i, opts): + """Overload of get_item that stages a Tensor list read.""" + if opts.element_dtype is None: + raise ValueError('cannot retrieve from a list without knowing its ' + 'element type; use set_element_type to annotate it') + x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) + return x + + +def _tf_tensor_get_item(target, i): + """Overload of get_item that stages a Tensor (not Tensor list) read.""" + return target[i] + + +def _py_get_item(target, i): + """Overload of get_item that executes a Python list modification.""" + return target[i] + + +def set_item(target, i, x): + """The slice write operator (i.e. __setitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports setitem semantics. + i: Index to modify. + x: The new element value. + + Returns: + Same as target, after the update was performed. + + Raises: + ValueError: if target is not of a supported type. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_set_item(target, i, x) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_set_item(target, i, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % target) + else: + return _py_set_item(target, i, x) + + +def _tf_tensorarray_set_item(target, i, x): + """Overload of set_item that stages a TensorArray write.""" + return target.write(i, x) + + +def _tf_tensor_list_set_item(target, i, x): + """Overload of set_item that stages a Tensor list update.""" + return list_ops.tensor_list_set_item(target, i, x) + + +def _py_set_item(target, i, x): + """Overload of set_item that executes a Python list modification.""" + target[i] = x + return target diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py new file mode 100644 index 00000000000000..d4aacb9d2015fe --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SlicesTest(test.TestCase): + + def test_set_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + l = slices.set_item(l, 0, [5, 6]) + + with self.test_session() as sess: + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]]) + + def test_get_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + t = slices.get_item( + l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) + + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 796ab445c74128..989b821e53a5ce 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -130,6 +130,7 @@ py_test( name = "transformer_test", srcs = ["transformer_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index cc4a7edf02ed75..ae861627fd65cc 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -46,8 +46,15 @@ class Basic(NoValue): '`name_map` allows renaming symbols.') -def getanno(node, key, field_name='___pyct_anno'): - return getattr(node, field_name)[key] +FAIL = object() + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or + (hasattr(node, field_name) and (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default def hasanno(node, key, field_name='___pyct_anno'): @@ -73,5 +80,9 @@ def delanno(node, key, field_name='___pyct_anno'): def copyanno(from_node, to_node, key, field_name='___pyct_anno'): - if hasanno(from_node, key, field_name): - setanno(to_node, key, getanno(from_node, key, field_name), field_name) + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index 1d4d9d119e0c45..f2c0c8cf05ca4b 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -38,12 +38,14 @@ def test_basic(self): anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) - self.assertEqual(3, anno.getanno(node, 'foo')) + self.assertEqual(anno.getanno(node, 'foo'), 3) + self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + self.assertIsNone(anno.getanno(node, 'foo', default=None)) def test_copyanno(self): node_1 = ast.Name() diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index 583cf7ecd7bce3..da07013cf4f430 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -205,6 +205,7 @@ def visit_Attribute(self, node): return node def visit_Subscript(self, node): + # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): @@ -216,7 +217,11 @@ def visit_Subscript(self, node): elif isinstance(s.value, gast.Str): subscript = QN(StringLiteral(s.value.s)) else: - subscript = anno.getanno(node.slice.value, anno.Basic.QN) + # The index may be an expression, case in which a name doesn't make sense. + if anno.hasanno(node.slice.value, anno.Basic.QN): + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + else: + return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 83f3bafc421764..8064a967cd389e 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,6 +19,7 @@ py_library( srcs = [ "activity.py", "annos.py", + "cfg.py", "live_values.py", "type_info.py", ], @@ -43,6 +44,19 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "live_values_test", srcs = ["live_values_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 2c14c2c8c23810..4d7b0cbb7b8f6e 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -23,11 +23,12 @@ import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). +# TODO(alexbw): Ignore named literals (e.g. None) class Scope(object): @@ -43,16 +44,20 @@ class Scope(object): used: identifiers referenced in this scope """ - def __init__(self, parent, isolated=True): + def __init__(self, parent, isolated=True, add_unknown_symbols=False): """Create a new scope. Args: parent: A Scope or None. isolated: Whether the scope is isolated, that is, whether variables created in this scope should be visible to the parent scope. + add_unknown_symbols: Whether to handle attributed and subscripts + without having first seen the base name. + E.g., analyzing the statement 'x.y = z' without first having seen 'x'. """ self.isolated = isolated self.parent = parent + self.add_unknown_symbols = add_unknown_symbols self.modified = set() self.created = set() self.used = set() @@ -134,13 +139,17 @@ def mark_param(self, name): self.params.add(name) def mark_creation(self, name, writes_create_symbol=False): + """Mark a qualified name as created.""" if name.is_composite(): parent = name.parent - if self.has(parent): - if not writes_create_symbol: - return + if not writes_create_symbol: + return else: - raise ValueError('Unknown symbol "%s".' % parent) + if not self.has(parent): + if self.add_unknown_symbols: + self.mark_read(parent) + else: + raise ValueError('Unknown symbol "%s".' % parent) self.created.add(name) def mark_write(self, name): @@ -163,17 +172,25 @@ def mark_returned(self, name): class ActivityAnalyzer(transformer.Base): - """Annotates nodes with local scope information. See Scope.""" + """Annotates nodes with local scope information. - def __init__(self, context, parent_scope): + See Scope. + + The use of this class requires that qual_names.resolve() has been called on + the node. This class will ignore nodes have not been + annotated with their qualified names. + """ + + def __init__(self, context, parent_scope=None, add_unknown_symbols=False): super(ActivityAnalyzer, self).__init__(context) - self.scope = Scope(parent_scope) + self.scope = Scope(parent_scope, None, add_unknown_symbols) self._in_return_statement = False + self._in_aug_assign = False @property def _in_constructor(self): - innermost = self.enclosing_entities[-1] if len(self.enclosing_entities) > 1: + innermost = self.enclosing_entities[-1] parent = self.enclosing_entities[-2] return isinstance(parent, gast.ClassDef) and innermost.name == '__init__' return False @@ -184,6 +201,7 @@ def _node_sets_self_attribute(self, node): # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self',): return True + return False def _track_symbol(self, node, @@ -201,12 +219,14 @@ def _track_symbol(self, self.scope.mark_write(qn.parent) if writes_create_symbol: self.scope.mark_creation(qn, writes_create_symbol=True) + if self._in_aug_assign: + self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - # TODO(mdan): This bay be incorrect with nested functions. + # TODO(mdan): This may be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_creation(qn) @@ -222,6 +242,14 @@ def _track_symbol(self, if self._in_return_statement: self.scope.mark_returned(qn) + def visit_AugAssign(self, node): + # Special rules for AugAssign. In Assign, the target is only written, + # but in AugAssig (e.g. a += b), the target is both read and written. + self._in_aug_assign = True + self.generic_visit(node) + self._in_aug_assign = False + return node + def visit_Name(self, node): self.generic_visit(node) self._track_symbol(node) @@ -295,7 +323,7 @@ def _process_parallel_blocks(self, parent, children): def visit_FunctionDef(self, node): if self.scope: - qn = QN(node.name) + qn = qual_names.QN(node.name) self.scope.mark_write(qn) current_scope = self.scope body_scope = Scope(current_scope, isolated=True) @@ -355,5 +383,32 @@ def visit_Return(self, node): return node +def get_read(node, context): + """Return the variable names as QNs (qual_names.py) read by this statement.""" + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.used + + +def get_updated(node, context): + """Return the variable names created or mutated by this statement. + + This function considers assign statements, augmented assign statements, and + the targets of for loops, as well as function arguments. + For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and + `y`, `for i in range(x)` will return `i`, etc. + Args: + node: An AST node + context: An EntityContext instance + + Returns: + A set of variable names (QNs, see qual_names.py) of all the variables + created or mutated. + """ + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.created | analyzer.scope.modified + + def resolve(node, context, parent_scope=None): return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index ef79a295bfa394..fdbd349af9d332 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -123,7 +123,7 @@ def _parse_and_analyze(self, test_fn): recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) - return node + return node, ctx def test_local_markers(self): @@ -133,7 +133,7 @@ def test_fn(a): # pylint:disable=unused-argument b -= 1 return b - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) self.assertFalse( anno.getanno(node.body[0].body[0].value, NodeAnno.IS_LOCAL)) # c in b = c @@ -156,6 +156,7 @@ def assertSymbolSetsAre(self, expected, actual, name): expected - actual, actual - expected)) def assertScopeIsRmc(self, scope, used, modified, created): + """Assert the scope contains specific used, modified & created variables.""" self.assertSymbolSetsAre(used, scope.used, 'read') self.assertSymbolSetsAre(modified, scope.modified, 'modified') self.assertSymbolSetsAre(created, scope.created, 'created') @@ -168,7 +169,7 @@ def test_fn(a): print(a, b) return c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 @@ -191,7 +192,7 @@ def test_fn(a): foo(a, b) # pylint:disable=undefined-variable return c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. @@ -208,7 +209,7 @@ def test_fn(a): foo(a.b, a.c) return a.d - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[1].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), @@ -234,7 +235,7 @@ def test_fn(a): foo(a[0], a[b]) return a[c] - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), @@ -258,7 +259,7 @@ def test_fn(a): b -= 1 return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), @@ -278,7 +279,7 @@ def test_fn(a): b -= 1 return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) @@ -299,7 +300,7 @@ def test_fn(x): u = -y return z, u - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), @@ -326,7 +327,7 @@ def test_fn(a): d = 1 return d - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), @@ -358,7 +359,7 @@ def test_fn(a, b, c, e): d = 1 return d - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), @@ -390,7 +391,7 @@ def test_fn(b): a = b * b return a - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] self.assertScopeIsRmc( anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), @@ -413,7 +414,7 @@ def f(x): b -= f(i) return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) fn_def_node = node.body[0].body[0] self.assertScopeIsRmc( @@ -434,7 +435,7 @@ def __init__(self, a): self.b = a self.b.c = 1 - node = self._parse_and_analyze(TestClass) + node, _ = self._parse_and_analyze(TestClass) init_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(init_node, NodeAnno.BODY_SCOPE), @@ -448,15 +449,118 @@ def test_aug_assign_subscripts(self): def test_fn(a): a[0] += 1 - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] self.assertScopeIsRmc( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), - ('a',), + ('a', 'a[0]'), ('a', 'a[0]'), ('a',), ) + def test_return_vars_are_read(self): + + def test_fn(a, b, c): # pylint: disable=unused-argument + return c + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('c',), + (), + ( + 'a', + 'b', + 'c', + ), + ) + + def test_aug_assign(self): + + def test_fn(a, b): + a += b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a', 'b'), + ('a'), + ('a', 'b'), + ) + + def test_aug_assign_rvalues(self): + + a = dict(bar=3) + + def foo(): + return a + + def test_fn(x): + foo()['bar'] += x + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('foo', 'x'), + (), + ('x',), + ) + + def test_params_created(self): + + def test_fn(a, b): # pylint: disable=unused-argument + return b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), + (('a', 'b'))) + + def test_get_read(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, set(map(qual_names.QN, ('test_fn', 'x', 'y')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, + set(map(qual_names.QN, ('test_fn2', 'x', 'y', 'z')))) + + def test_get_updated(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py new file mode 100644 index 00000000000000..ad97fdfa8e78d1 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -0,0 +1,445 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph analysis. + +Given a Python AST we construct a control flow graph, with edges both to the +next and previous statements (so it can easily walk the graph both ways). Its +nodes contain the AST of the statements. It can then perform forward or backward +analysis on this CFG. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import functools +import operator + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct.static_analysis import activity + + +class CfgNode(object): + """A node in the CFG.""" + __slots__ = ['next', 'value', 'prev'] + + def __init__(self, value): + self.next = set() + self.prev = set() + self.value = value + + +class Cfg(namedtuple('Cfg', ['entry', 'exit'])): + """A Control Flow Graph. + + Each statement is represented as a node. For control flow statements such + as conditionals and loops the conditional itself is a node which either + branches or cycles, respectively. + Attributes: + entry: The entry node, which contains the `gast.arguments` node of the + function definition. + exit: The exit node. This node is special because it has no value (i.e. no + corresponding AST node). This is because Python functions can have + multiple return statements. + """ + pass + + +class CfgBuilder(gast.NodeVisitor): + """Construct a control flow graph. + + Construct a CFG starting from a FunctionDef node. + Usage: + cfg_obj = CfgBuilder().build_cfg(fndef_node) + """ + + def __init__(self): + # The current leaves of the CFG + self.current_leaves = [] + # TODO(alexbw): generalize to break, return, continue, yield, etc. + # A stack of lists, tracking continue statements + self.continue_ = [] + # A stack of lists tracking break nodes + self.break_ = [] + + def set_current_leaves(self, cfg_node): + """Link this cfg_node to the current leaves. + + This is the central function for building the CFG. It links the current + head cfg_nodes to the passed cfg_node. It then resets the head to the + passed cfg_node. + + Args: + cfg_node: A CfgNode instance. + """ + for head in self.current_leaves: + head.next.add(cfg_node) + # While we're linking the CFG forward, add backlinks + cfg_node.prev.add(head) + self.current_leaves = [cfg_node] + + def build_cfg(self, node): + """Build a CFG for a function. + + Implementation of building a CFG for dataflow analysis. See, e.g.: + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Args: + node: A function definition the body of which to analyze. + Returns: + A CFG object. + Raises: + TypeError: If the input is not a function definition. + """ + if not isinstance(node, gast.FunctionDef): + raise TypeError('input must be a function definition') + entry_cfg_node = CfgNode(node.args) + self.current_leaves = [entry_cfg_node] + self.visit_statements(node.body) + exit_cfg_node = CfgNode(None) + self.set_current_leaves(exit_cfg_node) + return Cfg(entry_cfg_node, exit_cfg_node) + + def visit_statements(self, nodes): + for node in nodes: + # Check for control flow + if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, + gast.Continue, gast.With)): + self.visit(node) + else: + expr = CfgNode(node) + self.set_current_leaves(expr) + + def generic_visit(self, node): + raise ValueError('unknown control flow') + + def visit_If(self, node): + # TODO(alexbw): change this to use immutable tuples instead of lists + # The current head will hold the conditional + test = CfgNode(node.test) + self.set_current_leaves(test) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves = [test] + # Handle the orelse + self.visit_statements(node.orelse) + self.current_leaves.extend(body_exit) + + def visit_While(self, node): + test = CfgNode(node.test) + self.set_current_leaves(test) + # Start a new level of nesting + self.break_.append([]) + self.continue_.append([]) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(test) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_For(self, node): + iter_ = CfgNode(node.iter) + self.set_current_leaves(iter_) + self.break_.append([]) + self.continue_.append([]) + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(iter_) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_Break(self, node): + self.break_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Continue(self, node): + self.continue_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Try(self, node): + self.visit_statements(node.body) + body = self.current_leaves + handlers = [] + for handler in node.handlers: + self.current_leaves = body[:] + self.visit_statements(handler.body) + handlers.extend(self.current_leaves) + self.current_leaves = body + self.visit_statements(node.orelse) + self.current_leaves = handlers + self.current_leaves + self.visit_statements(node.finalbody) + + def visit_With(self, node): + for item in node.items: + self.set_current_leaves(CfgNode(item)) + self.visit_statements(node.body) + + +# TODO(alexbw): once CFG analysis occurs at a block level, +# this extra class will not be necessary +class PropagateAnalysis(gast.NodeVisitor): + """Port analysis annotations from statements to their enclosing blocks.""" + + def __init__(self, analysis): + self.transfer_fn = analysis.transfer_fn + self.in_label = analysis.in_label + self.out_label = analysis.out_label + super(PropagateAnalysis, self).__init__() + + def visit_If(self, node): + # Depth-first. + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + outgoing |= anno.getanno(node.test, self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_For(self, node): + self.generic_visit(node) + incoming = set(anno.getanno(node.body[0], self.in_label)) + incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, frozenset(incoming)) + anno.setanno(node, self.out_label, outgoing) + + def visit_While(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_With(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + for item in node.items: + incoming |= anno.getanno(item, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + +# TODO(alexbw): Abstract the CFG walking machinery into a superclass +# which is parameterized on which fields it selects when walking. +# TODO(alexbw): Abstract the application of dataflow analysis +class Forward(object): + """Forward analysis on CFG. + + Args: + label: A name for this analysis e.g. 'active' for activity analysis. The AST + nodes in the CFG will be given annotations 'name_in', 'name_out', + 'name_gen' and 'name_kill' which contain the incoming values, outgoing + values, values generated by the statement, and values deleted by the + statement respectively. + transfer_fn: Either the AND or OR operator. If the AND operator is used it + turns into forward must analysis (i.e. a value will only be carried + forward if it appears on all incoming paths). The OR operator means that + forward may analysis is done (i.e. the union of incoming values will be + taken). + """ + + def __init__(self, label, context, transfer_fn=operator.or_): + self.transfer_fn = transfer_fn + self.context = context + self.out_label = label + '_out' + self.in_label = label + '_in' + self.gen_label = label + '_gen' + self.kill_label = label + '_kill' + + # TODO(alexbw): see if we can simplify by visiting breadth-first + def visit(self, node): + """Depth-first walking the CFG, applying dataflow information propagtion.""" + # node.value is None only for the exit CfgNode. + if not node.value: + return + + if anno.hasanno(node.value, self.out_label): + before = hash(anno.getanno(node.value, self.out_label)) + else: + before = None + preds = [ + anno.getanno(pred.value, self.out_label) + for pred in node.prev + if anno.hasanno(pred.value, self.out_label) + ] + if preds: + incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) + else: + incoming = frozenset() + anno.setanno(node.value, self.in_label, incoming) + gen, kill = self.get_gen_kill(node, incoming) + anno.setanno(node.value, self.gen_label, gen) + anno.setanno(node.value, self.kill_label, kill) + anno.setanno(node.value, self.out_label, (incoming - kill) | gen) + + if hash(anno.getanno(node.value, self.out_label)) != before: + for succ in node.next: + self.visit(succ) + + def get_gen_kill(self, cfg_node, incoming): + """Calculate Gen and Kill properties of a CFG node in dataflow analysis. + + A function which takes the CFG node as well as a set of incoming + values. It must return a set of newly generated values by the statement as + well as a set of deleted (killed) values. + + Args: + cfg_node: A CfgNode instance. + incoming: + """ + raise NotImplementedError() + + +class Backward(Forward): + """Backward analysis on CFG.""" + + def visit(self, cfg_node): + # cfg_node.value is None for the exit node, which will be visited only once + if not cfg_node.value: + for pred in cfg_node.prev: + self.visit(pred) + return + + if anno.hasanno(cfg_node.value, self.in_label): + before = hash(anno.getanno(cfg_node.value, self.in_label)) + else: + before = None + succs = [ + anno.getanno(succ.value, self.in_label) + for succ in cfg_node.next + if anno.hasanno(succ.value, self.in_label) + ] + if succs: + incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) + else: + incoming = frozenset() + anno.setanno(cfg_node.value, self.out_label, incoming) + gen, kill = self.get_gen_kill(cfg_node, incoming) + anno.setanno(cfg_node.value, self.gen_label, gen) + anno.setanno(cfg_node.value, self.kill_label, kill) + anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) + if hash(anno.getanno(cfg_node.value, self.in_label)) != before: + for pred in cfg_node.prev: + self.visit(pred) + + +def run_analyses(node, analyses): + """Perform dataflow analysis on all functions within an AST. + + Args: + node: An AST node on which to run dataflow analysis. + analyses: Either an instance of the Forward or Backward dataflow analysis + class, or a list or tuple of them. + + Returns: + node: The node, but now with annotations on the AST nodes containing the + results of the dataflow analyses. + """ + if not isinstance(analyses, (tuple, list)): + analyses = (analyses,) + for analysis in analyses: + if not isinstance(analysis, (Forward, Backward)): + raise TypeError('not a valid forward analysis object') + + for child_node in gast.walk(node): + if isinstance(child_node, gast.FunctionDef): + cfg_obj = CfgBuilder().build_cfg(child_node) + for analysis in analyses: + if isinstance(analysis, Backward): + analysis.visit(cfg_obj.exit) + elif isinstance(analysis, Forward): + analysis.visit(cfg_obj.entry) + for analysis in analyses: + PropagateAnalysis(analysis).visit(node) + return node + + +class Liveness(Backward): + """Perform a liveness analysis. + + Each statement is annotated with a set of variables that may be used + later in the program. + """ + + def __init__(self, context): + super(Liveness, self).__init__('live', context) + + def get_gen_kill(self, node, _): + # A variable's parents are live if it is live + # e.g. x is live if x.y is live. This means gen needs to return + # all parents of a variable (if it's an Attribute or Subscript). + # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) + gen = activity.get_read(node.value, self.context) + gen = functools.reduce(lambda left, right: left | right.support_set, gen, + gen) + kill = activity.get_updated(node.value, self.context) + return gen, kill + + +class ReachingDefinitions(Forward): + """Perform reaching definition analysis. + + Each statement is annotated with a set of (variable, definition) pairs. + """ + + def __init__(self, context): + super(ReachingDefinitions, self).__init__('definitions', context) + + def get_gen_kill(self, node, incoming): + definitions = activity.get_updated(node.value, self.context) + gen = frozenset((id_, node.value) for id_ in definitions) + kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) + return gen, kill + + +class Defined(Forward): + """Perform defined variable analysis. + + Each statement is annotated with a set of variables which are guaranteed to + be defined at that point. + """ + + def __init__(self, context): + super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + + def get_gen_kill(self, node, _): + gen = activity.get_updated(node.value, self.context) + return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py new file mode 100644 index 00000000000000..fc07fa3447b23c --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.python.platform import test + + +class CFGTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + return node, ctx + + def _check_anno_matches(self, node, anno_name, var_names): + if isinstance(var_names, str): + var_names = (var_names,) + qual_vars = set() + for var_name in var_names: + if isinstance(var_name, str): + if '[' in var_name or ']' in var_name: + raise ValueError('Annotation matching not supported with subscript.') + if '.' not in var_name: + qual_vars.add(qual_names.QN(var_name)) + else: + attrs = var_name.split('.') + this_qn = functools.reduce(qual_names.QN, attrs[1:], + qual_names.QN(attrs[0])) + qual_vars.add(this_qn) + self.assertEqual(anno.getanno(node, anno_name), qual_vars) + + def test_reaching(self): + + def f(x): + print(x) + while True: + x = x + x = x + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + # Only the argument reaches the expression + def_in = anno.getanno(body[0], 'definitions_in') + # One element, x, from arguments + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) + + while_body = body[1].body + def_in = anno.getanno(while_body[0], 'definitions_in') + # One definition, two possible sources. + # - One from an assignment (if the loop is entered) + # - The other from the arguments (if loop is not entered) + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def_in = anno.getanno(while_body[1], 'definitions_in') + # If we've reached this line, the only reaching definition of x is the + # Assign node in previous line + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) + + def_in = anno.getanno(body[2], 'definitions_in') + # Same situation as while_body[0] + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def test_defined(self): + + def f(x): + if x: + y = 2 # pylint: disable=unused-variable + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + # only x is for sure defined at the end + self._check_anno_matches(body[1], 'defined_in', 'x') + # at the end of the if body both x and y are defined + if_body = body[0].body + self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) + + def _get_live_annotated_fnbody(self, f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body + + def test_live_straightline(self): + + def f1(x): + a = g(x) # pylint: disable=undefined-variable + b = h(a) # pylint: disable=undefined-variable, unused-variable + return x + + body = self._get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + def test_live_stacked_conds_with_else(self): + + def f2(x, a): # pylint: disable=unused-argument + if a > 0: # x should not be live + x = 0 + if a > 1: + x = 1 + else: + x = 2 + + body = self._get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + def test_live_stacked_conds(self): + + def f3(x, a): + if a > 0: # x and a should be live + x = 0 + if a > 1: # x and a should be live_in + x = 1 + return x # x should be live + + body = self._get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + def test_live_possibly_unused_cond(self): + + def f4(x, a): + if a > 0: # x should be live + x = 0 + x += 1 + + body = self._get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_attribute_in_cond(self): + + def f5(x, a): + if a > 0: # x.y should be live + x.y = 0 + return x.y + + body = self._get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + def test_live_noop(self): + + def f6(x): + return x # should this cause x.* to be live? + + body = self._get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + def test_live_loop(self): + + def f7(x, n): + for i in range(n): + x += i + return x + + body = self._get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_context_manager(self): + + def f8(x, f): + with f: + x += 1 + + body = self._get_live_annotated_fnbody(f8) + self._check_anno_matches(body[0], 'live_in', ('f', 'x')) + + def test_node_equality(self): + node_a = gast.parse('y = x').body[0] + node_b = gast.parse('y = x').body[0] + self.assertNotEqual(node_a, node_b) + + def test_nested_functions_defined(self): + + def f(x): + y = x * 2 + + def g(z): + return z + y + + return g(x) + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('g', 'x', 'y')))) + + # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. + # NOTE: 'z' is easy to find, but 'y' is not identified as + # defined, because CFG analysis is applied with each function separately. + # fndef_body = body[1].body + # self.assertEqual( + # anno.getanno(fndef_body[0], 'defined_in'), + # frozenset(map(qual_names.QN, ('z', 'y')))) + + def test_nested_functions_dont_leak_definitions(self): + + def f(x): + print(x) + + def g(): + y = 2 + return y + + return g() # y is not defined here + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('x', 'g')))) + + def test_loop_else(self): + + # Disabling useless-else-on-loop error, because 'break' and 'continue' + # canonicalization are a separate analysis pass, and here we test + # the CFG analysis in isolation. + def for_orelse(x): + y = 0 + for i in range(len(x)): + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + def while_orelse(x, i): + y = 0 + while x < 10: + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + for f in (for_orelse, while_orelse): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + return_node = body[-1] + reaching_defs = anno.getanno(return_node, 'definitions_in') + + # Y could be defined by Assign(Num(0)) or Assign(Num(1)) + # X could be defined as an argument or an AugAssign. + y_defs = [node for var, node in reaching_defs if str(var) == 'y'] + x_defs = [node for var, node in reaching_defs if str(var) == 'x'] + + self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) + self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) + self.assertEqual(len(y_defs), 2) + self.assertEqual( + set((gast.arguments, gast.AugAssign)), + set(type(def_) for def_ in x_defs)) + self.assertEqual(len(x_defs), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index c00946f9c41bc6..d6555dc7e0b3d4 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -136,14 +136,14 @@ def visit_If(self, node): def _process_function_arg(self, arg_name): str_name = str(arg_name) + type_holder = arg_name.ast() + self.scope.setval(arg_name, type_holder) if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_holder = arg_name.ast() type_string, type_obj = self.context.arg_types[str_name] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(arg_name, type_holder) def visit_arg(self, node): self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) @@ -167,50 +167,41 @@ def visit_Name(self, node): anno.getanno(definition, 'element_type')) return node - def _process_variable_assignment(self, source, targets): - # Special case: constructors. - if isinstance(source, gast.Call): - func = source.func + def _process_variable_assignment(self, target, value): + # Constructors + if isinstance(value, gast.Call): + func = value.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): - anno.setanno(source, 'is_constructor', True) - anno.setanno(source, 'type', func_obj) - anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) + anno.setanno(value, 'is_constructor', True) + anno.setanno(value, 'type', func_obj) + anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - # Multiple targets mean multiple assignment. - for target in targets: - # Tuple target means unpacking. - if isinstance(target, (gast.Tuple, gast.List)): - for i, target_item in enumerate(target.elts): - # Two cases here: - # 1. Static unpacking, e.g. a, b = c, d - # 2. Dynamic unpacking, e.g. a, b = c - # The former case is optimized away. - if isinstance(source, (gast.Tuple, gast.List)): - source_item = source.elts[i] - else: - source_item = gast.Subscript(source, gast.Index(i), ctx=None) - self._process_variable_assignment(source_item, (target_item,)) - elif isinstance(target, (gast.Name, gast.Attribute)): - target_symbol = anno.getanno(target, anno.Basic.QN) - self.scope.setval(target_symbol, source) - else: - raise ValueError('assignment target has unknown type: %s' % target) + if isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, value) + elif isinstance(target, gast.Subscript): + pass + else: + raise ValueError('assignment target has unknown type: %s' % target) def visit_With(self, node): - for wi in node.items: - if wi.optional_vars is not None: - self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) + for item in node.items: + if item.optional_vars is not None: + self.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self._process_variable_assignment(node.value, node.targets) + self.apply_to_single_assignments( + node.targets, node.value, self._process_variable_assignment) return node def visit_Call(self, node): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 46b7701624a430..95cbf5ca79a504 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -196,6 +196,19 @@ def test_fn(): f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + def test_type_annotation_args(self): + + class Foo(object): + pass + + def test_fn(f): + utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + def test_nested_unpacking(self): class Foo(object): diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 4db6cc0adfad90..60bca8b38dcf62 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -70,14 +70,40 @@ def enclosing_entities(self): return tuple(self._enclosing_entities) @property - def locel_scope_level(self): + def local_scope_level(self): return len(self._local_scope_state) - def enter_local_scope(self): - self._local_scope_state.append({}) + def enter_local_scope(self, inherit=None): + """Marks entry into a new local scope. - def exit_local_scope(self): - return self._local_scope_state.pop() + Args: + inherit: Optional enumerable of variable names to copy from the + parent scope. + """ + scope_entered = {} + if inherit: + this_scope = self._local_scope_state[-1] + for name in inherit: + if name in this_scope: + scope_entered[name] = this_scope[name] + self._local_scope_state.append(scope_entered) + + def exit_local_scope(self, keep=None): + """Marks exit from the current local scope. + + Args: + keep: Optional enumerable of variable names to copy into the + parent scope. + Returns: + A dict containing the scope that has just been exited. + """ + scope_left = self._local_scope_state.pop() + if keep: + this_scope = self._local_scope_state[-1] + for name in keep: + if name in scope_left: + this_scope[name] = scope_left[name] + return scope_left def set_local(self, name, value): self._local_scope_state[-1][name] = value @@ -91,38 +117,163 @@ def debug_print(self, node): print(pretty_printer.fmt(node)) return node - def visit_block(self, nodes): - """Helper equivalent to generic_visit, but for node lists.""" + def visit_block(self, nodes, before_visit=None, after_visit=None): + """A more powerful version of generic_visit for statement blocks. + + An example of a block is the body of an if statement. + + This function allows specifying a postprocessing callback (the + after_visit argument) argument which can be used to move nodes to a new + destination. This is done by after_visit by returning a non-null + second return value, e.g. return new_node, new_destination. + + For example, a transformer could perform the following move: + + foo() + bar() + baz() + + foo() + if cond: + bar() + baz() + + The above could be done with a postprocessor of this kind: + + def after_visit(node): + if node_is_function_call(bar): + new_container_node = build_cond() + new_container_node.body.append(node) + return new_container_node, new_container_node.body + else: + # Once we set a new destination, all subsequent items will be + # moved to it, so we don't need to explicitly handle baz. + return node, None + + Args: + nodes: enumerable of AST node objects + before_visit: optional callable that is called before visiting each item + in nodes + after_visit: optional callable that takes in an AST node and + returns a tuple (new_node, new_destination). It is called after + visiting each item in nodes. Is used in the same was as the + visit_* methods: new_node will replace the node; if not None, + new_destination must be a list, and subsequent nodes will be placed + in this list instead of the list returned by visit_block. + Returns: + A list of AST node objects containing the transformed items fron nodes, + except those nodes that have been relocated using after_visit. + """ results = [] + node_destination = results for node in nodes: + if before_visit: + # TODO(mdan): We can modify node here too, if ever needed. + before_visit() + replacement = self.visit(node) + + if after_visit and replacement: + replacement, new_destination = after_visit(replacement) + else: + new_destination = None + if replacement: if isinstance(replacement, (list, tuple)): - results.extend(replacement) + node_destination.extend(replacement) else: - results.append(replacement) + node_destination.append(replacement) + + # Allow the postprocessor to reroute the remaining nodes to a new list. + if new_destination is not None: + node_destination = new_destination return results + # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + def apply_to_single_assignments(self, targets, values, apply_fn): + """Applies a fuction to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: list, tuple of or individual AST node. Should be used with the + targets field of an ast.Assign node. + values: an AST node. + apply_fn: a function of a single argument, which will be called with the + respective nodes of each single assignment. The signaure is + apply_fn(target, value), no return value. + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) + self.apply_to_single_assignments(target_el, value_el, apply_fn) + else: + # TODO(mdan): Look into allowing to rewrite the AST here. + apply_fn(target, values) + def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file did_enter_function = False - local_scope_state_size = len(self._local_scope_state) + local_scope_size_at_entry = len(self._local_scope_state) try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): - self._enclosing_entities.append(node) did_enter_function = True + if did_enter_function: + self._enclosing_entities.append(node) + if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - return node - return super(Base, self).visit(node) - except (ValueError, AttributeError, KeyError, NotImplementedError, - AssertionError) as e: + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + result = super(Base, self).visit(node) + + # On exception, the local scope integrity is not guaranteed. + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_size_at_entry != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.' % ( + node, + local_scope_size_at_entry, + len(self._local_scope_state) + )) + return result + + except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), try_ast_to_source(node), pretty_printer.fmt(node, color=False)) @@ -130,18 +281,11 @@ def visit(self, node): line = source_code.splitlines()[self._lineno - 1] else: line = '' + # TODO(mdan): Avoid the printing of the original exception. + # In other words, we need to find how to suppress the "During handling + # of the above exception, another exception occurred" message. six.reraise(AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) - finally: - if did_enter_function: - self._enclosing_entities.pop() - - if local_scope_state_size != len(self._local_scope_state): - raise AssertionError( - 'Inconsistent local scope stack. Before entering node %s, the' - ' stack had length %d, after exit it has length %d. This' - ' indicates enter_local_scope and exit_local_scope are not' - ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f96b0dc377521a..f110e79605945e 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser @@ -27,7 +29,7 @@ class TransformerTest(test.TestCase): - def _context_for_nodetesting(self): + def _context_for_testing(self): return context.EntityContext( namer=None, source_code=None, @@ -53,7 +55,7 @@ def visit_BinOp(self, node): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(): a = 0 @@ -94,7 +96,7 @@ def inner_function(x): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) - def test_statement_info_stack(self): + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -116,7 +118,7 @@ def visit_While(self, node): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(a): """Docstring.""" @@ -142,7 +144,7 @@ def test_function(a): self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test')) - def test_statement_info_stack_checks_integrity(self): + def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): @@ -155,7 +157,7 @@ def visit_For(self, node): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def no_exit(a): if a > 0: @@ -174,6 +176,38 @@ def no_entry(a): with self.assertRaises(AssertionError): tr.visit(node) + def test_visit_block_postprocessing(self): + + class TestTransformer(transformer.Base): + + def _process_body_item(self, node): + if isinstance(node, gast.Assign) and (node.value.id == 'y'): + if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) + return if_node, if_node.body + return node, None + + def visit_FunctionDef(self, node): + node.body = self.visit_block( + node.body, after_visit=self._process_body_item) + return node + + def test_function(x, y): + z = x + z = y + return z + + tr = TestTransformer(self._context_for_testing()) + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + node = node.body[0] + + self.assertEqual(len(node.body), 2) + self.assertTrue(isinstance(node.body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1], gast.If)) + self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d3a1b946889253..d82c17bf2afd01 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -33,6 +33,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:dtypes", "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 211e8eaee9082d..998087e056c2cd 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -24,6 +24,7 @@ from tensorflow.contrib.autograph.utils import py_func from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -38,7 +39,13 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_range(*args, **kwargs) if f is range: return dynamic_range(*args, **kwargs) - raise ValueError('%s is not supported' % f) + if f is int: + return dynamic_int(*args, **kwargs) + if f is float: + return dynamic_float(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) def dynamic_len(list_or_tensor): @@ -52,6 +59,20 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def dynamic_int(num_or_tensor, **kwargs): + """Implementation of int() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) + return int(num_or_tensor) + + +def dynamic_float(num_or_tensor, **kwargs): + """Implementation of float() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) + return float(num_or_tensor) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 163e6984079fea..0c2312178a9210 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.autograph.utils import builtins from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.platform import test @@ -77,7 +78,7 @@ def range(x): # pylint:disable=redefined-builtin return x # Functions that just have the names of builtins are rejected. - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( @@ -87,6 +88,20 @@ def range(x): # pylint:disable=redefined-builtin self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + def test_casts(self): + i = constant_op.constant(2, dtype=dtypes.int32) + f = constant_op.constant(1.0, dtype=dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) + self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, True), 1) + self.assertEqual(builtins.dynamic_builtin(int, False), 0) + self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) + self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) + def test_dynamic_print_tf(self): try: out_capturer = six.StringIO() diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index d65c990c87cbc3..b27a19b16c08cb 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -49,6 +49,14 @@ cc_library( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], @@ -96,6 +104,7 @@ py_test( name = "batch_ops_test", size = "small", srcs = ["python/ops/batch_ops_test.py"], + shard_count = 5, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index fac7aff29f79fa..ea8339334f9b5e 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,7 +23,10 @@ from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -205,6 +208,114 @@ def worker(): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOp(self): + """Tests that the batch_function op works.""" + with self.test_session() as sess: + + @function.Defun(dtypes.int32) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = gen_batch_ops.batch_function( + [inp], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, + Tout=[dtypes.int32], + f=computation, + captured_tensors=computation.captured_inputs) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithCapturedInput(self): + """Tests that batch_function op works with captured input.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32) + def computation(inp): + return inp + captured_inp0 - captured_inp1 + + result = gen_batch_ops.batch_function( + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="", + f=computation, + in_tensors=[inp], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + + def testBasicUnbatchDecoratedWithReshape(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return array_ops.reshape(in_t, [-1]) + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [[2]]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" with self.test_session() as sess: @@ -250,7 +361,7 @@ def worker(): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" with self.test_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, batch_timeout_micros=36000000, grad_timeout_micros=1000000, diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h new file mode 100644 index 00000000000000..bf6b7083612018 --- /dev/null +++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h @@ -0,0 +1,21 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ + +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" + +#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 9994c84ebdb930..758754feac31f1 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -45,6 +45,7 @@ _DNN_LEARNING_RATE = 0.001 + def _get_optimizer(optimizer): if callable(optimizer): return optimizer() @@ -73,6 +74,7 @@ def _dnn_tree_combined_model_fn(features, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -108,6 +110,8 @@ def _dnn_tree_combined_model_fn(features, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -132,8 +136,7 @@ def _dnn_tree_combined_model_fn(features, dnn_parent_scope = "dnn" dnn_partitioner = dnn_input_layer_partitioner or ( partitioned_variables.min_max_variable_partitioner( - max_partitions=config.num_ps_replicas, - min_slice_size=64 << 20)) + max_partitions=config.num_ps_replicas, min_slice_size=64 << 20)) with variable_scope.variable_scope( dnn_parent_scope, @@ -171,8 +174,7 @@ def _dnn_tree_combined_model_fn(features, _add_hidden_layer_summary(net, hidden_layer_scope.name) previous_layer = net with variable_scope.variable_scope( - "logits", - values=(previous_layer,)) as logits_scope: + "logits", values=(previous_layer,)) as logits_scope: dnn_logits = layers.fully_connected( previous_layer, head.logits_dimension, @@ -190,8 +192,7 @@ def _dnn_train_op_fn(loss): optimizer=_get_optimizer(dnn_optimizer), name=dnn_parent_scope, variables=ops.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES, - scope=dnn_parent_scope), + ops.GraphKeys.TRAINABLE_VARIABLES, scope=dnn_parent_scope), # Empty summaries to prevent optimizers from logging training_loss. summaries=[]) @@ -230,7 +231,10 @@ def _tree_train_op_fn(loss): update_op = state_ops.assign_add(global_step, 1).op return update_op - tree_train_logits = dnn_logits + tree_logits + if predict_with_tree_only: + tree_train_logits = tree_logits + else: + tree_train_logits = dnn_logits + tree_logits def _no_train_op_fn(loss): """Returns a no-op.""" @@ -288,10 +292,10 @@ def _no_train_op_fn(loss): finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() model_fn_ops.training_hooks.extend([ - trainer_hooks.SwitchTrainOp( - dnn_train_op, dnn_steps_to_train, tree_train_op), - trainer_hooks.StopAfterNTrees( - num_trees, attempted_trees, finalized_trees)]) + trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train, + tree_train_op), + trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees) + ]) return model_fn_ops @@ -318,6 +322,7 @@ def __init__(self, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -360,6 +365,8 @@ def __init__(self, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -377,16 +384,32 @@ def __init__(self, def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedRegressor(estimator.Estimator): @@ -410,6 +433,7 @@ def __init__(self, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -452,6 +476,8 @@ def __init__(self, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -474,16 +500,32 @@ def __init__(self, def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) class DNNBoostedTreeCombinedEstimator(estimator.Estimator): @@ -508,6 +550,7 @@ def __init__(self, dnn_input_layer_partitioner=None, dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + predict_with_tree_only=False, tree_feature_columns=None, tree_center_bias=False, use_core_versions=False): @@ -545,6 +588,8 @@ def __init__(self, as a feature to the tree. dnn_steps_to_train: Number of steps to train dnn for before switching to gbdt. + predict_with_tree_only: Whether to use only the tree model output as the + final prediction. tree_feature_columns: An iterable containing all the feature columns used by the model's boosted trees. If dnn_input_layer_to_tree is set to True, these features are in addition to dnn_feature_columns. @@ -553,15 +598,32 @@ def __init__(self, use_core_versions: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. """ + def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, - tree_learner_config, num_trees, tree_examples_per_layer, config, - dnn_optimizer, dnn_activation_fn, dnn_dropout, - dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias, - use_core_versions) + features=features, + labels=labels, + mode=mode, + head=head, + dnn_hidden_units=dnn_hidden_units, + dnn_feature_columns=dnn_feature_columns, + tree_learner_config=tree_learner_config, + num_trees=num_trees, + tree_examples_per_layer=tree_examples_per_layer, + config=config, + dnn_optimizer=dnn_optimizer, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + dnn_input_layer_partitioner=dnn_input_layer_partitioner, + dnn_input_layer_to_tree=dnn_input_layer_to_tree, + dnn_steps_to_train=dnn_steps_to_train, + predict_with_tree_only=predict_with_tree_only, + tree_feature_columns=tree_feature_columns, + tree_center_bias=tree_center_bias, + use_core_versions=use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, - config=config, feature_engineering_fn=feature_engineering_fn) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 44a8ffaf4b2f5a..401bec84a20a0f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -43,47 +43,60 @@ namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; } // namespace -class BaseBuildSplitOp : public OpKernel { +class SplitBuilderState { public: - explicit BaseBuildSplitOp(OpKernelConstruction* const context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", - &feature_column_group_id_)); + explicit SplitBuilderState(OpKernelContext* const context) { + const Tensor* l1_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l1_regularization", &l1_regularization_)); + context->input("l1_regularization", &l1_regularization_t)); + const Tensor* l2_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l2_regularization", &l2_regularization_)); - OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization", - &tree_complexity_regularization_)); + context->input("l2_regularization", &l2_regularization_t)); + const Tensor* tree_complexity_regularization_t; + OP_REQUIRES_OK(context, context->input("tree_complexity_regularization", + &tree_complexity_regularization_t)); + const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, - context->GetAttr("min_node_weight", &min_node_weight_)); + context->input("min_node_weight", &min_node_weight_t)); - int strategy; - OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy)); + const Tensor* feature_column_group_id_t; + OP_REQUIRES_OK(context, context->input("feature_column_group_id", + &feature_column_group_id_t)); + + const Tensor* multiclass_strategy_t; + OP_REQUIRES_OK( + context, context->input("multiclass_strategy", &multiclass_strategy_t)); + int strategy = multiclass_strategy_t->scalar()(); OP_REQUIRES( context, boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid( strategy), errors::InvalidArgument("Wrong multiclass strategy passed.")); - multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - } - NodeStats ComputeNodeStats(const GradientStats& grad_stats) { - return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, - multiclass_strategy_, grad_stats); - } + multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - void ReadClassId(OpKernelContext* const context, int32* class_id) { const Tensor* class_id_t; OP_REQUIRES_OK(context, context->input("class_id", &class_id_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()), errors::InvalidArgument("class_id must be a scalar.")); - *class_id = class_id_t->scalar()(); + class_id_ = class_id_t->scalar()(); + + l1_regularization_ = l1_regularization_t->scalar()(); + l2_regularization_ = l2_regularization_t->scalar()(); + tree_complexity_regularization_ = + tree_complexity_regularization_t->scalar()(); + min_node_weight_ = min_node_weight_t->scalar()(); + feature_column_group_id_ = feature_column_group_id_t->scalar()(); } - void FillLeaf(const int class_id, const NodeStats& best_node_stats, + NodeStats ComputeNodeStats(const GradientStats& grad_stats) { + return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, + multiclass_strategy_, grad_stats); + } + + void FillLeaf(const NodeStats& best_node_stats, boosted_trees::trees::Leaf* leaf) const { - if (class_id == -1) { + if (class_id_ == -1) { // This would be the case either for TREE_PER_CLASS with only 2 classes, // or for other multiclass strategies. for (float f : best_node_stats.weight_contribution) { @@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel { CHECK(best_node_stats.weight_contribution.size() == 1) << "Weight contribution size = " << best_node_stats.weight_contribution.size(); - leaf->mutable_sparse_vector()->add_index(class_id); + leaf->mutable_sparse_vector()->add_index(class_id_); leaf->mutable_sparse_vector()->add_value( best_node_stats.weight_contribution[0]); } } - protected: + int32 feature_column_group_id() { return feature_column_group_id_; } + float tree_complexity_regularization() { + return tree_complexity_regularization_; + } + + private: LearnerConfig_MultiClassStrategy multiclass_strategy_; - int32 feature_column_group_id_; float l1_regularization_; float l2_regularization_; - float min_node_weight_; float tree_complexity_regularization_; + float min_node_weight_; + int32 class_id_; + int32 feature_column_group_id_; }; -class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildDenseInequalitySplitsOp : public OpKernel { public: explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) {} + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } @@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); -class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildSparseInequalitySplitsOp : public OpKernel { public: explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // For each partition (tree node), store starting index for each dimension. PartitionAndDimensionBoundaries partition_boundaries; @@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { const auto& dimension_boundaries = @@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { OP_REQUIRES( context, - bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); // Dimension for bias feature is always 0 @@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats root_gradient_stats(*gradients_t, *hessians_t, bias_start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { @@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { << bucket_ids_and_dimensions(start_index, 1) << " and for " << bucket_ids_and_dimensions(end_index - 1, 0) << " " << bucket_ids_and_dimensions(end_index - 1, 1); - if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) << "Dimension of bias feature should be 0"; @@ -422,6 +439,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } present_gradient_stats *= normalizer_ratio; + GradientStats not_present = + root_gradient_stats - present_gradient_stats; + // If there was (almost) no sparsity, fix the default direction to LEFT. + bool fixed_default_direction = not_present.IsAlmostZero(); GradientStats left_gradient_stats; for (int64 element_idx = start_index; element_idx < end_index; @@ -441,11 +462,12 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // backward pass gradients. GradientStats right_gradient_stats = present_gradient_stats - left_gradient_stats; + { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats left_stats_default_left = state.ComputeNodeStats( + root_gradient_stats - right_gradient_stats); NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); + state.ComputeNodeStats(right_gradient_stats); if (left_stats_default_left.gain + right_stats_default_left.gain > best_gain) { best_gain = @@ -457,11 +479,13 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { best_dimension_idx = dimension_id; } } - { + // Consider calculating the default direction only when there were + // enough missing examples. + if (!fixed_default_direction) { NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); + state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = state.ComputeNodeStats( + root_gradient_stats - left_gradient_stats); if (left_stats_default_right.gain + right_stats_default_right.gain > best_gain) { best_gain = left_stats_default_right.gain + @@ -487,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_sparse_float_binary_split_default_left() ->mutable_split(); } - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); // Set the feature index for the best feature column. const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); @@ -498,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); } } @@ -519,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // For each partition, store start indices of feature column dimensions. typedef std::vector> PartitionAndDimensionBoundaries; - - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), BuildSparseInequalitySplitsOp); -class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { +class BuildCategoricalEqualitySplitsOp : public OpKernel { public: explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -554,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; @@ -598,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -618,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -630,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(feature_column_group_id_); + equality_split->set_feature_column(state.feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } - - private: - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 9d6cc9245aa463..409a2d8f46c331 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,8 @@ import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops +from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops @@ -72,9 +74,11 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -130,11 +134,14 @@ def __init__(self, gradient_shape, hessian_shape, name="StatsAccumulator/{}".format(self._name)) - self._quantile_accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token, - epsilon=epsilon, - num_quantiles=num_quantiles, - name="QuantileAccumulator/{}".format(self._name)) + # Allocate both stats accumulator and quantile accumulator on the same + # device so that we can build splits with fewer RPCs. + with ops.colocate_with(self._stats_accumulator.resource()): + self._quantile_accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token, + epsilon=epsilon, + num_quantiles=num_quantiles, + name="QuantileAccumulator/{}".format(self._name)) class DenseSplitHandler(InequalitySplitHandler): @@ -236,45 +243,74 @@ def update_stats(self, stamp_token, example_partition_ids, gradients, def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - # Get the aggregated gradients and hessians per - # pair. - # In order to distribute the computation on all the PSs we use the PS that - # had the stats accumulator on. - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - - partition_ids, gains, split_infos = ( - split_handler_ops.build_dense_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_dense_split_scalar + else: + handler = make_dense_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a dense feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_dense_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos class SparseSplitHandler(InequalitySplitHandler): @@ -327,9 +363,6 @@ def __init__(self, multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, name=name) - # Register sparse_make_stats_update function as an Op to the graph. - g = ops.get_default_graph() - sparse_make_stats_update.add_to_graph(g) self._sparse_float_column = sparse_float_column def scheduled_reads(self): @@ -361,8 +394,8 @@ def update_stats(self, stamp_token, example_partition_ids, gradients, are_buckets_ready, buckets = scheduled_reads[0] with ops.name_scope(self._name, "SparseSplitHandler"): (quantile_indices, quantile_values, quantile_shapes, quantile_weights, - example_partition_ids, - feature_ids, gradients, hessians) = sparse_make_stats_update( + example_partition_ids, feature_ids, gradients, + hessians) = sparse_make_stats_update( is_active, are_buckets_ready, self._sparse_float_column.indices, self._sparse_float_column.values, self._sparse_float_column.dense_shape, buckets, @@ -379,42 +412,115 @@ def update_stats(self, stamp_token, example_partition_ids, gradients, def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_sparse_split_scalar + else: + handler = make_sparse_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a sparse feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_make_split(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight): + """Function that builds splits for a sparse feature column.""" + return func( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional) + + return f + +make_dense_split_scalar = _specialize_make_split(_make_dense_split, + is_multi_dimentional=False) +make_dense_split_tensor = _specialize_make_split(_make_dense_split, + is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=True) @function.Defun( @@ -501,11 +607,18 @@ def quantiles_ready(): example_partition_ids) # Compute aggregate stats for each partition. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. bias_feature_ids = array_ops.fill( @@ -533,8 +646,9 @@ def quantiles_not_ready(): empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( - [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( - [0, 1], dtype=dtypes.int64), empty_float) + [], dtype=dtypes.int64, shape=[0, 2]), empty_float, + constant_op.constant([0, 1], dtype=dtypes.int64), + empty_float) handler_active = (sparse_column_indices, sparse_column_values, sparse_column_shape, weights) quantile_indices, quantile_values, quantile_shape, quantile_weights = ( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 54d03018d9e266..2f2c2302113bf5 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 @@ -65,9 +67,9 @@ def testGenerateFeatureSplitCandidates(self): hessian_shape = tensor_shape.scalar() split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -92,7 +94,9 @@ def testGenerateFeatureSplitCandidates(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -105,7 +109,7 @@ def testGenerateFeatureSplitCandidates(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -199,10 +203,10 @@ def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): hessian_shape = tensor_shape.TensorShape([2, 2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -227,7 +231,9 @@ def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -240,7 +246,7 @@ def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -285,10 +291,10 @@ def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): hessian_shape = tensor_shape.TensorShape([2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -313,7 +319,8 @@ def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -326,7 +333,7 @@ def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -369,9 +376,9 @@ def testGenerateFeatureSplitCandidatesInactive(self): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -396,7 +403,8 @@ def testGenerateFeatureSplitCandidatesInactive(self): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -409,7 +417,7 @@ def testGenerateFeatureSplitCandidatesInactive(self): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -443,9 +451,9 @@ def testGenerateFeatureSplitCandidatesWithTreeComplexity(self): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, - min_node_weight=0, + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -470,7 +478,8 @@ def testGenerateFeatureSplitCandidatesWithTreeComplexity(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -483,7 +492,7 @@ def testGenerateFeatureSplitCandidatesWithTreeComplexity(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -576,7 +585,7 @@ def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, min_node_weight=1.5, epsilon=0.001, @@ -603,7 +612,8 @@ def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -616,7 +626,7 @@ def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -685,10 +695,10 @@ def testGenerateFeatureSplitCandidates(self): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -713,8 +723,8 @@ def testGenerateFeatureSplitCandidates(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] - + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -727,7 +737,7 @@ def testGenerateFeatureSplitCandidates(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -811,10 +821,10 @@ def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -839,7 +849,8 @@ def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -853,7 +864,7 @@ def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -905,10 +916,10 @@ def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -933,7 +944,8 @@ def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -947,7 +959,7 @@ def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -996,10 +1008,10 @@ def testGenerateFeatureSplitCandidatesInactive(self): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1024,7 +1036,8 @@ def testGenerateFeatureSplitCandidatesInactive(self): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1038,7 +1051,7 @@ def testGenerateFeatureSplitCandidatesInactive(self): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1065,10 +1078,10 @@ def testEmpty(self): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1096,7 +1109,8 @@ def testEmpty(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1110,7 +1124,7 @@ def testEmpty(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1138,10 +1152,10 @@ def testDegenerativeCase(self): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1166,7 +1180,8 @@ def testDegenerativeCase(self): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1180,7 +1195,7 @@ def testDegenerativeCase(self): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 8ad97fedc923ac..c120dd8a6c156e 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -295,7 +295,7 @@ WeightedQuantilesStream::GetQuantileSpecs( if (eps <= std::numeric_limits::epsilon()) { // Exact quantile computation at the expense of RAM. max_level = 1; - block_size = std::max(max_elements, 2LL); + block_size = std::max(max_elements, int64{2}); } else { // The bottom-most level will become full at most // (max_elements / block_size) times, the level above will become full @@ -315,7 +315,7 @@ WeightedQuantilesStream::GetQuantileSpecs( block_size = static_cast(ceil(max_level / eps)) + 1; } } - return std::make_tuple(max_level, std::max(block_size, 2LL)); + return std::make_tuple(max_level, std::max(block_size, int64{2})); } } // namespace quantiles diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 7576856dc3a6d0..a7e7bfc13cadce 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -195,7 +195,7 @@ class WeightedQuantilesSummary { // designed to be cache-friendly. void Compress(int64 size_hint, double min_eps = 0) { // No-op if we're already within the size requirement. - size_hint = std::max(size_hint, 2LL); + size_hint = std::max(size_hint, int64{2}); if (entries_.size() <= size_hint) { return; } @@ -267,7 +267,7 @@ class WeightedQuantilesSummary { if (entries_.empty()) { return output; } - num_quantiles = std::max(num_quantiles, 2LL); + num_quantiles = std::max(num_quantiles, int64{2}); output.reserve(num_quantiles + 1); // Make successive rank queries to get boundaries. diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 5d0ebbf73ce127..ca5c7f3d8c78a5 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -23,12 +23,6 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("BuildDenseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildSparseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildCategoricalEqualitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("feature_ids: int64") .Input("gradients: float32") .Input("hessians: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs. feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); } // namespace tensorflow - // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 28834ef55bf8e1..5cd37ec67ec3bd 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import random + from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops @@ -399,6 +401,65 @@ def testMakeSparseMultidimensionalSplit(self): self.assertAllClose(0.6, split_node.split.threshold) + def testMakeSparseSplitDefaultDirectionIsStable(self): + """Tests default direction is stable when no sparsity.""" + random.seed(1123) + for _ in range(50): + with self.test_session() as sess: + grad = random.random() + hessian = random.random() + # The data looks like the following (divide by the num of steps 2). + # Gradients | Partition | bucket ID | + # (grad, hessian) | 0 | -1 | + # And then 100 buckets of + # (grad/100, hessian/100), so there is no sparsity. + n_buckets = 100 + + # 1 for the overall sum, and 100 buckets. + partition_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + + bucket_ids = [-1] + [n for n in range(100)] + bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64) + dimension_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = [grad] + [grad / n_buckets] * n_buckets + gradients = array_ops.constant(gradients) + hessians = [hessian] + [hessian / n_buckets] * n_buckets + hessians = array_ops.constant(hessians) + + boundaries = [x * 1 for x in range(n_buckets + 1)] + bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32) + + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0], partitions) + self.assertEqual(1, len(splits)) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + self.assertTrue( + split_info.split_node.HasField( + 'sparse_float_binary_split_default_left')) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 7a5f329b7ab321..843420968ac6a6 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ import abc import collections +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -60,6 +62,7 @@ def _move_tensors(tensors, device): """Moves a list of tensors to a device by concatenating/splitting them.""" # Reset the device setting to avoid weird interactions with device merging # logic. + zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): with ops.device(tensors[0].device): @@ -68,12 +71,11 @@ def _move_tensors(tensors, device): return array_ops.unstack(values) else: with ops.device(tensors[0].device): - sizes = array_ops.stack( - [array_ops.shape(tensor)[0] for tensor in tensors]) - values = array_ops.concat(tensors, axis=0) + sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0] + values = array_ops.concat(tensors, axis=zero) with ops.device(device): sizes = array_ops.unstack(sizes) - return list(array_ops.split(values, sizes, axis=0)) + return list(array_ops.split(values, sizes, axis=zero)) def _scheduled_stamp_resource_op_runner(batch, stamp): diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 1b184d296b329c..19b6b3296db394 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -187,7 +187,7 @@ def flush(self, stamp_token, next_stamp_token): stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: - A list of quantiles or approximate boundaries. + The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, @@ -201,3 +201,6 @@ def flush_summary(self, stamp_token, next_stamp_token): stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result + + def resource(self): + return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 08c1dcdd028829..5dd2e0c7f254f3 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -180,8 +180,7 @@ def extract_features(features, feature_columns, use_core_columns): elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( - features, [fc], - weight_collections=[scope]) + features, [fc], weight_collections=[scope]) else: result = feature_column_ops.transform_features(features, [fc]) if len(result) > 1: @@ -334,10 +333,12 @@ def __init__(self, self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") @@ -354,9 +355,10 @@ def __init__(self, self._sparse_int_indices = sparse_int_indices self._sparse_int_values = sparse_int_values self._sparse_int_shapes = sparse_int_shapes - self._reduce_dim = (self._learner_config.multi_class_strategy == - learner_pb2.LearnerConfig.TREE_PER_CLASS and - learner_config.num_classes == 2) + self._reduce_dim = ( + self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + learner_config.num_classes == 2) def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -369,13 +371,13 @@ def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): Returns: a dictionary of prediction results - ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, - NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. + NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPTED. """ ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) num_handlers = ( - len(self._dense_floats) + len(self._sparse_float_shapes) + - len(self._sparse_int_shapes)) + len(self._dense_floats) + len(self._sparse_float_shapes) + len( + self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) @@ -432,8 +434,9 @@ def predict(self, mode): # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") @@ -457,8 +460,8 @@ def predict(self, mode): # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): - # Serialize the model from parameter server after reading all inputs. - with ops.control_dependencies(input_deps): + # Serialize the model from parameter server after reading the inputs. + with ops.control_dependencies([input_deps[0]]): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) @@ -500,8 +503,9 @@ def train(self, loss, predictions_dict, labels): ValueError: if inputs are not valid. """ # Get the worker device from input dependencies. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) worker_device = input_deps[0].device # Get tensors relevant for training and form the loss. @@ -517,7 +521,7 @@ def train(self, loss, predictions_dict, labels): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = -1 + class_id = constant_op.constant(-1, dtype=dtypes.int32) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. @@ -571,31 +575,39 @@ def train(self, loss, predictions_dict, labels): # Get the weights for each example for quantiles calculation, weights = self._get_weights(hessian_shape, squeezed_hessians) - regularization_config = self._learner_config.regularization - min_node_weight = self._learner_config.constraints.min_node_weight # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) + l1_regularization = constant_op.constant( + self._learner_config.regularization.l1, dtypes.float32) + l2_regularization = constant_op.constant( + self._learner_config.regularization.l2, dtypes.float32) + tree_complexity_regularization = constant_op.constant( + self._learner_config.regularization.tree_complexity, dtypes.float32) + min_node_weight = constant_op.constant( + self._learner_config.constraints.min_node_weight, dtypes.float32) + epsilon = 0.01 + num_quantiles = 100 + strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.DenseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=dense_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -604,14 +616,13 @@ def train(self, loss, predictions_dict, labels): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.SparseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( self._sparse_float_indices[sparse_float_column_idx], self._sparse_float_values[sparse_float_column_idx], @@ -619,7 +630,7 @@ def train(self, loss, predictions_dict, labels): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -628,10 +639,9 @@ def train(self, loss, predictions_dict, labels): fc_name = self._fc_names[fc_name_idx] handlers.append( categorical_split_handler.EqualitySplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_int_column_idx, sparse_int_column=sparse_tensor.SparseTensor( @@ -641,7 +651,7 @@ def train(self, loss, predictions_dict, labels): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -694,11 +704,11 @@ def train(self, loss, predictions_dict, labels): name="continue_centering", trainable=False) stats_update_ops.append( - control_flow_ops.cond(continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), - control_flow_ops.no_op)) + control_flow_ops.cond( + continue_centering, + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator), + control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -720,8 +730,8 @@ def train(self, loss, predictions_dict, labels): shape=[len(handlers)], seed=[seed + 1, 1]) active_handlers = array_ops.stack( [active_handlers_current_layer, active_handlers_next_layer], axis=1) - active_handlers = (active_handlers < - self._learner_config.feature_fraction_per_level) + active_handlers = ( + active_handlers < self._learner_config.feature_fraction_per_level) elif subsampling_type == "feature_fraction_per_tree": seed = predictions_dict[NUM_TREES_ATTEMPTED] active_handlers_current_layer = stateless.stateless_random_uniform( @@ -729,9 +739,12 @@ def train(self, loss, predictions_dict, labels): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack([ - active_handlers_current_layer, - array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) + active_handlers = array_ops.stack( + [ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool) + ], + axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) @@ -760,6 +773,7 @@ def _feature_selection_active_handlers(): empty_hessians = constant_op.constant( [], dtype=dtypes.float32, shape=empty_hess_shape) + active_handlers = array_ops.unstack(active_handlers, axis=0) for handler_idx in range(len(handlers)): handler = handlers[handler_idx] is_active = active_handlers[handler_idx] @@ -901,7 +915,6 @@ def _get_replica_device_setter(self, worker_device): "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - "QuantileStreamResourceHandleOp", ] ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( @@ -971,7 +984,7 @@ def _update_ensemble(): # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [-1]) + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) @@ -1036,8 +1049,11 @@ def _grow_ensemble_fn(): # Update ensemble. update_ops = [are_all_splits_ready] - update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, - _grow_ensemble_fn) + if self._center_bias: + update_model = control_flow_ops.cond(continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() update_ops.append(update_model) # Update ensemble stats. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index f9c22283b7f513..289fb195db109f 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -31,7 +31,6 @@ from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn - from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -97,8 +96,8 @@ def testExtractFeatures(self): array_ops.zeros([2], dtypes.int64)) features["sparse_int"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros([2], dtypes.int64), - array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.int64), array_ops.zeros([2], + dtypes.int64)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, sparse_int_shapes) = ( @@ -139,8 +138,8 @@ def testExtractFeaturesWithTransformation(self): array_ops.zeros([2], dtypes.int64)) features["sparse_categorical"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros( - [2], dtypes.string), array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) feature_columns = set() feature_columns.add(layers.real_valued_column("dense_float")) feature_columns.add( @@ -235,7 +234,8 @@ def testTrainFnChiefNoBiasCentering(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -316,6 +316,113 @@ def testTrainFnChiefNoBiasCentering(self): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): + """Tests the train function with sparse and dense features.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + features["sparse_float"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.float32), + array_ops.constant([4, 1], dtypes.int64)) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + sparse_float_binary_split_default_right { + split{ + left_id: 1 + right_id: 2 + } + } + node_metadata { + gain: 1.125 + } + } + nodes { + leaf { + vector { + value: 1.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -339,7 +446,8 @@ def testTrainFnChiefScalingNumberOfExamples(self): ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -442,7 +550,8 @@ def testTrainFnChiefWithBiasCentering(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -513,7 +622,8 @@ def testTrainFnNonChiefNoBiasCentering(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -576,7 +686,8 @@ def testTrainFnNonChiefWithCentering(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -622,7 +733,8 @@ def testPredictFn(self): with self.test_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" + text_format.Merge( + """ trees { nodes { leaf { @@ -659,14 +771,15 @@ def testPredictFn(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL predictions_dict = sess.run(gbdt_model.predict(mode)) self.assertEquals(predictions_dict["ensemble_stamp"], 3) - self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25], - [0.25], [0.25]]) + self.assertAllClose(predictions_dict["predictions"], + [[0.25], [0.25], [0.25], [0.25]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) def testTrainFnMulticlassFullHessian(self): @@ -698,7 +811,8 @@ def testTrainFnMulticlassFullHessian(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -801,7 +915,8 @@ def testTrainFnMulticlassDiagonalHessian(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -893,8 +1008,8 @@ def testTrainFnMulticlassTreePerClass(self): learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = { - "dense_float": array_ops.constant( - [[1.0], [1.5], [2.0]], dtypes.float32), + "dense_float": + array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32), } gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( @@ -904,7 +1019,8 @@ def testTrainFnMulticlassTreePerClass(self): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) batch_size = 3 predictions = array_ops.constant( @@ -986,7 +1102,8 @@ def testTrainFnMulticlassTreePerClass(self): self.assertAllClose( 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0], - atol=1e-4, rtol=1e-4) + atol=1e-4, + rtol=1e-4) def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): """Tests the train function running on chief with feature selection.""" @@ -1230,9 +1347,9 @@ def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() - _set_float_split(tree.nodes.add() - .sparse_float_binary_split_default_right.split, 2, 4.0, - 1, 2) + _set_float_split( + tree.nodes.add().sparse_float_binary_split_default_right.split, 2, + 4.0, 1, 2) _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) tree_ensemble_config.tree_weights.append(1.0) @@ -1241,7 +1358,8 @@ def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): metadata.num_layers_grown = 1 tree_ensemble_config = tree_ensemble_config.SerializeToString() ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, tree_ensemble_config=tree_ensemble_config, + stamp_token=0, + tree_ensemble_config=tree_ensemble_config, name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 @@ -1333,5 +1451,6 @@ def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): self.assertEquals(output.growing_metadata.num_layers_attempted, 2) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 1192cc44a17823..8ae493ba998bd8 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -14,19 +14,37 @@ # ============================================================================== """Tools for working with object-based checkpoints. - -For creating and managing dependencies: +Visualization and inspection: @@dot_graph_from_checkpoint +@@object_metadata + +Managing dependencies: +@@Checkpointable +@@CheckpointableObjectGraph +@@NoDependency @@split_dependency + +Checkpointable data structures: +@@List +@@Mapping +@@UniqueNameTracker """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint +from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.data_structures import List +from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) + diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index a5681ffa61d07e..7b200a29bf6008 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -8,8 +8,35 @@ py_library( name = "checkpoint", srcs_version = "PY2AND3", deps = [ + ":containers", ":split_dependency", ":visualize", + "//tensorflow/python/training/checkpointable:data_structures", + ], +) + +py_library( + name = "containers", + srcs = ["containers.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:data_structures", + ], +) + +py_test( + name = "containers_test", + srcs = ["containers_test.py"], + deps = [ + ":containers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", + "@six_archive//:six", ], ) @@ -21,6 +48,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", ], ) @@ -32,8 +60,9 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -44,6 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -52,10 +83,13 @@ py_test( srcs = ["visualize_test.py"], deps = [ ":visualize", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py new file mode 100644 index 00000000000000..4d3d5312993740 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -0,0 +1,80 @@ +"""Checkpointable data structures.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures + + +class UniqueNameTracker(data_structures.CheckpointableDataStructure): + """Adds dependencies on checkpointable objects with name hints. + + Useful for creating dependencies with locally unique names. + + Example usage: + ```python + class SlotManager(tf.contrib.checkpoint.Checkpointable): + + def __init__(self): + # Create a dependency named "slotdeps" on the container. + self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tfe.Variable(4.), "y")) + slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + ``` + """ + + def __init__(self): + super(UniqueNameTracker, self).__init__() + self._maybe_initialize_checkpointable() + self._name_counts = {} + + def track(self, checkpointable, base_name): + """Add a dependency on `checkpointable`. + + Args: + checkpointable: An object to add a checkpoint dependency on. + base_name: A name hint, which is uniquified to determine the dependency + name. + Returns: + `checkpointable`, for chaining. + Raises: + ValueError: If `checkpointable` is not a checkpointable object. + """ + + if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + raise ValueError( + ("Expected a checkpointable value, got %s which does not inherit " + "from CheckpointableBase.") % (checkpointable,)) + + def _format_name(prefix, number): + if number > 0: + return "%s_%d" % (prefix, number) + else: + return prefix + + count = self._name_counts.get(base_name, 0) + candidate = _format_name(base_name, count) + while self._lookup_dependency(candidate) is not None: + count += 1 + candidate = _format_name(base_name, count) + self._name_counts[base_name] = count + 1 + self._track_value(checkpointable, name=candidate) + return checkpointable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py new file mode 100644 index 00000000000000..3717d7f583ffdc --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from tensorflow.contrib.checkpoint.python import containers +from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class UniqueNameTrackerTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testNames(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + x1 = resource_variable_ops.ResourceVariable(2.) + x2 = resource_variable_ops.ResourceVariable(3.) + x3 = resource_variable_ops.ResourceVariable(4.) + y = resource_variable_ops.ResourceVariable(5.) + slots = containers.UniqueNameTracker() + slots.track(x1, "x") + slots.track(x2, "x") + slots.track(x3, "x_1") + slots.track(y, "y") + self.evaluate((x1.initializer, x2.initializer, x3.initializer, + y.initializer)) + save_root = checkpointable_utils.Checkpoint(slots=slots) + save_path = save_root.save(checkpoint_prefix) + + restore_slots = checkpointable.Checkpointable() + restore_root = checkpointable_utils.Checkpoint( + slots=restore_slots) + status = restore_root.restore(save_path) + restore_slots.x = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.y = resource_variable_ops.ResourceVariable(0.) + status.assert_consumed().run_restore_ops() + self.assertEqual(2., self.evaluate(restore_slots.x)) + self.assertEqual(3., self.evaluate(restore_slots.x_1)) + self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) + self.assertEqual(5., self.evaluate(restore_slots.y)) + + @test_util.run_in_graph_and_eager_modes() + def testExample(self): + class SlotManager(checkpointable.Checkpointable): + + def __init__(self): + self.slotdeps = containers.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(3.), "x")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(4.), "y")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(5.), "x")) + self.slots = slots + + manager = SlotManager() + self.evaluate([v.initializer for v in manager.slots]) + checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + metadata = checkpointable_utils.object_metadata(save_path) + dependency_names = [] + for node in metadata.nodes: + for child in node.children: + dependency_names.append(child.local_name) + six.assertCountEqual( + self, + dependency_names, + ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + + @test_util.run_in_graph_and_eager_modes() + def testLayers(self): + tracker = containers.UniqueNameTracker() + tracker.track(layers.Dense(3), "dense") + tracker.layers[0](array_ops.zeros([1, 1])) + self.assertEqual(2, len(tracker.trainable_weights)) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/checkpoint/python/split_dependency.py b/tensorflow/contrib/checkpoint/python/split_dependency.py index 3aec8c96e90440..7e77453f3d848c 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -20,8 +20,8 @@ import functools from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable as checkpointable from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import base as checkpointable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py index f1d9d19b047ee6..69dc0b9be2d554 100644 --- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -23,8 +23,8 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def _split_variable_closure(variable): diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index 86fbdb41d2c378..bac071c4cff383 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -17,10 +17,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import errors_impl -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def dot_graph_from_checkpoint(save_path): @@ -52,20 +51,9 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - try: - object_graph_string = reader.get_tensor( - checkpointable.OBJECT_GRAPH_PROTO_KEY) - except errors_impl.NotFoundError: - raise ValueError( - ('The specified checkpoint "%s" does not appear to be object-based (it ' - 'is missing the key "%s"). Likely it was created with a name-based ' - 'saver and does not contain an object dependency graph.') % ( - save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY)) + object_graph = checkpointable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() - object_graph = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - object_graph.ParseFromString(object_graph_string) graph = 'digraph {\n' def _escape(name): return name.replace('"', '\\"') diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py index 1d9ab789235cb9..583e3bc442893d 100644 --- a/tensorflow/contrib/checkpoint/python/visualize_test.py +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -24,11 +24,11 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.keras._impl.keras.engine import training -from tensorflow.python.keras._impl.keras.layers import core +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import util as checkpointable_utils try: import pydot # pylint: disable=g-import-not-at-top diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index f3a75e8688ece1..42ba368531468b 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -15,7 +15,10 @@ load( ) tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -28,15 +31,25 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95dfd9..a6e13ea3ae9384 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -20,9 +20,15 @@ # pylint: disable=line-too-long,wildcard-import from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * # pylint: enable=line-too-long,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'ConfigureColabSession', + 'ConfigureGcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index ff46f0daa80a70..40160706f70e8f 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -73,3 +73,17 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 00000000000000..648a219fb87a6e --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 00000000000000..9cf85f5f1811d8 --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 00000000000000..8c8c5acb31af69 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,188 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1403483d287041..a5a9630a4aa382 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,8 @@ _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' class TPUClusterResolver(ClusterResolver): @@ -70,6 +72,16 @@ def _inGke(): def _gkeMaster(): return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + @staticmethod + def _envVarFallback(): + if _DEFAULT_ENV_VARIABLE in os.environ: + return os.environ[_DEFAULT_ENV_VARIABLE] + return None + + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) + def __init__(self, tpu=None, zone=None, @@ -78,7 +90,8 @@ def __init__(self, coordinator_name=None, coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs @@ -108,6 +121,11 @@ def __init__(self, service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. @@ -123,8 +141,11 @@ def __init__(self, in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None and in_gke: - tpu = self._gkeMaster() + if tpu is None: + if in_gke: + tpu = self._gkeMaster() + else: + tpu = self._envVarFallback() self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name @@ -149,14 +170,22 @@ def __init__(self, if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver. Execute: `pip install ' - '--upgrade google-api-python-client` to install with ' - 'pip.') - - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5b3f9be5a11237..5fac55fd027fa2 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -367,6 +367,10 @@ def testGkeEnvironment(self): compat.as_bytes(TPUClusterResolver._gkeMaster())) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + def testDiscoveryUrl(self): + os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' + self.assertEqual('https://{api}.internal/{apiVersion}', + TPUClusterResolver._discoveryUrl()) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 8f6b412955a480..0676f6c665fa11 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -84,7 +84,7 @@ if (NOT WIN32) option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) if (systemlib_ALL) - set (systmelib_ZLIB ON) + set (systemlib_ZLIB ON) endif (systemlib_ALL) endif() @@ -172,19 +172,20 @@ if (tensorflow_OPTIMIZE_FOR_NATIVE_ARCH) endif() endif() +include(CheckCXXCompilerFlag) + +# OpenMP Support +CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) +if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +endif() +CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) +if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") +endif() + # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) - include(CheckCXXCompilerFlag) - if (tensorflow_ENABLE_MKL_SUPPORT) - add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) - if (NOT tensorflow_ENABLE_MKLDNN_SUPPORT) - add_definitions(-DINTEL_MKL_ML) - endif() - endif() - CHECK_CXX_COMPILER_FLAG("-fopenmp" COMPILER_OPT_OPENMP_SUPPORT) - if (COMPILER_OPT_OPENMP_SUPPORT) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") - endif() if (WIN32) CHECK_CXX_COMPILER_FLAG(${tensorflow_WIN_CPU_SIMD_OPTIONS} COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) if(COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) @@ -323,10 +324,13 @@ if(HAIKU) list(APPEND tensorflow_EXTERNAL_LIBRARIES network) endif() +# MKL Support if (tensorflow_ENABLE_MKL_SUPPORT) + add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) if (WIN32) find_path(MKL_HOME_PLATFORM mkl PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ PATH_SUFFIXES windows) set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) set(MKL_LINK_DIRS @@ -345,6 +349,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT) # Fix me: complete the path on linux find_path(MKL_HOME_PLATFORM mkl HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ PATH_SUFFIXES linux) set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) set(MKL_LINK_DIRS) # incompleted @@ -357,6 +362,8 @@ if (tensorflow_ENABLE_MKL_SUPPORT) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) include_directories(${mkldnn_INCLUDE_DIRS}) + else (tensorflow_ENABLE_MKLDNN_SUPPORT) + add_definitions(-DINTEL_MKL_ML) endif() endif (tensorflow_ENABLE_MKL_SUPPORT) diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index 116d42309394b9..8942f3eecf07ff 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -31,7 +31,8 @@ else (systemlib_ZLIB) set(ZLIB_URL https://github.com/madler/zlib) set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) - set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) + # Match zlib version in tensorflow/workspace.bzl + set(ZLIB_TAG v1.2.11) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 6468bed4979253..015cb73bbd93bb 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -32,52 +32,13 @@ tensorflow/python/feature_column tensorflow/python/framework tensorflow/python/grappler tensorflow/python/keras -tensorflow/python/keras/activations tensorflow/python/keras/applications -tensorflow/python/keras/applications/densenet -tensorflow/python/keras/applications/inception_resnet_v2 -tensorflow/python/keras/applications/inception_v3 -tensorflow/python/keras/applications/mobilenet -tensorflow/python/keras/applications/nasnet -tensorflow/python/keras/applications/resnet50 -tensorflow/python/keras/applications/vgg16 -tensorflow/python/keras/applications/vgg19 -tensorflow/python/keras/applications/xception -tensorflow/python/keras/backend -tensorflow/python/keras/callbacks -tensorflow/python/keras/constraints tensorflow/python/keras/datasets -tensorflow/python/keras/datasets/boston_housing -tensorflow/python/keras/datasets/cifar10 -tensorflow/python/keras/datasets/cifar100 -tensorflow/python/keras/datasets/fashion_mnist -tensorflow/python/keras/datasets/imdb -tensorflow/python/keras/datasets/mnist -tensorflow/python/keras/datasets/reuters -tensorflow/python/keras/estimator -tensorflow/python/keras/initializers +tensorflow/python/keras/engine tensorflow/python/keras/layers -tensorflow/python/keras/losses -tensorflow/python/keras/metrics -tensorflow/python/keras/models -tensorflow/python/keras/optimizers tensorflow/python/keras/preprocessing -tensorflow/python/keras/preprocessing/image -tensorflow/python/keras/preprocessing/sequence -tensorflow/python/keras/preprocessing/text -tensorflow/python/keras/regularizers tensorflow/python/keras/utils tensorflow/python/keras/wrappers -tensorflow/python/keras/wrappers/scikit_learn -tensorflow/python/keras/_impl -tensorflow/python/keras/_impl/keras -tensorflow/python/keras/_impl/keras/applications -tensorflow/python/keras/_impl/keras/datasets -tensorflow/python/keras/_impl/keras/engine -tensorflow/python/keras/_impl/keras/layers -tensorflow/python/keras/_impl/keras/preprocessing -tensorflow/python/keras/_impl/keras/utils -tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions @@ -100,6 +61,7 @@ tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/training +tensorflow/python/training/checkpointable tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -153,6 +115,8 @@ tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python +tensorflow/contrib/control_flow +tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util @@ -333,6 +297,8 @@ tensorflow/contrib/metrics tensorflow/contrib/metrics/python tensorflow/contrib/metrics/python/metrics tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mixed_precision +tensorflow/contrib/mixed_precision/python tensorflow/contrib/mpi_collectives/python tensorflow/contrib/mpi_collectives/python/ops tensorflow/contrib/model_pruning diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index d63c41db844af2..cf1ee2ad76f2cc 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -11,7 +11,6 @@ tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto -tensorflow/contrib/tensorboard/graph_explorer/proto tensorflow/contrib/tensorboard/plugins/projector tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c6a15f2ca075c8..2e0a2fcef4cbdc 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,9 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" @@ -38,13 +37,15 @@ add_dependencies( tf_core_lib tf_protos_cc) -add_library(tf_c_python_api OBJECT - "${tensorflow_source_dir}/tensorflow/c/python_api.cc" - "${tensorflow_source_dir}/tensorflow/c/python_api.h" -) -add_dependencies( - tf_c_python_api - tf_c - tf_core_lib - tf_core_framework - tf_protos_cc) +if(tensorflow_BUILD_PYTHON_BINDINGS) + add_library(tf_c_python_api OBJECT + "${tensorflow_source_dir}/tensorflow/c/python_api.cc" + "${tensorflow_source_dir}/tensorflow/c/python_api.h" + ) + add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_core_framework + tf_protos_cc) +endif() diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index f73da0b8ab18af..6c90cf398c69c8 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -155,7 +155,7 @@ if (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") endif() else (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX}") endif (WIN32) add_custom_target(tf_extension_ops) diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index b47c32f1c48b3d..dac84ccb0dbf48 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -213,10 +213,6 @@ else() list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() -file(GLOB tf_core_platform_exclude_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc") -list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs}) - list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs}) if(UNIX) @@ -286,8 +282,6 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f38c9e05135f9f..2d76bf530a2100 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -68,6 +68,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e558691de4b749..bc753333dba4f6 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index c4bdb69d828b26..1959ad028a06f3 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -244,13 +244,11 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD # tf_python_op_gen_main library ######################################################## set(tf_python_op_gen_main_srcs - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) @@ -422,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" @@ -464,12 +464,12 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" @@ -715,7 +715,7 @@ if(WIN32) endif() else() add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal${CMAKE_SHARED_LIBRARY_SUFFIX} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) endif() @@ -725,7 +725,7 @@ endif() ######################################################## # Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -736,7 +736,7 @@ foreach(api_init_file ${api_init_files_list}) string(STRIP "${api_init_file}" api_init_file) if(api_init_file) string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes - list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}") + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}") endif() endforeach(api_init_file) set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") @@ -749,18 +749,14 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}" - - # Re-add tensorflow/__init__.py back. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "${api_init_list_file}" COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" @@ -791,7 +787,6 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 92f2ab6dea8e7d..eb9482dc25f2be 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" # Disable following manual tag in BUILD. "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + # These tests depend on a .so file + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py ) if (WIN32) @@ -267,6 +271,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" + # Flaky on Windows cpu with py36 (b/73556968) + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/sparse_reshape_op_test.py" # Windows file management related issues. "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # training tests diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index cffe069aa352f8..4f957f1e0b46fd 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -44,7 +44,8 @@ DUMPBIN = "dumpbin.exe" # Exclude if matched -EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::") +EXCLUDE_RE = re.compile(r"RTTI|deleting destructor|::internal::|Internal|" + r"python_op_gen_internal|grappler") # Include if matched before exclude INCLUDEPRE_RE = re.compile(r"google::protobuf::internal::ExplicitlyConstructed|" @@ -56,6 +57,10 @@ r"tensorflow::ops::internal::Enter|" r"tensorflow::strings::internal::AppendPieces|" r"tensorflow::strings::internal::CatPieces|" + r"tensorflow::errors::Internal|" + r"tensorflow::Tensor::CopyFromInternal|" + r"tensorflow::kernel_factory::" + r"OpKernelRegistrar::InitInternal|" r"tensorflow::io::internal::JoinPathImpl") # Include if matched after exclude @@ -64,7 +69,7 @@ r"tensorflow::|" r"functor::|" r"\?nsync_|" - r"perftools::gputools") + r"stream_executor::") # We want to identify data members explicitly in the DEF file, so that no one # can implicitly link against the DLL if they use one of the variables exported diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc index ae4d9d2836a0f8..81b36ca902b822 100644 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py index f039cb0f5265b9..0fbe3081af0b4d 100644 --- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import engine +from tensorflow.python.keras import engine from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import init_ops diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 29a593f6bcfa05..a56a01b16356e1 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -170,12 +169,11 @@ def mulop(x1, x2): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) -@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant([[3]]) + x = constant_op.constant([[3.]]) y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): y_c = math_ops.matmul(y_nc, y_nc, name="compiled") @@ -200,11 +198,11 @@ def testCompilationGradientScopeNames(self): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) @@ -222,11 +220,11 @@ def testCompilationSeparateGradientScopeNames(self): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD new file mode 100644 index 00000000000000..746b5b5b5e2fa2 --- /dev/null +++ b/tensorflow/contrib/control_flow/BUILD @@ -0,0 +1,48 @@ +# New implementations of control flow ops + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "control_flow", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2", + ], +) + +py_library( + name = "cond_v2", + srcs = ["python/cond_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:c_api_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops_gen", + "//tensorflow/python:gradients", + "//tensorflow/python:pywrap_tensorflow", + ], +) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["python/cond_v2_test.py"], + additional_deps = [ + ":cond_v2", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:gradients", + ], + grpc_enabled = True, +) diff --git a/tensorflow/python/keras/_impl/keras/wrappers/__init__.py b/tensorflow/contrib/control_flow/__init__.py similarity index 68% rename from tensorflow/python/keras/_impl/keras/wrappers/__init__.py rename to tensorflow/contrib/control_flow/__init__.py index 20c95929e3d2e1..582af2cf10a3d9 100644 --- a/tensorflow/python/keras/_impl/keras/wrappers/__init__.py +++ b/tensorflow/contrib/control_flow/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras API wrappers. + +"""New implementations of TF control flow ops. + +@@cond_v2 """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers import scikit_learn +# pylint: disable=unused-import +from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py new file mode 100644 index 00000000000000..90c678d0f6bd21 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -0,0 +1,394 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 and gradient. + +This is a version of cond that emits a single If op, as well as the gradient +function for If ops produced by cond_v2. This will eventually replace the +current tf.cond implementation once it reaches feature and performance parity. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import gradients_impl + + +# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify +# that they aren't part of the official public API. These protected members +# often need to be used by implementation code however. Rather than litter the +# code with pylint comments, we ignore protected access violations for +# readability. +# pylint: disable=protected-access + + +def cond_v2(pred, true_fn, false_fn, name="cond"): + """Like tf.cond, except emits a single If op.""" + with ops.name_scope(name) as scope: + true_graph = function.func_graph_from_py_func(true_fn, [], [], + name="%s_true" % scope) + false_graph = function.func_graph_from_py_func(false_fn, [], [], + name="%s_false" % scope) + _check_same_outputs(true_graph, false_graph) + + # Add inputs to true_graph and false_graph to make them match. Note that + # this modifies true_graph and false_graph. + cond_inputs = _make_inputs_match(true_graph, false_graph, + true_graph.extra_inputs, + false_graph.extra_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # the gradient computation. + + true_intermediates = _get_intermediates(true_graph) + false_intermediates = _get_intermediates(false_graph) + + # Save the original number of outputs to return to the caller. + num_cond_outputs = len(true_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_outputs, extra_false_outputs = _pad_params( + true_graph, false_graph, true_intermediates, false_intermediates) + + true_graph.outputs.extend(extra_true_outputs) + false_graph.outputs.extend(extra_false_outputs) + + # Create the If op. + tensors = gen_functional_ops._if( + pred, cond_inputs, [t.dtype for t in true_graph.outputs], + _create_new_tf_function(true_graph), + _create_new_tf_function(false_graph), + name=scope) + + # TODO(b/79883549): if we could make Graphs from FunctionDefs, we wouldn't + # need this extra state. Requiring extra state also prevents the ability to + # take the gradient of deserialized If ops. + tensors[0].op._true_graph = true_graph + tensors[0].op._false_graph = false_graph + + return tensors[:num_cond_outputs] + + +@ops.RegisterGradient("If") +def _IfGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of an If op produced by cond_v2.""" + true_graph = op._true_graph + false_graph = op._false_graph + + # Create grad functions that compute the gradient of the true/false forward + # graphs. These functions will capture tensors from the forward pass + # functions. + true_grad_graph = _create_grad_func( + true_graph, grads, "%sgrad" % true_graph.name) + false_grad_graph = _create_grad_func( + false_graph, grads, "%sgrad" % false_graph.name) + + assert ([t.dtype for t in true_grad_graph.outputs] == + [t.dtype for t in false_grad_graph.outputs]) + + # Match up the captured grad function inputs with outputs of 'op' and other + # external tensors. + true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) + false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + + # Make the inputs to true_grad_graph and false_grad_graph match. Note that + # this modifies true_grad_graph and false_grad_graph. + grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, + true_grad_inputs, false_grad_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # higher-order gradient computations. + + true_grad_intermediates = _get_intermediates(true_grad_graph) + false_grad_intermediates = _get_intermediates(false_grad_graph) + + # Save the original number of gradient outputs to return. + num_grad_outputs = len(true_grad_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( + true_grad_graph, false_grad_graph, + true_grad_intermediates, false_grad_intermediates) + + true_grad_graph.outputs.extend(extra_true_grad_outputs) + false_grad_graph.outputs.extend(extra_false_grad_outputs) + + # Create the gradient If op. + tensors = gen_functional_ops._if( + op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + _create_new_tf_function(true_grad_graph), + _create_new_tf_function(false_grad_graph)) + tensors[0].op._true_graph = true_grad_graph + tensors[0].op._false_graph = false_grad_graph + + # The predicate has no gradient. + return [None] + tensors[:num_grad_outputs] + + +def _grad_fn(func_graph, grads): + """The gradient function for each conditional branch. + + This function builds the gradient graph of the corresponding forward-pass + conditional branch in `func_graph`. This is done by differentiating + func_graph's outputs w.r.t. its inputs. + + Args: + func_graph: function._FuncGraph. The corresponding forward-pass function. + grads: The list of input gradient Tensors. + + Returns: + The output gradient Tensors. + """ + # Filter out untrainable function outputs. + # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes + # cause _GradientsHelper to raise an exception (e.g. the implementation + # doesn't expect 'ys' to contain boolean tensors). + assert len(func_graph.outputs) == len(grads) + ys = [] + grad_ys = [] + for y, grad_y in zip(func_graph.outputs, grads): + if not gradients_impl._IsTrainable(y): + continue + ys.append(y) + grad_ys.append(grad_y) + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _get_grad_inputs. + result = gradients_impl._GradientsHelper( + ys, func_graph.inputs, grad_ys=grad_ys, + src_graph=func_graph) + + # Functions can't return None; replace Nones with zero tensors. + # TODO(b/80444525): don't return anything here and make _IfGrad return None if + # both branches have zero gradient. + for i in range(len(result)): + if result[i] is None: + result[i] = array_ops.zeros_like(func_graph.inputs[i]) + + return result + + +def _create_grad_func(func_graph, grads, name): + """Returns the _FuncGraph representation of _grad_fn.""" + return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) + + +def _get_grad_inputs(if_op, cond_graph, grad_graph): + """Returns the tensors we should pass to grad_graph. + + This method handles tensors captured from cond_graph in grad_graph. It + converts these to suitable input tensors from the outer graph. + + Args: + if_op: Operation. The forward-pass If op that uses cond_graph. + cond_graph: function._FuncGraph. The forward-pass function. + grad_graph: function._FuncGraph. The gradients function. + + Returns: + A list of inputs tensors to be passed to grad_graph. + """ + inputs = [] + + # Maps placeholders in cond_graph -> input tensor in outer graph. + forward_input_map = {v: k for k, v in cond_graph._captured.items()} + + for t in grad_graph.extra_inputs: + if t.graph == ops.get_default_graph(): + # t is in the outer graph (e.g. one of the input gradients). + inputs.append(t) + elif t in forward_input_map: + # t is an input placeholder in cond_graph. Get the corresponding input + # tensor in the outer graph. + assert t.graph == cond_graph + assert forward_input_map[t].graph == ops.get_default_graph() + inputs.append(forward_input_map[t]) + else: + # t is an intermediate value in cond_graph. Get the corresponding output + # of 'if_op' (note that all intermediate values are outputs). + assert t.graph == cond_graph + output_idx = cond_graph.outputs.index(t) + inputs.append(if_op.outputs[output_idx]) + + return inputs + + +def _create_new_tf_function(func_graph): + """Converts func_graph to a TF_Function and adds it to the current graph. + + Args: + func_graph: function._FuncGraph + + Returns: + The name of the new TF_Function. + """ + func_graph.name = "%s_" % func_graph.name + c_func = c_api.TF_GraphToFunction_wrapper( + func_graph._c_graph, + func_graph.name, + False, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in func_graph.inputs], + [t._as_tf_output() for t in func_graph.outputs], + [], + None, # opts + None) # description + c_func = c_api_util.ScopedTFFunction(c_func) + c_api.TF_GraphCopyFunction( + ops.get_default_graph()._c_graph, c_func.func, None) + return func_graph.name + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that aren't inputs or outputs.""" + intermediates = [] + for op in func_graph.get_operations(): + for t in op.outputs: + if t in func_graph.inputs: continue + if t in func_graph.outputs: continue + intermediates.append(t) + return intermediates + + +def _separate_unique_inputs(true_inputs, false_inputs): + """Separates tensors appearing only in true_inputs or false_inputs, or both. + + Args: + true_inputs: list of Tensors + false_inputs: list of Tensors + + Returns: + Three lists of Tensors: + 1. The tensors that appear in both true_inputs and false_inputs + 2. The tensors that only appear in true_inputs + 3. The tensors that only appear in false_inputs + """ + true_inputs = set(true_inputs) + false_inputs = set(false_inputs) + + shared_inputs = true_inputs.intersection(false_inputs) + true_only_inputs = true_inputs - false_inputs + false_only_inputs = false_inputs - true_inputs + + return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) + + +def _pad_params(true_graph, false_graph, true_params, false_params): + """Returns new param lists that have matching signatures. + + This is done by mirroring each param list in the other using dummy params. + There is no merging of params. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_params: a list of Tensors from true_graph + false_params: a list of Tensors from false_graph + + Returns: + A new list of Tensors in true_graph and a new list of Tensors in + false_graph. The two lists have the same number of Tensors, with matching + types and shapes across the lists. + """ + new_true_params = (true_params + + _create_dummy_params(true_graph, false_params)) + new_false_inputs = (_create_dummy_params(false_graph, true_params) + + false_params) + return new_true_params, new_false_inputs + + +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): + """Modifies true_graph and false_graph so they have the same input signature. + + This method reorders and/or adds parameters to true_graph and false_graph so + they have the same input signature, and updates the 'inputs', 'extra_inputs', + and '_captured' fields of both graphs accordingly. It uses the input tensors + from the outer graph to avoid duplicating shared arguments. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_inputs: a list of Tensors in the outer graph. The inputs for + true_graph. + false_inputs: a list of Tensors in the outer graph. The inputs for + false_graph. + + Returns: + A new list of Tensors from the outer graph that are the new inputs for both + true_graph and false_graph. This is a deduped version of true_inputs + + false_inputs. + """ + shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( + true_inputs, false_inputs) + + new_inputs = shared_inputs + true_only_inputs + false_only_inputs + + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + + true_graph.inputs = ( + [true_input_to_param[t] for t in shared_inputs] + + [true_input_to_param[t] for t in true_only_inputs] + + _create_dummy_params(true_graph, false_only_inputs)) + + false_graph.inputs = ( + [false_input_to_param[t] for t in shared_inputs] + + _create_dummy_params(false_graph, true_only_inputs) + + [false_input_to_param[t] for t in false_only_inputs]) + + # Rewrite the _FuncGraphs' state to reflect the new inputs. + true_graph.extra_inputs = new_inputs + false_graph.extra_inputs = new_inputs + + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) + false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + + return new_inputs + + +def _create_dummy_params(func_graph, template_tensors): + """Creates tensors in func_graph to represent template_tensors. + + Args: + func_graph: function._FuncGraph. + template_tensors: a list of tensors in the outer graph. + + Returns: + A list of tensors in func_graph. + """ + with func_graph.as_default(): + return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) + for t in template_tensors] + + +def _check_same_outputs(true_graph, false_graph): + """Raises an error if true_graph and false_graph have different outputs.""" + true_output_types = [t.dtype for t in true_graph.outputs] + false_output_types = [t.dtype for t in false_graph.outputs] + if (len(true_graph.outputs) != len(false_graph.outputs) or + true_output_types != false_output_types): + raise ValueError( + "true_fn() and false_fn() must return the same number and type of " + "arguments, got:\n" + " true_fn: %s\n" + " false_fn: %s" % (true_output_types, false_output_types)) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py new file mode 100644 index 00000000000000..166002ca7faa41 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -0,0 +1,114 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.control_flow.python import cond_v2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class NewCondTest(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + with self.test_session() as sess: + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: True}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: False}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testSecondDerivative(self): + self.skipTest("b/109758172") + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + with self.test_session() as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 102bc460fdadb0..a0dd3881a86c19 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index a5e065b93a23c3..74f2ec22ffaab1 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -152,6 +152,22 @@ def testCrfLogNorm(self): self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + def testCrfLogNormZeroSeqLength(self): + """ + Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. + """ + with self.test_session() as sess: + inputs = constant_op.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = constant_op.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = constant_op.constant(np.zeros([2], + dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=np.float32) + log_norm = crf.crf_log_norm(inputs, sequence_lengths, transition_params) + tf_log_norm = sess.run(log_norm) + self.assertAllClose(tf_log_norm, expected_log_norm) + def testCrfLogLikelihood(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) @@ -292,10 +308,10 @@ def testCrfDecodeZeroSeqLength(self): dtype=np.float32)) sequence_lengths = constant_op.constant(np.zeros([2], dtype=np.int32)) - values = crf.crf_decode(inputs, transition_params, sequence_lengths) - tags, scores = sess.run(values) - self.assertEqual(len(tags.shape), 2) - self.assertEqual(len(scores.shape), 1) + tags, scores = crf.crf_decode(inputs, transition_params, sequence_lengths) + tf_tags, tf_scores = sess.run([tags, scores]) + self.assertEqual(len(tf_tags.shape), 2) + self.assertEqual(len(tf_scores.shape), 1) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index e37c029cebf30e..2d2cbdc1990ed9 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -52,6 +52,7 @@ import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -90,9 +91,13 @@ def _single_seq_fn(): batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = array_ops.reshape( math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) - return array_ops.gather_nd( + sequence_scores = array_ops.gather_nd( array_ops.squeeze(inputs, [1]), array_ops.concat([example_inds, tag_indices], axis=1)) + sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(sequence_scores), + sequence_scores) + return sequence_scores def _multi_seq_fn(): # Compute the scores of the given tag sequence. @@ -128,7 +133,12 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over # the "initial state" (the unary potentials). def _single_seq_fn(): - return math_ops.reduce_logsumexp(first_input, [1]) + log_norm = math_ops.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(log_norm), + log_norm) + return log_norm def _multi_seq_fn(): """Forward computation of alpha values.""" @@ -137,13 +147,21 @@ def _multi_seq_fn(): # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) + # Sequence length is not allowed to be less than zero. + sequence_lengths_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_lengths.dtype), + sequence_lengths - 1) _, alphas = rnn.dynamic_rnn( cell=forward_cell, inputs=rest_of_input, - sequence_length=sequence_lengths - 1, + sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=dtypes.float32) log_norm = math_ops.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(log_norm), + log_norm) return log_norm max_seq_len = array_ops.shape(inputs)[1] @@ -479,7 +497,7 @@ def _multi_seq_fn(): initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - # sequence length is not allowed to be less than zero + # Sequence length is not allowed to be less than zero. sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 012b17cee88aec..8285ea04926d3a 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -54,11 +54,11 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -717,7 +717,7 @@ def _VerifyCheckpoint( inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], dtype=dtypes.float32) cudnn_output, _ = cudnn_layer(inputs) - status.assert_consumed().run_restore_ops() + status.run_restore_ops() second_save_path = cudnn_checkpoint.save(checkpoint_prefix) restore_layer = compatible_cell_fn() restore_layer_checkpoint = checkpointable_utils.Checkpoint( @@ -728,7 +728,7 @@ def _VerifyCheckpoint( restore_layer_output, current_state = restore_layer( inputs=3. * array_ops.ones([1, input_size]), state=current_state) - status.assert_consumed().run_restore_ops() + status.run_restore_ops() self.assertTrue(restore_layer.variables) for variable, expected_value in zip( restore_layer.variables, expected_variable_values): diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 73a961992e19fa..8822a7523f6b16 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,11 +20,10 @@ import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops @@ -33,8 +32,8 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -1647,10 +1646,3 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): # 1 set of weight and bias parameters for the recurrent input, and 1 for the # previous layer input. _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER - - -ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 077cbba9d2ae41..1af1ed08b53ee0 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,11 +23,14 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@CheckpointInputPipelineHook +@@CsvDataset @@SqlDataset @@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length +@@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ -72,8 +75,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index c56910c7833d4c..7b69e10441eba3 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -29,6 +29,16 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "ignore_errors_dataset_op", srcs = ["ignore_errors_dataset_op.cc"], @@ -63,6 +73,7 @@ cc_library( cc_library( name = "dataset_kernels", deps = [ + ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":prefetching_kernels", diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc new file mode 100644 index 00000000000000..e88ad3dc32003e --- /dev/null +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -0,0 +1,758 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "na_value", &na_value)); + + std::vector record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + std::vector select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector filenames, bool header, + int64 buffer_size, const DataTypeVector& output_types, + const std::vector& output_shapes, + std::vector record_defaults, std::vector select_cols, + bool use_quote_delim, char delim, string na_value) + : GraphDatasetBase(ctx), + filenames_(std::move(filenames)), + header_(header), + buffer_size_(buffer_size), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + // TODO(rachelim): Implement this + std::vector input_tensors; + TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); + return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); + do { + // We are currently processing a file, so try to read the next record + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement save + return errors::Unimplemented("CSVDataset: SaveInternal"); + } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement restore + return errors::Unimplemented("CSVDataset: RestoreInternal"); + } + + private: + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted + // fields to output tensors as we go. + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result = Status::OK(); + + while (!end_of_record) { // Read till we reach \n, \r or EOF + bool include = + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller + } + + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, + earlier_pieces, include); + } else if (!s.ok()) { + return s; + } + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + return QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include); + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + Status s = QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include); + if (next == '\r') SkipNewLineIfNecessary(); + return s; + } else if (next != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote"); + } + + } else { + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + return UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + + if (ch == dataset()->delim_) { + Status s = UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + pos_++; + return s; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + Status s = UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return s; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Advance pos_ to the next field anyway so that we can ignore + // errors gracefully if required. The caller of this will be able to + // call ParseOneField and continue with the rest of the record. + AdvanceToNextField(end_of_record); + return errors::InvalidArgument( + "Unquoted fields cannot have quotes inside"); + } + // Otherwise, go to next character + pos_++; + } + } + + // Advances pos_ to the start of the next field, as delimited by delim, + // CRLF, or EOF, ignoring errors, and not keeping track of characters in + // the current field. + void AdvanceToNextField(bool* end_of_record) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + while (true) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + if (!s.ok()) { + *end_of_record = true; + return; + } + } + + char ch = buffer_[pos_]; + pos_++; + + if (ch == dataset()->delim_) { + return; + } + + if (ch == '\n' || ch == '\r') { + *end_of_record = true; + if (ch == '\r') SkipNewLineIfNecessary(); + return; + } + } + } + + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + float value; + if (!strings::safe_strtof(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + double value; + if (!strings::safe_strtod(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + component.scalar()() = field.ToString(); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + input_stream_.reset( + new io::RandomAccessInputStream(file_.get(), false)); + buffer_.clear(); + pos_ = 0; + if (dataset()->header_) { + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters + std::unique_ptr input_stream_ + GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector filenames_; + const bool header_; + const int64 buffer_size_; + const DataTypeVector out_type_; + const std::vector output_shapes_; + const std::vector record_defaults_; + const std::vector select_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 48d3734162525f..6a12ca06f4d6cc 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::DirectedInterleave")})); @@ -105,7 +105,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); } @@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - selector_input_impl_(params.dataset->selector_input_->MakeIterator( - params.prefix + ".selector")), - num_active_inputs_(params.dataset->data_inputs_.size()) { - data_input_impls_.reserve(params.dataset->data_inputs_.size()); - for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { - const DatasetBase* data_input = params.dataset->data_inputs_[i]; - data_input_impls_.push_back(data_input->MakeIterator( - strings::StrCat(params.prefix, "[", i, "]"))); + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bb29df60e8f114..bbec50681c6f5d 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); @@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -72,8 +74,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 63e19ae3f837c9..3dfc3741c2b040 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -127,7 +127,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { threadpool_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); @@ -140,7 +140,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -154,8 +156,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 69fbb0fcdcce87..67c237799c10a2 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unique")})); @@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("UniqueDatasetOp::Dataset"); } @@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 137deb63527f0b..f271d269ab1b93 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -34,6 +34,40 @@ data_input_datasets: `N` datasets with the same type that will be interleaved according to the values of `selector_input_dataset`. )doc"); +REGISTER_OP("CSVDataset") + .Input("filenames: string") + .Input("buffer_size: int64") + .Input("header: bool") + .Input("field_delim: string") + .Input("use_quote_delim: bool") + .Input("na_value: string") + .Input("select_cols: int64") + .Input("record_defaults: output_types") + .Output("handle: variant") + .Attr("output_types: list({float,double,int32,int64,string}) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, + // `na_value` must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + // `select_cols` must be a vector + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); + // `record_defaults` must be a list of scalars...? + for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + } + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 23a2d7351361ca..995e283a32400d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -11,7 +11,10 @@ py_test( size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_oss", # (b/79552534) + "no_pip", + ], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", @@ -32,7 +35,7 @@ py_test( py_test( name = "bucketing_test", - size = "small", + size = "medium", srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ @@ -117,6 +120,20 @@ py_library( ], ) +py_test( + name = "csv_dataset_op_test", + size = "small", + srcs = ["csv_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", @@ -192,6 +209,23 @@ py_test( ], ) +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "get_single_element_test", size = "small", @@ -246,6 +280,19 @@ py_test( ], ) +py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "prefetch_dataset_op_test", size = "small", @@ -287,6 +334,7 @@ py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -301,6 +349,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -320,11 +369,15 @@ py_test( deps = [ "//tensorflow/contrib/data/python/ops:resampling", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -407,6 +460,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index a4a0ce79b6013d..b5fbc45ad3d8d2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -427,7 +427,9 @@ def testBatchAndDropRemainderShapeInference(self): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -446,6 +448,7 @@ def _map_fn(x, y, z): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -497,12 +500,18 @@ def _map_fn(x, y, z): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -543,6 +552,44 @@ def testMapAndBatchYieldsPartialBatch(self): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(50000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(5): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(4): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + def testMapAndBatchSparse(self): def _sparse(i): @@ -630,9 +677,7 @@ def _build_dataset_dense_to_sparse(self, components): lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) - # TODO(b/70988345): Re-enable when sparse tensors are properly supported by - # the DatasetSerializationTestBase. - def _testDenseToSparseBatchDatasetCore(self): + def testDenseToSparseBatchDatasetCore(self): components = np.random.randint(5, size=(40,)).astype(np.int32) diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) @@ -684,7 +729,7 @@ def testCore(self): class MapAndBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): - def testSerializationCore(self): + def testNumParallelBatches(self): range_size = 11 num_repeats = 2 batch_size = 5 @@ -711,6 +756,33 @@ def _map_fn(x): self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), num_outputs_drop_remainder) + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 55a56b83a8efba..bd3e034211c4aa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -35,6 +36,179 @@ from tensorflow.python.platform import test +class GroupByReducerTest(test.TestCase): + + def checkResults(self, dataset, shapes, values): + self.assertEqual(shapes, dataset.output_shapes) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + for expected in values: + got = sess.run(get_next) + self.assertEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSum(self): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer(lambda x: x % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testAverage(self): + + def reduce_fn(x, y): + return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( + x[1] + 1), x[1] + 1 + + reducer = grouping.Reducer( + init_func=lambda _: (0.0, 0.0), + reduce_func=reduce_fn, + finalize_func=lambda x: x[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer( + lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) + + def testConcat(self): + components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) + reducer = grouping.Reducer( + init_func=lambda x: "", + reduce_func=lambda x, y: x + y[0], + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensor_slices(components), + dataset_ops.Dataset.range(2 * i))).apply( + grouping.group_by_reducer(lambda x, y: y % 2, reducer)) + self.checkResults( + dataset, + shapes=tensor_shape.scalar(), + values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) + + def testSparseSum(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1], dtype=np.int64)), + dense_shape=np.array([1, 1])) + + reducer = grouping.Reducer( + init_func=lambda _: _sparse(np.int64(0)), + reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), + finalize_func=lambda x: x.values[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( + grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testChangingStateShape(self): + + def reduce_fn(x, _): + # Statically known rank, but dynamic length. + larger_dim = array_ops.concat([x[0], x[0]], 0) + # Statically unknown rank. + larger_rank = array_ops.expand_dims(x[1], 0) + return larger_dim, larger_rank + + reducer = grouping.Reducer( + init_func=lambda x: ([0], 1), + reduce_func=reduce_fn, + finalize_func=lambda x: x) + + for i in range(1, 11): + dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( + grouping.group_by_reducer(lambda x: x, reducer)) + self.assertEqual([None], dataset.output_shapes[0].as_list()) + self.assertIs(None, dataset.output_shapes[1].ndims) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual([0] * (2**i), x) + self.assertAllEqual(np.array(1, ndmin=i), y) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testTypeMismatch(self): + reducer = grouping.Reducer( + init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), + reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64(0), reducer)) + + # TODO(b/78665031): Remove once non-scalar keys are supported. + def testInvalidKeyShape(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) + + # TODO(b/78665031): Remove once non-int64 keys are supported. + def testInvalidKeyType(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: "wrong", reducer)) + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + class GroupByWindowTest(test.TestCase): def testSimple(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py new file mode 100644 index 00000000000000..74b90ec7d1617d --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -0,0 +1,600 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CsvDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import string +import tempfile +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.client import session +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class CsvDatasetOpTest(test.TestCase): + + def _assert_datasets_equal(self, g, ds1, ds2): + assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' + '%s') % (ds1.output_shapes, + ds2.output_shapes) + assert ds1.output_types == ds2.output_types + assert ds1.output_classes == ds2.output_classes + next1 = ds1.make_one_shot_iterator().get_next() + next2 = ds2.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = sess.run(next1) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + sess.run(next2) + break + op2 = sess.run(next2) + self.assertAllEqual(op1, op2) + + def setup_files(self, inputs, linebreak='\n'): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) + filenames.append(fn) + return filenames + + def _make_test_datasets(self, inputs, **kwargs): + # Test by comparing its output to what we could get with map->decode_csv + filenames = self.setup_files(inputs) + dataset_expected = core_readers.TextLineDataset(filenames) + dataset_expected = dataset_expected.map( + lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + dataset_actual = readers.CsvDataset(filenames, **kwargs) + return (dataset_actual, dataset_expected) + + def _test_by_comparison(self, inputs, **kwargs): + """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" + with ops.Graph().as_default() as g: + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(g, dataset_actual, dataset_expected) + + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def _test_dataset(self, + inputs, + expected_output=None, + expected_err_re=None, + linebreak='\n', + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self.setup_files(inputs, linebreak) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): + record_defaults = [[]] * 4 + inputs = [['1,2,3,4']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_int(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_float(self): + record_defaults = [[0.0]] * 4 + inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_string(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_mixedTypes(self): + record_defaults = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.string), + constant_op.constant([], dtype=dtypes.float64) + ] + inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withUseQuoteDelimFalse(self): + record_defaults = [['']] * 4 + inputs = [['1,2,"3,4"', '"5,6",7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_withFieldDelim(self): + record_defaults = [[0]] * 4 + inputs = [['1:2:3:4', '5:6:7:8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, field_delim=':') + + def testCsvDataset_withNaValue(self): + record_defaults = [[0]] * 4 + inputs = [['1,NA,3,4', 'NA,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, na_value='NA') + + def testCsvDataset_withSelectCols(self): + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, select_cols=[1, 2]) + + def testCsvDataset_withSelectColsTooHigh(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults, + select_cols=[3, 4]) + + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withMultipleFiles(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withLeadingAndTrailingSpaces(self): + record_defaults = [[0.0]] * 4 + inputs = [['0, 1, 2, 3']] + expected = [[0.0, 1.0, 2.0, 3.0]] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithMissingDefault(self): + record_defaults = [[]] * 2 + inputs = [['0,']] + self._test_dataset( + inputs, + expected_err_re='Field 1 is required but missing in record!', + record_defaults=record_defaults) + + def testCsvDataset_errorWithFewerDefaultsThanFields(self): + record_defaults = [[0.0]] * 2 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have more in record', + record_defaults=record_defaults) + + def testCsvDataset_errorWithMoreDefaultsThanFields(self): + record_defaults = [[0.0]] * 5 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 5 fields but have 4 in record', + record_defaults=record_defaults) + + def testCsvDataset_withHeader(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2', '1,2']] + expected = [[1, 2]] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withHeaderAndNoRecords(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2']] + expected = [] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_errorWithHeaderEmptyFile(self): + record_defaults = [[0]] * 2 + inputs = [[]] + expected_err_re = "Can't read header of file" + self._test_dataset( + inputs, + expected_err_re=expected_err_re, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withEmptyFile(self): + record_defaults = [['']] * 2 + inputs = [['']] # Empty file + self._test_dataset( + inputs, expected_output=[], record_defaults=record_defaults) + + def testCsvDataset_errorWithEmptyRecord(self): + record_defaults = [['']] * 2 + inputs = [['', '1,2']] # First record is empty + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults) + + def testCsvDataset_withChainedOps(self): + # Testing that one dataset can create multiple iterators fine. + # `repeat` creates multiple iterators from the same C++ Dataset. + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', '5,6,,8']] + ds_actual, ds_expected = self._make_test_datasets( + inputs, record_defaults=record_defaults) + with ops.Graph().as_default() as g: + self._assert_datasets_equal(g, + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) + + def testCsvDataset_withTypeDefaults(self): + # Testing using dtypes as record_defaults for required fields + record_defaults = [dtypes.float32, [0.0]] + inputs = [['1.0,2.0', '3.0,4.0']] + self._test_dataset( + inputs, + [[1.0, 2.0], [3.0, 4.0]], + record_defaults=record_defaults, + ) + + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self.setup_files(data) + + with ops.Graph().as_default() as g: + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + next_batch = ds.make_one_shot_iterator().get_next() + + with self.test_session(graph=g) as sess: + result = list(sess.run(next_batch).values()) + + self.assertEqual(result, sorted(result)) + +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + + +class CsvDatasetBenchmark(test.Benchmark): + """Benchmarks for the various ways of creating a dataset from CSV files. + """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 + + def _setUp(self, str_val): + # Since this isn't test.TestCase, have to manually create a test dir + gfile.MakeDirs(googletest.GetTempDir()) + self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + + self._num_cols = [4, 64, 256] + self._num_per_iter = 5000 + self._filenames = [] + for n in self._num_cols: + fn = os.path.join(self._temp_dir, 'file%d.csv' % n) + with open(fn, 'wb') as f: + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) + self._filenames.append(fn) + + def _tearDown(self): + gfile.DeleteRecursively(self._temp_dir) + + def _runBenchmark(self, dataset, num_cols, prefix): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: + start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. + sess.run(next_element) + end = time.time() + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter + print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, + median_wall_time)) + self.report_benchmark( + iters=self._num_per_iter, + wall_time=median_wall_time, + name='%s_with_cols_%d' % (prefix, num_cols)) + + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') + self._tearDown() + + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') + self._tearDown() + + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py new file mode 100644 index 00000000000000..34b6a080c0aae7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 5000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 43aa4b1bd02791..bee561e3e23a2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -30,7 +30,6 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -907,114 +906,5 @@ def interleave_fn(x): sess.run(self.next_element) -class DirectedInterleaveDatasetTest(test.TestCase): - - def testBasic(self): - selector_dataset = dataset_ops.Dataset.range(10).repeat(100) - input_datasets = [ - dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) - ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - for _ in range(100): - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def _normalize(self, vec): - return vec / vec.sum() - - def _chi2(self, expected, actual): - actual = np.asarray(actual) - expected = np.asarray(expected) - diff = actual - expected - chi2 = np.sum(diff * diff / expected, axis=0) - return chi2 - - def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): - # Create a dataset that samples each integer in `[0, num_datasets)` - # with probability given by `weights[i]`. - dataset = interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(num_datasets) - ], weights) - dataset = dataset.take(num_samples) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - freqs = np.zeros([num_datasets]) - for _ in range(num_samples): - freqs[sess.run(next_element)] += 1 - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - return freqs - - def testSampleFromDatasets(self): - random_seed.set_random_seed(1619) - num_samples = 10000 - rand_probs = self._normalize(np.random.random_sample((15,))) - - # Use chi-squared test to assert that the observed distribution matches the - # expected distribution. Based on the implementation in - # "tensorflow/python/kernel_tests/multinomial_op_test.py". - for probs in [[.85, .05, .1], rand_probs]: - probs = np.asarray(probs) - classes = len(probs) - freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - # Also check that `weights` as a dataset samples correctly. - probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() - freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - def testErrors(self): - with self.assertRaisesRegexp(ValueError, - r"vector of length `len\(datasets\)`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[0.25, 0.25, 0.25, 0.25]) - - with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[1, 1]) - - with self.assertRaisesRegexp(TypeError, "must have the same type"): - interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(0), - dataset_ops.Dataset.from_tensors(0.0) - ]) - - -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 00000000000000..30f1847dcddbfa --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1075302bae96ca..e0237198b7d47e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -256,6 +257,29 @@ def testTFRecordWithCompressionCore(self): lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) +def _interleave(iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + class ReadBatchFeaturesTest(test.TestCase): def setUp(self): @@ -355,8 +379,8 @@ def _next_record(file_indices): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return self._interleave([_next_record([i]) for i in file_indices], - cycle_length) + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) file_batch = [] keywords_batch_indices = [] @@ -397,28 +421,6 @@ def _next_record_interleaved(file_indices, cycle_length): [len(file_batch), keywords_batch_max_len], record_batch ] - def _interleave(self, iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 - def _verify_records(self, sess, batch_size, @@ -620,14 +622,12 @@ def _write_file(self, filename, rows): f.close() return fn - def _create_file(self, fileno, header=True, comment=True): + def _create_file(self, fileno, header=True): rows = [] if header: rows.append(self.COLUMNS) for recno in range(self._num_records): rows.append(self._csv_values(fileno, recno)) - if comment: - rows.append("# Some comment goes here. Ignore me.") return self._write_file("csv_file%d.csv" % fileno, rows) def _create_files(self): @@ -648,9 +648,7 @@ def _make_csv_dataset( shuffle=False, shuffle_seed=None, header=True, - comment="#", na_value="", - default_float_type=dtypes.float32, ): return readers.make_csv_dataset( filenames, @@ -662,9 +660,7 @@ def _make_csv_dataset( shuffle=shuffle, shuffle_seed=shuffle_seed, header=header, - comment=comment, na_value=na_value, - default_float_type=default_float_type, select_columns=select_cols, ) @@ -786,29 +782,6 @@ def testMakeCSVDataset_withNoLabel(self): num_epochs=10, label_name=None) - def testMakeCSVDataset_withNoComments(self): - """Tests that datasets can be created from CSV files with no header line. - """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), comment=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - comment=None, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) - def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ @@ -876,7 +849,7 @@ def testMakeCSVDataset_withTypeInference(self): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with comments, header, etc) + # Test that it works with standard test files (with header, etc) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -889,7 +862,9 @@ def testMakeCSVDataset_withTypeInference(self): num_epochs=10, defaults=[[], [], [], [], [""]]) - # Test on a deliberately tricky file + def testMakeCSVDataset_withTypeInferenceTricky(self): + # Test on a deliberately tricky file (type changes as we read more rows, and + # there are null values) fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, @@ -914,20 +889,29 @@ def testMakeCSVDataset_withTypeInference(self): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float32, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match for i in range(len(expected_dtypes)): + print(features["col%d" % i].dtype, expected_dtypes[i]) assert features["col%d" % i].dtype == expected_dtypes[i] for i in range(len(rows)): assert sess.run(features) == dict(zip(col_names, expected[i])) - # With float64 as default type for floats + def testMakeCSVDataset_withTypeInferenceAllTypes(self): + # Test that we make the correct inference for all types with fallthrough + fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string, dtypes.string ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] + expected = [[ + 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + ]] + self._write_file("file.csv", [col_names] + rows) + with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -936,7 +920,6 @@ def testMakeCSVDataset_withTypeInference(self): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float64, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match @@ -1086,5 +1069,189 @@ def testMakeCSVDataset_withShuffle(self): self.assertFalse(all_equal) +class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length, + drop_final_batch, + use_parser_fn): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) + + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for f, r in next_records: + record = self._record(f, r) + if use_parser_fn: + record = record[1:] + record_batch.append(record) + batch_index += 1 + if len(record_batch) == batch_size: + yield record_batch + record_batch = [] + batch_index = 0 + if record_batch and not drop_final_batch: + yield record_batch + + def _verify_records(self, + sess, + outputs, + batch_size, + file_index, + num_epochs, + interleave_cycle_length, + drop_final_batch, + use_parser_fn): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length, + drop_final_batch, use_parser_fn): + actual_batch = sess.run(outputs) + self.assertAllEqual(expected_batch, actual_batch) + + def _read_test(self, batch_size, num_epochs, file_index=None, + num_parallel_reads=1, drop_final_batch=False, parser_fn=False): + if file_index is None: + file_pattern = self.test_filenames + else: + file_pattern = self.test_filenames[file_index] + + if parser_fn: + fn = lambda x: string_ops.substr(x, 1, 999) + else: + fn = None + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs = readers.make_tf_record_dataset( + file_pattern=file_pattern, + num_epochs=num_epochs, + batch_size=batch_size, + parser_fn=fn, + num_parallel_reads=num_parallel_reads, + drop_final_batch=drop_final_batch, + shuffle=False).make_one_shot_iterator().get_next() + self._verify_records( + sess, outputs, batch_size, file_index, num_epochs=num_epochs, + interleave_cycle_length=num_parallel_reads, + drop_final_batch=drop_final_batch, use_parser_fn=parser_fn) + with self.assertRaises(errors.OutOfRangeError): + sess.run(outputs) + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + # Basic test: read from file 0. + self._read_test(batch_size, num_epochs, 0) + + # Basic test: read from file 1. + self._read_test(batch_size, num_epochs, 1) + + # Basic test: read from both files. + self._read_test(batch_size, num_epochs) + + # Basic test: read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8) + + def testDropFinalBatch(self): + for batch_size in [1, 2, 10]: + for num_epochs in [1, 3]: + # Read from file 0. + self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) + + # Read from both files. + self._read_test(batch_size, num_epochs, drop_final_batch=True) + + # Read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + drop_final_batch=True) + + def testParserFn(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for drop_final_batch in [False, True]: + self._read_test(batch_size, num_epochs, parser_fn=True, + drop_final_batch=drop_final_batch) + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + parser_fn=True, drop_final_batch=drop_final_batch) + + def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, + seed=None): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + num_parallel_reads=num_parallel_reads, + shuffle=True, + shuffle_seed=seed) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + sess.run(iterator.initializer) + first_batches = [] + try: + while True: + first_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + sess.run(iterator.initializer) + second_batches = [] + try: + while True: + second_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + self.assertEqual(len(first_batches), len(second_batches)) + if seed is not None: + # if you set a seed, should get the same results + for i in range(len(first_batches)): + self.assertAllEqual(first_batches[i], second_batches[i]) + + expected = [] + for f in range(self._num_files): + for r in range(self._num_records): + expected.extend([self._record(f, r)] * num_epochs) + + for batches in (first_batches, second_batches): + actual = [] + for b in batches: + actual.extend(b) + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testShuffle(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for num_parallel_reads in [1, 2]: + # Test that all expected elements are produced + self._shuffle_test(batch_size, num_epochs, num_parallel_reads) + # Test that elements are produced in a consistent order if + # you specify a seed. + self._shuffle_test(batch_size, num_epochs, num_parallel_reads, + seed=21345) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 5f47dcb3399911..bdc003a8a5bd64 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -18,6 +18,9 @@ from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +import time +from absl.testing import parameterized from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops @@ -30,52 +33,98 @@ from tensorflow.python.util import compat -class ResampleTest(test.TestCase): +def _time_resampling( + test_obj, data_np, target_dist, init_dist, num_to_sample): + dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat() - def testInitialKnownDistribution(self): - self._testDistribution(initial_known=True) + # Reshape distribution via rejection sampling. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist, + seed=142)) - def testInitialNotKnownDistribution(self): - self._testDistribution(initial_known=False) + get_next = dataset.make_one_shot_iterator().get_next() - def _testDistribution(self, initial_known): + with test_obj.test_session() as sess: + start_time = time.time() + for _ in xrange(num_to_sample): + sess.run(get_next) + end_time = time.time() + + return end_time - start_time + + +class ResampleTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ("InitialDistributionKnown", True), + ("InitialDistributionUnknown", False)) + def testDistribution(self, initial_known): classes = np.random.randint(5, size=(20000,)) # Uniformly sampled target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] initial_dist = [0.2] * 5 if initial_known else None - iterator = (dataset_ops.Dataset.from_tensor_slices(classes).shuffle( - 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).apply( - resampling.rejection_resample( - target_dist=target_dist, - initial_dist=initial_dist, - class_func=lambda c, _: c, - seed=27)).make_one_shot_iterator()) - get_next = iterator.get_next() + classes = math_ops.to_int64(classes) # needed for Windows build. + dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( + 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() + + get_next = dataset.apply( + resampling.rejection_resample( + target_dist=target_dist, + initial_dist=initial_dist, + class_func=lambda c, _: c, + seed=27)).make_one_shot_iterator().get_next() with self.test_session() as sess: returned = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - returned.append(sess.run(get_next)) + while len(returned) < 4000: + returned.append(sess.run(get_next)) returned_classes, returned_classes_and_data = zip(*returned) _, returned_data = zip(*returned_classes_and_data) self.assertAllEqual([compat.as_bytes(str(c)) for c in returned_classes], returned_data) total_returned = len(returned_classes) - # Subsampling rejects a large percentage of the initial data in - # this case. - self.assertGreater(total_returned, 20000 * 0.2) class_counts = np.array([ len([True for v in returned_classes if v == c]) for c in range(5)]) returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) + @parameterized.named_parameters( + ("OnlyInitial", True), + ("NotInitial", False)) + def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): + init_dist = [0.5, 0.5] + target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test that this works. + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + def testRandomClasses(self): init_dist = [0.25, 0.25, 0.25, 0.25] target_dist = [0.0, 0.0, 0.0, 1.0] num_classes = len(init_dist) - # We don't need many samples to test a dirac-delta target distribution + # We don't need many samples to test a dirac-delta target distribution. num_samples = 100 data_np = np.random.choice(num_classes, num_samples, p=init_dist) @@ -109,5 +158,23 @@ def _remap_fn(_): self.assertAllClose(target_dist, bincount, atol=1e-2) + +class ResampleDatasetBenchmark(test.Benchmark): + + def benchmarkResamplePerformance(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 1000 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + resample_time = _time_resampling( + self, data_np, target_dist, init_dist, num_to_sample=1000) + + self.report_benchmark( + iters=1000, wall_time=resample_time, name="benchmark_resample") + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 1a97a84b2cba13..eb2ceff893543f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -35,15 +36,19 @@ class ScanDatasetTest(test.TestCase): - def _count(self, start, step): - return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( - scan_ops.scan(start, lambda state, _: (state + step, state))) + def _counting_dataset(self, start, scan_fn): + return dataset_ops.Dataset.from_tensors(0).repeat().apply( + scan_ops.scan(start, scan_fn)) def testCount(self): + def make_scan_fn(step): + return lambda state, _: (state + step, state) + start = array_ops.placeholder(dtypes.int32, shape=[]) step = array_ops.placeholder(dtypes.int32, shape=[]) take = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = self._count(start, step).take(take).make_initializable_iterator() + iterator = self._counting_dataset( + start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() with self.test_session() as sess: @@ -78,6 +83,37 @@ def testFibonacci(self): self.assertEqual(5, self.evaluate(next_element())) self.assertEqual(8, self.evaluate(next_element())) + def testSparseCount(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def make_scan_fn(step): + return lambda state, _: (_sparse(state.values[0] + step), state) + + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._counting_dataset( + _sparse(start), + make_scan_fn(step)).take(take).make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element).values[0]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with # initial values with known shapes, and use a scan function that @@ -132,7 +168,7 @@ def _scan_fn(unused_state, unused_input_value): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerialzationTest( +class ScanDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_dataset(self, num_elements): diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e26cef8ec522c7..4148addf2878c9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -22,6 +22,7 @@ import sqlite3 +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.platform import test -class SqlDatasetTest(test.TestCase): +class SqlDatasetTestBase(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): dataset = readers.SqlDataset(self.driver_name, self.data_source_name, @@ -92,6 +93,9 @@ def setUp(self): conn.commit() conn.close() + +class SqlDatasetTest(SqlDatasetTestBase): + # Test that SqlDataset can read from a database table. def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, @@ -652,5 +656,27 @@ def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): sess.run(get_next) +class SqlDatasetSerializationTest( + SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 5b04c5316cfbb7..086661adb76033 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -45,6 +45,27 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -187,12 +208,27 @@ py_library( ], ) +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], srcs_version = "PY2AND3", deps = [ ":batching", + ":interleave_ops", ":scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -202,6 +238,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) @@ -345,6 +382,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":optimization", ":prefetching_ops", ":readers", ":resampling", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 2152bcde84aae6..b9393de4e90ae2 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -364,7 +364,7 @@ def __init__(self, with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() - self._dataset = dataset + self._input_dataset = dataset if not allow_unsafe_cast: # Validate that the types are compatible. @@ -408,7 +408,7 @@ def __init__(self, self._output_classes = output_classes def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access + return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access @property def output_classes(self): @@ -466,14 +466,14 @@ def _apply_fn(dataset): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -483,12 +483,12 @@ def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -511,8 +511,9 @@ def output_types(self): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -528,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 0531f9cbb9da6e..ea229b5b27b117 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -33,6 +34,35 @@ from tensorflow.python.ops import math_ops +def group_by_reducer(key_func, reducer): + """A transformation that groups elements and performs a reduction. + + This transformation maps element of a dataset to a key using `key_func` and + groups the elements by key. The `reducer` is used to process each group; its + `init_func` is used to initialize state for each group when it is created, the + `reduce_func` is used to update the state every time an element is mapped to + the matching group, and the `finalize_func` is used to map the final state to + an output value. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reducer: An instance of `Reducer`, which captures the reduction logic using + the `init_func`, `reduce_func`, and `finalize_func` functions. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return GroupByReducerDataset(dataset, key_func, reducer) + + return _apply_fn + + def group_by_window(key_func, reduce_func, window_size=None, @@ -227,6 +257,250 @@ def output_types(self): return self._output_types +class GroupByReducerDataset(dataset_ops.Dataset): + """A `Dataset` that groups its input and performs a reduction.""" + + def __init__(self, input_dataset, key_func, reducer): + """See `group_by_reducer()` for details.""" + super(GroupByReducerDataset, self).__init__() + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_init_func(reducer.init_func) + self._make_reduce_func(reducer.reduce_func, input_dataset) + self._make_finalize_func(reducer.finalize_func) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) + def tf_key_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + # pylint: disable=protected-access + if dataset_ops._should_unpack_args(nested_args): + ret = key_func(*nested_args) + # pylint: enable=protected-access + else: + ret = key_func(nested_args) + ret = ops.convert_to_tensor(ret) + if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) + return ret + + self._key_func = tf_key_func + self._key_func.add_to_graph(ops.get_default_graph()) + + def _make_init_func(self, init_func): + """Make wrapping Defun for init_func.""" + + @function.Defun(dtypes.int64) + def tf_init_func(key): + """A wrapper for Defun that facilitates shape inference.""" + key.set_shape([]) + ret = init_func(key) + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._state_classes = sparse.get_classes(ret) + self._state_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._state_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._init_func = tf_init_func + self._init_func.add_to_graph(ops.get_default_graph()) + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + # Iteratively rerun the reduce function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + # Create a list in which `tf_reduce_func` will store the new shapes. + flat_new_state_shapes = [] + + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) + def tf_reduce_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): + arg.set_shape(shape) + + pivot = len(nest.flatten(self._state_shapes)) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) + + ret = reduce_func(nested_state_args, nested_input_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + # Extract shape information from the returned values. + flat_new_state = nest.flatten(ret) + flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, + nest.pack_sequence_as(self._state_types, + [t.dtype for t in flat_new_state]))) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, + [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + # Use the private method that will execute `tf_reduce_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access + + flat_state_shapes = nest.flatten(self._state_shapes) + weakened_state_shapes = [ + old.most_specific_compatible_shape(new) + for old, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for old_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if old_shape.ndims is not None and ( + weakened_shape.ndims is None or + old_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._reduce_func = tf_reduce_func + self._reduce_func.add_to_graph(ops.get_default_graph()) + + def _make_finalize_func(self, finalize_func): + """Make wrapping Defun for finalize_func.""" + + @function.Defun(*(nest.flatten( + sparse.as_dense_types(self._state_types, self._state_classes)))) + def tf_finalize_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes))): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(self._state_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, self._state_types, self._state_shapes, + self._state_classes) + + ret = finalize_func(nested_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._output_classes = sparse.get_classes(ret) + self._output_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._output_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._finalize_func = tf_finalize_func + self._finalize_func.add_to_graph(ops.get_default_graph()) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_reducer_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._init_func.captured_inputs, + self._reduce_func.captured_inputs, + self._finalize_func.captured_inputs, + key_func=self._key_func, + init_func=self._init_func, + reduce_func=self._reduce_func, + finalize_func=self._finalize_func, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + class GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" @@ -336,3 +610,30 @@ def _as_variant_tensor(self): sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + +class Reducer(object): + """A reducer is used for reducing a set of elements. + + A reducer is represented as a tuple of the three functions: + 1) initialization function: key => initial state + 2) reduce function: (old state, input) => new state + 3) finalization function: state => result + """ + + def __init__(self, init_func, reduce_func, finalize_func): + self._init_func = init_func + self._reduce_func = reduce_func + self._finalize_func = finalize_func + + @property + def init_func(self): + return self._init_func + + @property + def reduce_func(self): + return self._reduce_func + + @property + def finalize_func(self): + return self._finalize_func diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 812a50ecbf1053..be66fbac50753c 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -240,3 +241,47 @@ def select_dataset(logits, seed): (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) return DirectedInterleaveDataset(selector_input, datasets) + + +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb035e5..0d71be66018eee 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,10 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import saver +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook def make_saveable_from_iterator(iterator): @@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator): return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access -class _Saveable(saver.BaseSaverBuilder.SaveableObject): +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): """SaveableObject for saving/restoring iterator state.""" def __init__(self, iterator_resource): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ - saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") ] super(_Saveable, self).__init__(iterator_resource, specs, iterator_resource.name) @@ -75,3 +77,182 @@ def __init__(self, iterator_resource): def restore(self, restored_tensors, unused_restored_shapes): with ops.colocate_with(self.op): return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input__.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + self._first_run = True + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def _restore_or_save_initial_ckpt(self, session): + # Ideally this should be run in after_create_session but is not for the + # following reason: + # Currently there is no way of enforcing an order of running the + # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` + # is run *after* this hook. That is troublesome because + # 1. If a checkpoint exists and this hook restores it, the initializer hook + # will override it. + # 2. If no checkpoint exists, this hook will try to save an initialized + # iterator which will result in an exception. + # + # As a temporary fix we enter the following implicit contract between this + # hook and the _DatasetInitializerHook. + # 1. The _DatasetInitializerHook initializes the iterator in the call to + # after_create_session. + # 2. This hook saves the iterator on the first call to `before_run()`, which + # is guaranteed to happen after `after_create_session()` of all hooks + # have been run. + + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = saver_lib.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + if self._first_run: + self._restore_or_save_initial_ckpt(run_context.session) + self._first_run = False + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py new file mode 100644 index 00000000000000..30a993b1f7056b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test.TestCase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.Variable( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = saver_lib.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 00000000000000..cad41bce2961f2 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index bbb808fbd77300..f938153f5f8c8b 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,16 +17,18 @@ from __future__ import division from __future__ import print_function +import collections import csv -from math import ceil import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,9 +36,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -68,7 +68,7 @@ def _is_valid_float(str_val, float_dtype): return False -def _infer_type(str_val, na_value, prev_type, float_dtype): +def _infer_type(str_val, na_value, prev_type): """Given a string, infers its tensor type. Infers the type of a value by picking the least 'permissive' type possible, @@ -79,29 +79,34 @@ def _infer_type(str_val, na_value, prev_type, float_dtype): na_value: Additional string to recognize as a NA/NaN CSV value. prev_type: Type previously inferred based on values of this column that we've seen up till now. - float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type - to parse float strings as. Returns: Inferred dtype. """ if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type return prev_type - if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): - return dtypes.int32 + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most - if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, - dtypes.int64): - return dtypes.int64 + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions - if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: - return float_dtype + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] - return dtypes.string - -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment): +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): + """Generator that yields rows of CSV file(s) in order.""" for fn in filenames: with file_io.FileIO(fn, "r") as f: rdr = csv.reader( @@ -112,9 +117,6 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, next(rdr) # Skip header lines for csv_row in rdr: - if comment is not None and csv_row[0].startswith(comment): - continue # Skip comment lines - if len(csv_row) != num_cols: raise ValueError( "Problem inferring types: CSV row has different number of fields " @@ -123,22 +125,21 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, comment, float_dtype, - num_rows_for_inference, select_columns): + na_value, header, num_rows_for_inference, + select_columns): """Infers column types from the first N valid CSV records of files.""" if select_columns is None: select_columns = range(num_cols) inferred_types = [None] * len(select_columns) for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment)): + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): if num_rows_for_inference is not None and i >= num_rows_for_inference: break for j, col_index in enumerate(select_columns): inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j], float_dtype) + inferred_types[j]) # Replace None's with a default type inferred_types = [t or dtypes.string for t in inferred_types] @@ -198,6 +199,112 @@ def _get_sorted_col_indices(select_columns, column_names): return result +def _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): + """Optionally shuffle and repeat dataset, as requested.""" + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + return dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + return dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + return dataset.repeat(num_epochs) + return dataset + + +def make_tf_record_dataset( + file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=None, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): + """Reads and optionally parses TFRecord files into a dataset. + + Provides common functionality such as batching, optional parsing, shuffling, + and performant defaults. + + Args: + file_pattern: List of files or patterns of TFRecord file paths. + See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + parser_fn: (Optional.) A function accepting string input to parse + and process the record contents. This function must map records + to components of a fixed shape, so they may be batched. By + default, uses the record contents unmodified. + num_epochs: (Optional.) An int specifying the number of times this + dataset is repeated. If None (the default), cycles through the + dataset forever. + shuffle: (Optional.) A bool that indicates whether the input + should be shuffled. Defaults to `True`. + shuffle_buffer_size: (Optional.) Buffer size to use for + shuffling. A large buffer size ensures better shuffling, but + increases memory usage and startup time. + shuffle_seed: (Optional.) Randomization seed to use for shuffling. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + num_parallel_reads: (Optional.) Number of threads used to read + records from files. By default or if set to a value >1, the + results will be interleaved. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + + Returns: + A dataset, where each element matches the output of `parser_fn` + except it will have an additional leading `batch-size` dimension, + or a `batch_size`-length 1-D tensor of strings if `parser_fn` is + unspecified. + """ + files = dataset_ops.Dataset.list_files( + file_pattern, shuffle=shuffle, seed=shuffle_seed) + + if num_parallel_reads is None: + # Note: We considered auto-tuning this value, but there is a concern + # that this affects the mixing of records from different files, which + # could affect training convergence/accuracy, so we are defaulting to + # a constant for now. + num_parallel_reads = 24 + dataset = core_readers.TFRecordDataset( + files, num_parallel_reads=num_parallel_reads) + + if shuffle_buffer_size is None: + # TODO(josh11b): Auto-tune this value when not specified + shuffle_buffer_size = 10000 + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + if parser_fn is None: + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + else: + # TODO(josh11b): if num_parallel_parser_calls is None, use some function + # of num cores instead of map_and_batch's default behavior of one batch. + dataset = dataset.apply(batching.map_and_batch( + parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + + if prefetch_buffer_size is None: + prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) + + def make_csv_dataset( file_pattern, batch_size, @@ -209,7 +316,6 @@ def make_csv_dataset( use_quote_delim=True, na_value="", header=True, - comment=None, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, @@ -218,7 +324,6 @@ def make_csv_dataset( num_parallel_reads=1, num_parallel_parser_calls=2, sloppy=False, - default_float_type=dtypes.float32, num_rows_for_inference=100, ): """Reads CSV files into a dataset. @@ -231,8 +336,8 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV records. See @{tf.gfile.Glob} for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. column_names: An optional list of strings that corresponds to the CSV columns, in order. One per column of the input record. If this is not provided, infers the column names from the first row of the records. @@ -272,15 +377,11 @@ def make_csv_dataset( header: A bool that indicates whether the first rows of provided CSV files correspond to header lines with column names, and should not be included in the data. - comment: An optional character string that marks lines that should not be - parsed as csv records. If this is provided, all lines that start with - this character will not be parsed. num_epochs: An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever. shuffle: A bool that indicates whether the input should be shuffled. shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size - ensures better shuffling, but would increase memory usage and startup - time. + ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Recommended value is the number of @@ -294,8 +395,6 @@ def make_csv_dataset( produced is deterministic prior to shuffling (elements are still randomized if `shuffle=True`. Note that if the seed is set, then order of elements after shuffling is deterministic). Defaults to `False`. - default_float_type: Either `tf.float32` or `tf.float64`. If defaults are - not provided, float-like strings are interpreted to be this type. num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. @@ -317,8 +416,6 @@ def make_csv_dataset( dataset = dataset.shuffle(len(filenames), shuffle_seed) # Clean arguments; figure out column names and defaults - if comment is not None and len(comment) != 1: - raise ValueError("`comment` arg must be a single-character string or None") if column_names is None: if not header: @@ -341,8 +438,7 @@ def make_csv_dataset( # construction time column_defaults = _infer_column_defaults( filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, comment, default_float_type, num_rows_for_inference, - select_columns) + header, num_rows_for_inference, select_columns) if select_columns is not None and len(column_defaults) != len(select_columns): raise ValueError( @@ -356,71 +452,189 @@ def make_csv_dataset( if label_name is not None and label_name not in column_names: raise ValueError("`label_name` provided must be one of the columns.") - # Define map and filter functions - def filter_fn(line): - return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) - def filename_to_dataset(filename): - ds = core_readers.TextLineDataset(filename) - if header: - ds = ds.skip(1) - if comment is not None: - ds = ds.filter(filter_fn) - return ds + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header) - def decode_csv(line): - """Decodes CSV line into features. + def map_fn(*columns): + """Organizes columns into a features dictionary. Args: - line: String tensor corresponding to one csv record. + *columns: list of `Tensor`s corresponding to one csv record. Returns: - A dictionary of feature names to values for that particular record. If + An OrderedDict of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - columns = parsing_ops.decode_csv( - line, - column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - ) - features = dict(zip(column_names, columns)) + features = collections.OrderedDict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label return features - # Read files sequentially or in parallel + # Read files sequentially (if num_parallel_reads=1) or in parallel dataset = dataset.apply( interleave_ops.parallel_interleave( filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) - if num_epochs != 1 and shuffle: - # Use shuffle_and_repeat for perf - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) - elif num_epochs != 1: - dataset = dataset.repeat(num_epochs) - - # Use map_and_batch for perf - # TODO(b/76425672): use num_parallel_calls for better performance tuning when - # that is added - dataset = dataset.apply( - batching.map_and_batch( - map_func=decode_csv, - batch_size=batch_size, - num_parallel_batches=int( - ceil(num_parallel_parser_calls / batch_size)))) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map + dataset = dataset.batch(batch_size=batch_size) + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) + return dataset +_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + + +class CsvDataset(dataset_ops.Dataset): + """A Dataset comprising lines from one or more CSV files.""" + + def __init__(self, + filenames, + record_defaults, + buffer_size=None, + header=False, + field_delim=",", + use_quote_delim=True, + na_value="", + select_cols=None): + """Creates a `CsvDataset` by reading and decoding CSV files. + + The elements of this dataset correspond to records from the file(s). + RFC 4180 format is expected for CSV files + (https://tools.ietf.org/html/rfc4180) + Note that we allow leading and trailing spaces with int or float field. + + + For example, suppose we have a file 'my_file0.csv' with four CSV columns of + different data types: + ``` + abcdefg,4.28E10,5.55E6,12 + hijklmn,-5.3E14,,2 + ``` + + We can construct a CsvDataset from it as follows: + ```python + dataset = tf.contrib.data.CsvDataset( + "my_file*.csv", + [tf.float32, # Required field, use dtype or empty tensor + tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 + tf.int32, # Required field, use dtype or empty tensor + ], + select_cols=[1,2,3] # Only parse last three columns + ) + ``` + + The expected output of its iterations is: + ```python + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + + >> (4.28e10, 5.55e6, 12) + >> (-5.3e14, 0.0, 2) + ``` + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_defaults: A list of default values for the CSV fields. Each item in + the list is either a valid CSV `DType` (float32, float64, int32, int64, + string), or a `Tensor` object with one of the above types. One per + column of CSV data, with either a scalar `Tensor` default value for the + column if it is optional, or `DType` or empty `Tensor` if required. If + both this and `select_columns` are specified, these must have the same + lengths, and `column_defaults` is assumed to be sorted in order of + increasing column index. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer while reading files. Defaults to 4MB. + header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) + have header line(s) that should be skipped when parsing. Defaults to + `False`. + field_delim: (Optional.) A `tf.string` scalar containing the delimiter + character that separates fields in a record. Defaults to `","`. + use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats + double quotation marks as regular characters inside of string fields + (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. + na_value: (Optional.) A `tf.string` scalar indicating a value that will + be treated as NA/NaN. + select_cols: (Optional.) A sorted list of column indices to select from + the input data. If specified, only this subset of columns will be + parsed. Defaults to parsing all columns. + """ + super(CsvDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + record_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in record_defaults + ] + self._record_defaults = ops.convert_n_to_tensor( + record_defaults, name="record_defaults") + self._buffer_size = convert.optional_param_to_tensor( + "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) + self._header = ops.convert_to_tensor( + header, dtype=dtypes.bool, name="header") + self._field_delim = ops.convert_to_tensor( + field_delim, dtype=dtypes.string, name="field_delim") + self._use_quote_delim = ops.convert_to_tensor( + use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") + self._na_value = ops.convert_to_tensor( + na_value, dtype=dtypes.string, name="na_value") + self._select_cols = convert.optional_param_to_tensor( + "select_cols", + select_cols, + argument_default=[], + argument_dtype=dtypes.int64, + ) + self._output_shapes = tuple( + tensor_shape.scalar() for _ in range(len(record_defaults))) + self._output_types = tuple(d.dtype for d in self._record_defaults) + self._output_classes = tuple( + ops.Tensor for _ in range(len(record_defaults))) + + def _as_variant_tensor(self): + # Constructs graph node for the dataset op. + return contrib_gen_dataset_ops.csv_dataset( + filenames=self._filenames, + record_defaults=self._record_defaults, + buffer_size=self._buffer_size, + header=self._header, + output_shapes=self._output_shapes, + field_delim=self._field_delim, + use_quote_delim=self._use_quote_delim, + na_value=self._na_value, + select_cols=self._select_cols, + ) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + def make_batched_features_dataset(file_pattern, batch_size, features, @@ -480,8 +694,8 @@ def make_batched_features_dataset(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be @@ -537,16 +751,8 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.map(lambda _, v: v) # Apply dataset repeat and shuffle transformations. - repeat_dataset = (num_epochs != 1) - if repeat_dataset and shuffle: - # Used fused shuffle_and_repeat operation for better performance - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif repeat_dataset: - dataset = dataset.repeat(num_epochs) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) @@ -620,8 +826,8 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index a182dddd38d23d..bad6edd5147d83 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -20,10 +20,12 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops @@ -50,79 +52,182 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. """ - def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - dist_estimation_batch_size = 32 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") class_values_ds = dataset.map(class_func) + + # Get initial distribution. if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") - acceptance_dist = _calculate_acceptance_probs(initial_dist_t, - target_dist_t) + acceptance_dist, prob_of_original = ( + _calculate_acceptance_probs_with_mixing(initial_dist_t, + target_dist_t)) initial_dist_ds = dataset_ops.Dataset.from_tensors( initial_dist_t).repeat() acceptance_dist_ds = dataset_ops.Dataset.from_tensors( acceptance_dist).repeat() + prob_of_original_ds = dataset_ops.Dataset.from_tensors( + prob_of_original).repeat() + else: + initial_dist_ds = _estimate_initial_dist_ds( + target_dist_t, class_values_ds) + acceptance_and_original_prob_ds = initial_dist_ds.map( + lambda initial: _calculate_acceptance_probs_with_mixing( + initial, target_dist_t)) + acceptance_dist_ds = acceptance_and_original_prob_ds.map( + lambda accept_prob, _: accept_prob) + prob_of_original_ds = acceptance_and_original_prob_ds.map( + lambda _, prob_original: prob_original) + filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, + class_values_ds, seed) + # Prefetch filtered dataset for speed. + filtered_ds = filtered_ds.prefetch(3) + + prob_original_static = _get_prob_original_static( + initial_dist_t, target_dist_t) if initial_dist is not None else None + if prob_original_static == 1: + return dataset_ops.Dataset.zip((class_values_ds, dataset)) + elif prob_original_static == 0: + return filtered_ds else: - num_classes = (target_dist_t.shape[0].value or - array_ops.shape(target_dist_t)[0]) - smoothing_constant = 10 - initial_examples_per_class_seen = array_ops.fill( - [num_classes], np.int64(smoothing_constant)) - - def update_estimate_and_tile(num_examples_per_class_seen, c): - updated_examples_per_class_seen, dist = _estimate_data_distribution( - c, num_examples_per_class_seen) - tiled_dist = array_ops.tile( - array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) - return updated_examples_per_class_seen, tiled_dist - - initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .apply(scan_ops.scan(initial_examples_per_class_seen, - update_estimate_and_tile)) - .apply(batching.unbatch())) - acceptance_dist_ds = initial_dist_ds.map( - lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) - - def maybe_warn_on_large_rejection(accept_dist, initial_dist): - proportion_rejected = math_ops.reduce_sum( - (1 - accept_dist) * initial_dist) - return control_flow_ops.cond( - math_ops.less(proportion_rejected, .5), - lambda: accept_dist, - lambda: logging_ops.Print( # pylint: disable=g-long-lambda - accept_dist, [proportion_rejected, initial_dist, accept_dist], - message="Proportion of examples rejected by sampler is high: ", - summarize=100, - first_n=10)) - - acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, - initial_dist_ds)) - .map(maybe_warn_on_large_rejection)) - - def _gather_and_copy(class_val, acceptance_prob, data): - return (class_val, array_ops.gather(acceptance_prob, class_val), data) - current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( - (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) - filtered_ds = ( - current_probabilities_and_class_and_data_ds - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) - return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + return interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], + weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), + seed=seed) return _apply_fn -def _calculate_acceptance_probs(initial_probs, target_probs): - """Calculate the per-class acceptance rates. +def _get_prob_original_static(initial_dist_t, target_dist_t): + """Returns the static probability of sampling from the original. + + `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters + an Op that it isn't defined for. We have some custom logic to avoid this. + + Args: + initial_dist_t: A tensor of the initial distribution. + target_dist_t: A tensor of the target distribution. + + Returns: + The probability of sampling from the original distribution as a constant, + if it is a constant, or `None`. + """ + init_static = tensor_util.constant_value(initial_dist_t) + target_static = tensor_util.constant_value(target_dist_t) + + if init_static is None or target_static is None: + return None + else: + return np.min(target_static / init_static) + + +def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, + seed): + """Filters a dataset based on per-class acceptance probabilities. Args: - initial_probs: The class probabilities of the data. - target_probs: The desired class proportion in minibatches. + dataset: The dataset to be filtered. + acceptance_dist_ds: A dataset of acceptance probabilities. + initial_dist_ds: A dataset of the initial probability distribution, given or + estimated. + class_values_ds: A dataset of the corresponding classes. + seed: (Optional.) Python integer seed for the resampler. + Returns: - A list of the per-class acceptance probabilities. + A dataset of (class value, data) after filtering. + """ + def maybe_warn_on_large_rejection(accept_dist, initial_dist): + proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) + return control_flow_ops.cond( + math_ops.less(proportion_rejected, .5), + lambda: accept_dist, + lambda: logging_ops.Print( # pylint: disable=g-long-lambda + accept_dist, [proportion_rejected, initial_dist, accept_dist], + message="Proportion of examples rejected by sampler is high: ", + summarize=100, + first_n=10)) + + acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, + initial_dist_ds)) + .map(maybe_warn_on_large_rejection)) + + def _gather_and_copy(class_val, acceptance_prob, data): + return class_val, array_ops.gather(acceptance_prob, class_val), data + + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) + filtered_ds = ( + current_probabilities_and_class_and_data_ds + .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) + return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + + +def _estimate_initial_dist_ds( + target_dist_t, class_values_ds, dist_estimation_batch_size=32, + smoothing_constant=10): + num_classes = (target_dist_t.shape[0].value or + array_ops.shape(target_dist_t)[0]) + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist - This method is based on solving the following analysis: + initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) + + return initial_dist_ds + + +def _get_target_to_initial_ratio(initial_probs, target_probs): + # Add tiny to initial_probs to avoid divide by zero. + denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) + return target_probs / denom + + +def _estimate_data_distribution(c, num_examples_per_class_seen): + """Estimate data distribution as labels are seen. + + Args: + c: The class labels. Type `int32`, shape `[batch_size]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. + + Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. + dist: The updated distribution. Type `float32`, shape `[num_classes]`. + """ + num_classes = num_examples_per_class_seen.get_shape()[0].value + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( + array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) + init_prob_estimate = math_ops.truediv( + num_examples_per_class_seen, + math_ops.reduce_sum(num_examples_per_class_seen)) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist + + +def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): + """Calculates the acceptance probabilities and mixing ratio. + + In this case, we assume that we can *either* sample from the original data + distribution with probability `m`, or sample from a reshaped distribution + that comes from rejection sampling on the original distribution. This + rejection sampling is done on a per-class basis, with `a_i` representing the + probability of accepting data from class `i`. + + This method is based on solving the following analysis for the reshaped + distribution: Let F be the probability of a rejection (on any example). Let p_i be the proportion of examples in the data in class i (init_probs) @@ -151,39 +256,39 @@ def _calculate_acceptance_probs(initial_probs, target_probs): 0 <= t_i <= 1, sum_i(t_i) = 1 ``` - A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` - """ - # Add tiny to initial_probs to avoid divide by zero. - denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) - ratio_l = target_probs / denom - # Calculate list of acceptance probabilities. - max_ratio = math_ops.reduce_max(ratio_l) - return ratio_l / max_ratio + If we try to minimize the amount of data rejected, we get the following: + M_max = max_i [ t_i / p_i ] + M_min = min_i [ t_i / p_i ] -def _estimate_data_distribution(c, num_examples_per_class_seen): - """Estimate data distribution as labels are seen. + The desired probability of accepting data if it comes from class `i`: + + a_i = (t_i/p_i - m) / (M_max - m) + + The desired probability of pulling a data element from the original dataset, + rather than the filtered one: + + m = M_min Args: - c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, - containing counts. + initial_probs: A Tensor of the initial probability distribution, given or + estimated. + target_probs: A Tensor of the corresponding classes. Returns: - num_examples_per_lass_seen: Updated counts. Type `int64`, shape - `[num_classes]`. - dist: The updated distribution. Type `float32`, shape `[num_classes]`. + (A 1D Tensor with the per-class acceptance probabilities, the desired + probability of pull from the original distribution.) """ - num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in batch. - num_examples_per_class_seen = math_ops.add( - num_examples_per_class_seen, math_ops.reduce_sum( - array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) - init_prob_estimate = math_ops.truediv( - num_examples_per_class_seen, - math_ops.reduce_sum(num_examples_per_class_seen)) - dist = math_ops.cast(init_prob_estimate, dtypes.float32) - return num_examples_per_class_seen, dist + ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) + max_ratio = math_ops.reduce_max(ratio_l) + min_ratio = math_ops.reduce_min(ratio_l) + + # Target prob to sample from original distribution. + m = min_ratio + + # TODO(joelshor): Simplify fraction, if possible. + a_i = (ratio_l - m) / (max_ratio - m) + return a_i, m \ No newline at end of file diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 60ef7efba4bb2b..e911ad0fa0541f 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.data.util import sparse from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -36,18 +37,22 @@ def __init__(self, input_dataset, initial_state, scan_func): self._input_dataset = input_dataset with ops.name_scope("initial_state"): + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. self._initial_state = nest.pack_sequence_as(initial_state, [ - ops.convert_to_tensor(t, name="component_%d" % i) + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( + t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) - # Compute initial values for the state shapes and types based on - # the initial state. These will be refined by running - # `tf_scan_func` one or more times below. - # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. + # Compute initial values for the state classes, shapes and types based on + # the initial state. The shapes may be refined by running `tf_scan_func` one + # or more times below. + self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, - [t.shape for t in nest.flatten(self._initial_state)]) + [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) @@ -62,67 +67,102 @@ def __init__(self, input_dataset, initial_state, scan_func): need_to_rerun = True while need_to_rerun: - flat_state_shapes = nest.flatten(self._state_shapes) - flat_state_types = nest.flatten(self._state_types) - - # Create a list in which `tf_scan_func` will store the s + # Create a list in which `tf_scan_func` will store the new shapes. flat_new_state_shapes = [] - @function.Defun(*(flat_state_types + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. - # TODO(b/69424092): Check that neither inputs nor outputs are sparse. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, - flat_state_shapes + nest.flatten(dense_shapes)): + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): arg.set_shape(shape) - pivot = len(flat_state_shapes) - old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) - input_value = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - - ret = scan_func(old_state, input_value) + pivot = len(nest.flatten(self._state_shapes)) + print(self._state_classes) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + print(input_dataset.output_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) + + ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) new_state, output_value = ret - flat_new_state = [ - ops.convert_to_tensor(t) for t in nest.flatten(new_state) - ] - flat_output_value = [ - ops.convert_to_tensor(t) for t in nest.flatten(output_value) - ] + # Extract and validate class information from the returned values. + for t, clazz in zip( + nest.flatten(new_state), nest.flatten(self._state_classes)): + if not isinstance(t, clazz): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, + nest.pack_sequence_as( + self._state_types, + [type(t) for t in nest.flatten(new_state)]))) + self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. - flat_new_state_shapes.extend([t.shape for t in flat_new_state]) + flat_new_state_shapes.extend( + [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( - output_value, [t.shape for t in flat_output_value]) + output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, flat_state_types): + for t, dtype in zip( + nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % - (self._state_types, nest.pack_sequence_as( - self._state_types, [t.dtype for t in flat_new_state]))) - self._output_classes = nest.pack_sequence_as( - output_value, [ops.Tensor for _ in flat_output_value]) + (self._state_types, + nest.pack_sequence_as( + self._state_types, + [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in flat_output_value]) - - return flat_new_state + flat_output_value + output_value, [t.dtype for t in nest.flatten(output_value)]) + + # Serialize any sparse tensors. + new_state = nest.pack_sequence_as(new_state, [ + t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) + ]) + output_value = nest.pack_sequence_as(output_value, [ + t for t in nest.flatten( + sparse.serialize_sparse_tensors(output_value)) + ]) + return nest.flatten(new_state) + nest.flatten(output_value) # Use the private method that will execute `tf_scan_func` but delay # adding it to the graph in case we need to rerun the function. tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + flat_state_shapes = nest.flatten(self._state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -150,7 +190,7 @@ def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, - nest.flatten(self._initial_state), + nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, output_types=nest.flatten( diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index d39172039683fe..3cbaab5affd739 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -136,8 +136,8 @@ def _apply_fn(dataset): def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. - To consume the statistics, associate a `StatsAggregator` with an iterator - over the output dataset. + To consume the statistics, associate a `StatsAggregator` with the output + dataset. Args: tag: String. All statistics recorded by the returned transformation will @@ -158,8 +158,8 @@ def _apply_fn(dataset): def latency_stats(tag): """Records the latency of producing each element of the input dataset. - To consume the statistics, associate a `StatsAggregator` with an iterator - over the output dataset. + To consume the statistics, associate a `StatsAggregator` with the output + dataset. Args: tag: String. All statistics recorded by the returned transformation will diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 2038a8fc742a0c..da66e86f29460c 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -21,11 +21,11 @@ py_library( srcs = ["values.py"], visibility = ["//tensorflow:internal"], deps = [ + ":input_ops", ":prefetching_ops_v2", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", "//tensorflow/python:distribute", @@ -33,6 +33,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", ], ) @@ -42,6 +43,7 @@ gpu_py_test( srcs = ["values_test.py"], additional_deps = [ ":mirrored_strategy", + ":multi_worker_test_base", ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", @@ -57,6 +59,9 @@ gpu_py_test( "//tensorflow/python/eager:test", "//tensorflow/python/estimator:model_fn", ], + tags = [ + "no_pip", + ], ) py_library( @@ -81,6 +86,19 @@ py_library( ], ) +py_library( + name = "multi_worker_strategy", + srcs = ["multi_worker_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], @@ -133,6 +151,7 @@ py_library( ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", @@ -216,6 +235,24 @@ gpu_py_test( ], ) +py_library( + name = "multi_worker_test_base", + testonly = 1, + srcs = ["multi_worker_test_base.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + ], +) + py_library( name = "step_fn", srcs = ["step_fn.py"], @@ -274,6 +311,7 @@ gpu_py_test( tags = [ "multi_and_single_gpu", "no_pip", + "noguitar", # TODO(b/109653107): test is flaky. ], ) @@ -408,6 +446,7 @@ py_library( srcs = ["cross_tower_utils.py"], srcs_version = "PY2AND3", deps = [ + ":values", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", @@ -415,6 +454,24 @@ py_library( ], ) +gpu_py_test( + name = "cross_tower_utils_test", + srcs = ["cross_tower_utils_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_utils", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", + ], +) + py_library( name = "cross_tower_ops", srcs = ["cross_tower_ops.py"], @@ -433,24 +490,24 @@ py_library( ], ) -py_test( +gpu_py_test( name = "cross_tower_ops_test", srcs = ["cross_tower_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ + additional_deps = [ ":combinations", ":cross_tower_ops", ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "@absl_py//absl/testing:parameterized", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -479,3 +536,52 @@ gpu_py_test( "//tensorflow/python/data/ops:iterator_ops", ], ) + +py_library( + name = "input_ops", + srcs = ["input_ops.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + ], +) + +gpu_py_test( + name = "input_ops_test", + srcs = ["input_ops_test.py"], + additional_deps = [ + ":input_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:errors", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:io_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python:util", + ], + tags = [ + "no_pip", + ], +) + +gpu_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:keras", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/keras", + ], + tags = [ + "multi_and_single_gpu", + "notsan", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 946310aa6fc210..98e7228f24d8ca 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -41,16 +41,20 @@ def testOptimizer(self, optimizer): from collections import OrderedDict import sys +import types +import unittest from absl.testing import parameterized +import six -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect @@ -66,29 +70,35 @@ def generate(combinations): combinations: a list of dictionaries created using combine() and times(). Restrictions: - -- there should always be a "mode" argument. Accepted values are "eager" - and "graph". + -- the "mode" argument can be either "eager" or "graph". It's "graph" by + default. -- arguments of the test method must match by name to get the corresponding - value of the combination. Tests must accept all arguments (except "mode", - which is optional). - -- distribution argument is special. It is meant for passing instances of - DistributionStrategy. Each instance is to be passed as `(, - )` tuple, where is the number of required - GPUs. If the required number of GPUs for the DistributionStrategy isn't - available then the test case is going to be skipped. + value of the combination. Tests must accept all arguments except the + "mode", "required_tpu" and "required_gpus". + -- "distribution" argument is special and optional. It is meant for passing + instances of DistributionStrategy. Each instance is to be passed as via + `NamedDistribution`. If using "distribution", "required_gpus" and + "required_tpu" should be specified via the NamedDistribution instance, + rather than as separate arguments. + -- "required_tpu" argument is special and optional. If not `None`, then the + test will be skipped if TPUs aren't available. + -- "required_gpus" argument is special and optional. If not `None`, then the + test will be skipped if the specified number of GPUs aren't available. Returns: - a decorator that will cause the test method to be run under the specified - conditions. + a decorator that will cause the test method or the test class to be run + under the specified conditions. Raises: - ValueError - if "mode" argument wasn't either "eager" or "graph. + ValueError - if "mode" argument wasn't either "eager" or "graph" or if other + arguments were not accepted by the test method. """ - def decorator(test_function): + def decorator(test_method_or_class): """The decorator to be returned.""" # Generate good test names that can be used with --test_filter. + named_combinations = [] for combination in combinations: # We use OrderedDicts in `combine()` and `times()` to ensure stable # order of keys in each dictionary. @@ -99,59 +109,96 @@ def decorator(test_function): "".join(filter(str.isalnum, str(value)))) for key, value in combination.items() ]) - combination.update({"testcase_name": "_test{}".format(name)}) - - @parameterized.named_parameters(*combinations) - def decorated(self, **kwargs): - """A wrapped test method that sets up `test_function`.""" - assert "mode" in kwargs - mode = kwargs["mode"] - - if "distribution" in kwargs: - distribution = kwargs["distribution"] - kwargs["distribution"] = distribution.strategy - if distribution.required_tpu and not TPU_TEST: - self.skipTest("Test requires a TPU, but it's not available.") - if not distribution.required_tpu and TPU_TEST: - self.skipTest("Test that doesn't require a TPU.") - - if not distribution.required_gpus: - if GPU_TEST: - self.skipTest("Test that doesn't require GPUs.") - elif context.num_gpus() < distribution.required_gpus: - self.skipTest( - "{} GPUs are not available for this test. {} GPUs are available". - format(distribution.required_gpus, context.num_gpus())) - - requested_arguments = tf_inspect.getfullargspec(test_function).args - missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( - set(requested_arguments + ["mode"])) - if missing_arguments: - raise ValueError("The test is missing arguments {} .".format( - missing_arguments)) - - kwargs_to_pass = {} - for arg in requested_arguments: - if arg == "self": - kwargs_to_pass[arg] = self - else: - kwargs_to_pass[arg] = kwargs[arg] - - if mode == "eager": - with context.eager_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - elif mode == "graph": - with context.graph_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - else: - raise ValueError( - "'mode' has to be either 'eager' or 'graph' and not {}".format( - mode)) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) + + if isinstance(test_method_or_class, type): + class_object = test_method_or_class + class_object._test_method_ids = test_method_ids = {} + for name, test_method in six.iteritems(class_object.__dict__.copy()): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + isinstance(test_method, types.FunctionType)): + delattr(class_object, name) + methods = {} + parameterized._update_class_dict_for_param_test_case( + class_object.__name__, methods, test_method_ids, name, + parameterized._ParameterizedTestIter( + _augment_with_special_arguments(test_method), + named_combinations, parameterized._NAMED, name)) + for method_name, method in six.iteritems(methods): + setattr(class_object, method_name, method) + + return class_object + else: + test_method = _augment_with_special_arguments(test_method_or_class) + return parameterized.named_parameters(*named_combinations)(test_method) - return decorated return decorator +def _augment_with_special_arguments(test_method): + def decorated(self, **kwargs): + """A wrapped test method that treats some arguments in a special way.""" + mode = kwargs.pop("mode", "graph") + + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. + requested_arguments = tf_inspect.getfullargspec(test_method).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with ops.Graph().as_default(), context.eager_mode(): + test_method(**kwargs_to_pass) + elif mode == "graph": + with ops.Graph().as_default(), context.graph_mode(): + test_method(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + return decorated + + def combine(**kwargs): """Generate combinations based on its keyword arguments. @@ -159,7 +206,8 @@ def combine(**kwargs): can be computed using `times()`. Args: - **kwargs: keyword arguments of form `option=[possibilities, ...]`. + **kwargs: keyword arguments of form `option=[possibilities, ...]` + or `option=the_only_possibility`. Returns: a list of dictionaries for each combination. Keys in the dictionaries are @@ -178,6 +226,8 @@ def combine(**kwargs): key = first[0] values = first[1] + if not isinstance(values, list): + values = [values] return [ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) @@ -239,9 +289,9 @@ def __repr__(self): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus=None, + def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False): - self._distribution = distribution + self._distribution_fn = distribution_fn self._name = name self._required_gpus = required_gpus self._required_tpu = required_tpu @@ -251,7 +301,7 @@ def __repr__(self): @property def strategy(self): - return self._distribution + return self._distribution_fn() @property def required_gpus(self): @@ -262,21 +312,31 @@ def required_tpu(self): return self._required_tpu +# pylint: disable=g-long-lambda +default_strategy = NamedDistribution( + "Default", + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + required_gpus=None) one_device_strategy = NamedDistribution( - "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), - None) -tpu_strategy = NamedDistribution( - "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), + required_gpus=None) +tpu_strategy_single_iteration = NamedDistribution( + "TPUSingleIteration", + lambda: tpu_lib.TPUStrategy(iterations_per_step=1), + required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) +# Note that we disable prefetching for testing since prefetching makes +# the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) -mirrored_strategy_without_prefetch = NamedDistribution( - "MirroredCPUAndGPUNoPrefetch", - mirrored_strategy.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1) + lambda: mirrored_lib.MirroredStrategy( + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + lambda: mirrored_lib.MirroredStrategy( + ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + required_gpus=2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 219b24160f3902..86aa48cea889c6 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from collections import OrderedDict +from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.python.eager import test @@ -41,6 +42,15 @@ def test_combine(self): "b": 3 }], combinations.combine(a=[1, 2], b=[2, 3])) + def test_combine_single_parameter(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 2, + "b": 2 + }], combinations.combine(a=[1, 2], b=2)) + def test_add(self): self.assertEqual( [{ @@ -111,5 +121,28 @@ def test_overlapping_keys(self): _ = combinations.times(c1, c2) +@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1])) +class CombineTheTestSuite(parameterized.TestCase): + + def test_add_things(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def test_add_things_one_more(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def not_a_test(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + def _test_but_private(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + # Check that nothing funny happens to a non-callable that starts with "_test". + test_member = 0 + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index cff717db80f0bd..a411b880e80291 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -53,15 +53,14 @@ def _validate_value_destination_pairs(value_destination_pairs): return True +# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. def _get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) elif isinstance(destinations, six.string_types): - return [device_util.canonicalize(destinations)] + return [device_util.resolve(destinations)] else: - return [ - device_util.canonicalize(destination) for destination in destinations - ] + return [device_util.resolve(destination) for destination in destinations] def _devices_match(left, right): @@ -78,12 +77,12 @@ def _all_devices_match(value_destination_pairs): return True -def _simple_broadcast(tensor, destinations): +def _simple_broadcast(value, destinations): index = {} devices = _get_devices_from(destinations) for d in devices: - with ops.device(d): - index[d] = array_ops.identity(tensor) + index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + value, d) return value_lib.Mirrored(index) @@ -99,7 +98,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, continue count += len(v_list) # Sum within each device before aggregating across devices. - v = math_ops.add_n(v_list) + # TODO(yuefengz): Check whether it helps to use accumulation_fn here. + v = cross_tower_utils.aggregate_tensors_or_indexed_slices( + v_list, math_ops.add_n) else: count += 1 all_values.append(v) @@ -108,11 +109,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - if method_string == "sum": - reduced = accumulation_fn(all_values) - elif method_string == "mean": - reduced = accumulation_fn(all_values) / count - else: + reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( + all_values, accumulation_fn) + if method_string == "mean": + reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( + reduced, count) + elif method_string != "sum": raise ValueError("`method_string` must be 'sum' or 'mean'") return reduced @@ -445,10 +447,18 @@ def __init__(self, super(AllReduceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) - and not context.executing_eagerly()): + and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [per_device_value])[0] else: + if contains_indexed_slices: + logging.log_first_n( + logging.WARN, + "Efficient allreduce is not supported for IndexedSlices.", 10) + devices = _get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, @@ -456,14 +466,18 @@ def _reduce(self, method_string, per_device_value, destinations): return self.broadcast(reduced, devices) def _batch_reduce(self, method_string, value_destination_pairs): - if (_all_devices_match(value_destination_pairs) and - not context.executing_eagerly()): + all_devices_match = _all_devices_match(value_destination_pairs) + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + value_destination_pairs) + if (all_devices_match and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [v[0] for v in value_destination_pairs]) else: - if not context.executing_eagerly(): + if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") + return [ self._reduce(method_string, t, destinations=v) for t, v in value_destination_pairs diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 7c7b0870887465..2a266326088def 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util def _make_per_device(values, devices): @@ -56,19 +57,46 @@ def _fake_mirrored(value, devices): {d: v for d, v in zip(devices, [value] * len(devices))}) +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + _cpu_device = "/device:CPU:0" class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): - def _assert_value_equal(self, left, right): + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): if isinstance(left, list): for l, r in zip(left, right): - self._assert_value_equal(l, r) + self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) self.assertEqual(left.devices, right.devices) - if context.executing_eagerly(): + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.iteritems(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): self.assertEqual([v.numpy() for v in left._index.values()], list(right._index.values())) else: @@ -143,29 +171,29 @@ def testReductionAndBroadcast(self, cross_tower_ops, distribution): # test reduce() for destinations in all_destinations: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("mean", per_device, destinations=destinations), _fake_mirrored(mean, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "mean", per_device_2, destinations=destinations), _fake_mirrored(mean_2, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("sum", per_device, destinations=destinations), _fake_mirrored(mean * len(devices), destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "sum", per_device_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations or per_device)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "mean", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean, d1 or per_device), _fake_mirrored(mean_2, d2 or per_device_2)]) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "sum", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean * len(devices), d1 or per_device), @@ -176,7 +204,7 @@ def testReductionAndBroadcast(self, cross_tower_ops, distribution): if destinations is None: continue else: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) @@ -184,16 +212,14 @@ def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "hierarchical_copy") self.assertEqual(result.num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "nccl") self.assertEqual(result.num_packs, 1) @@ -202,8 +228,7 @@ def testChooseAlgorithm(self): [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "hierarchical_copy") self.assertEqual(result.num_packs, 8) @@ -211,11 +236,85 @@ def testChooseAlgorithm(self): device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) self.assertEqual(result.all_reduce_alg, "nccl") self.assertEqual(result.num_packs, 1) + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], + math_ops.add_n, "sum") + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate(combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + method_string=["sum", "mean"], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, + method_string, batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_tower_ops_instance.batch_reduce(method_string, + [(per_device, devices)]) + else: + result = cross_tower_ops_instance.reduce(method_string, per_device, + devices) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if method_string == "sum": + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert method_string == "mean" + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index fc04e2195f6d30..137fabf4c739bb 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -21,9 +21,11 @@ import collections as pycoll from tensorflow.contrib import nccl +from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -337,3 +339,46 @@ def unpack_small_tensors(tower_grads, packing): new_gv_list.insert(idx, gv[gi]) new_tower_grads.append(new_gv_list) return new_tower_grads + + +def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): + """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" + if any(isinstance(v, ops.IndexedSlices) for v in values): + return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access + else: + return accumulation_fn(values) + + +def divide_by_n_tensors_or_indexed_slices(value, n): + if isinstance(value, ops.IndexedSlices): + value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access + return ops.IndexedSlices( + value.values / n, value.indices, value.dense_shape) + else: + return value / n + + +def copy_tensor_or_indexed_slices_to_device(value, device): + with ops.device(device): + if isinstance(value, ops.IndexedSlices): + copied_values = array_ops.identity(value.values) + copied_indices = array_ops.identity(value.indices) + copied_shape = array_ops.identity(value.dense_shape) + result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) + else: + result = array_ops.identity(value) + return result + + +def contains_indexed_slices(value): + """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" + if isinstance(value, ops.IndexedSlices): + return True + elif isinstance(value, (list, tuple)) and value: + return any(contains_indexed_slices(v) for v in value) + elif isinstance(value, value_lib.DistributedValues): + return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access + elif isinstance(value, value_lib.MapOutput): + return contains_indexed_slices(value.get()) + else: + return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py new file mode 100644 index 00000000000000..4ef8db681503dc --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_tower_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDevice(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDeviceMapOutput(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({ + "/gpu:0": value_lib.MapOutput([t0]), + "/cpu:0": value_lib.MapOutput([t1])}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index b87224251ca384..2b05884b9b9347 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An example tf.keras model that is trained using MirroredStrategy.""" +"""An example of training tf.keras Model using MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from sys import argv + +import sys + import numpy as np import tensorflow as tf @@ -33,30 +35,37 @@ def input_fn(): def main(args): if len(args) < 2: - print('You must specify model_dir for checkpoints such as' - ' /tmp/tfkeras_example./') + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example/.') return - print('Using %s to store checkpoints.' % args[1]) - - strategy = tf.contrib.distribute.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) - optimizer = tf.train.GradientDescentOptimizer(0.2) + model_dir = args[1] + print('Using %s to store checkpoints.' % model_dir) + # Define tf.keras Model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + # Compile tf.keras Model. + optimizer = tf.train.GradientDescentOptimizer(0.2) model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() tf.keras.backend.set_learning_phase(True) + + # Define a DistributionStrategy and convert the tf.keras Model to a + # tf.Estimator that utilizes the DistributionStrategy. + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) keras_estimator = tf.keras.estimator.model_to_estimator( - keras_model=model, config=config, model_dir=args[1]) + keras_model=model, config=config, model_dir=model_dir) + # Train and evaluate the tf.Estimator. keras_estimator.train(input_fn=input_fn, steps=10) eval_result = keras_estimator.evaluate(input_fn=input_fn) print('Eval result: {}'.format(eval_result)) + if __name__ == '__main__': - tf.app.run(argv=argv) + tf.app.run(argv=sys.argv) diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py new file mode 100644 index 00000000000000..1f24f629479b6a --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Input-pipeline utilities for Distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging + +# TODO(priyag): Any other reader datasets to consider here? +_READER_DATASET_OPS = [ + "TextLineDataset", + "TFRecordDataset", + "FixedLengthRecordDataset" +] + + +# pylint: disable=protected-access +def auto_shard_dataset(dataset, num_shards, index): + """Shard the input pipeline by sharding the underlying list of files. + + Args: + dataset: A `tf.data.Dataset` instance, typically the result of a bunch of + dataset transformations. + num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of + shards operating in parallel. Same usage as in `Dataset.shard`. + index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. + Same usage as in `Dataset.shard`. + + Returns: + A modified `Dataset` obtained by updating the pipeline sharded by the + files. + + Raises: + NotImplementedError: If we cannot automatically determine a good way to + shard the input dataset. + """ + + # TODO(priyag): Clone datasets instead of updating in place, similar to the + # clone method for TFRecordDataset. + def _auto_shard_impl(dataset, found_reader_op): + """Recursive implementation of auto sharding.""" + + if not found_reader_op: + # TODO(priyag): Make this check more robust by enforcing some common + # property on reader datasets. + if (isinstance(dataset, readers.TextLineDataset) or + isinstance(dataset, readers.FixedLengthRecordDataset)): + filenames_tensor = dataset._filenames + num_files = array_ops.size(filenames_tensor) + sharded_filenames_tensor = array_ops.gather( + filenames_tensor, math_ops.range(index, num_files, num_shards)) + dataset._filenames = sharded_filenames_tensor + return dataset + elif isinstance(dataset, readers.TFRecordDataset): + # `TFRecordDataset` needs to be handled separately than other readers + # because it converts filenames to a dataset first. Also, we clone it + # instead of updating in place because it has special logic in the + # constructor. Eventually we will change all cases to clone datasets + # instead of updating in-place. + return dataset._clone( + filenames=dataset._filenames.shard(num_shards, index)) + elif hasattr(dataset, "_map_func"): + # TODO(priyag): Make this check more robust by enforcing some common + # property on all map/flatmap/interleave datasets. + map_func_def = dataset._map_func.definition + for node in map_func_def.node_def: + if node.op in _READER_DATASET_OPS: + found_reader_op = True + break + elif node.op == "FlatMapDataset": + # TODO(priyag): Should this check for other map datasets? Should it + # be recursive? It is too specific to implementation of + # TFRecordDataset right now. + nested_func_name = node.attr["f"].func.name + nested_func = ops.get_default_graph()._functions[nested_func_name] + for nested_node in nested_func.definition.node_def: + if nested_node.op in _READER_DATASET_OPS: + found_reader_op = True + break + if found_reader_op: + break + if found_reader_op: + dataset._input_dataset = _auto_shard_impl( + dataset._input_dataset, found_reader_op) + return dataset + + # TODO(priyag): Make _input_dataset(s) a common property of all datasets to + # make this check more robust. + if hasattr(dataset, "_input_dataset"): + dataset._input_dataset = _auto_shard_impl( + dataset._input_dataset, found_reader_op) + if hasattr(dataset, "_dataset_to_concatenate"): + # Special case for `ConcatentateDataset`. We want to shard all input + # datasets. + dataset._dataset_to_concatenate = _auto_shard_impl( + dataset._dataset_to_concatenate, found_reader_op) + return dataset + + if hasattr(dataset, "_datasets"): + # Special case for `ZipDataset`. + dataset._datasets = nest.pack_sequence_as(dataset._datasets, [ + _auto_shard_impl(ds, found_reader_op) + for ds in nest.flatten(dataset._datasets) + ]) + return dataset + + if not found_reader_op: + tf_logging.warn( + "Could not find a standard reader in the input pipeline" + "(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)." + "Falling back to sharding the dataset anyway. Please verify" + "correctness of auto-sharding for your input.") + + # TODO(priyag): What do we want to do if the number of filenames is + # uneven in the number of shards? By default, this will just return as + # many items it can before throwing OutOfRangeError. + # TODO(priyag): This will shard the filenames before any shuffling of the + # filename dataset. It might be desirable to shard after shuffling + # filenames? If so, how do we achieve that? + return dataset.shard(num_shards, index) + + return _auto_shard_impl(dataset=dataset, found_reader_op=False) diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py new file mode 100644 index 00000000000000..16179c3a4903c8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_ops_test.py @@ -0,0 +1,265 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for input pipeline modifications for distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.distribute.python import input_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import errors +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class AutoShardDatasetTest(test.TestCase): + + def setUp(self): + super(AutoShardDatasetTest, self).setUp() + self._num_files = 10 + self._num_records = 4 + self._num_shards = 2 + self._shard_index = 0 + self._record_bytes = 10 + + def _record(self, r, f): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _text_line(self, r, f): + return compat.as_bytes("Text line %d of file %d" % (r, f)) + + def _fixed_length_record(self, r, f): + return compat.as_bytes(str((r * f) % 10) * self._record_bytes) + + def _createTFRecordFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + record = self._record(j, i) + writer.write(record) + writer.close() + return filenames + + def _createTextFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(self._num_records): + contents.append(self._text_line(j, i)) + if j + 1 != self._num_records or i == 0: + contents.append(b"\r\n") + contents = b"".join(contents) + + with open(fn, "wb") as f: + f.write(contents) + return filenames + + def _createFixedLengthRecordFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + for j in range(self._num_records): + f.write(self._fixed_length_record(j, i)) + return filenames + + def _verifySimpleShardingOutput(self, dataset, record_fn): + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(record_fn(r, f), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testTFRecordDataset(self): + dataset = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._record) + + def testFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._record) + + def testInterleave(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.interleave( + readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Since block_length == num records in each file, the output will still + # contain records in order of files. + self._verifySimpleShardingOutput(dataset, self._record) + + def testParallelInterleave(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.apply(interleave_ops.parallel_interleave( + readers.TFRecordDataset, + cycle_length=4, + block_length=self._num_records)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Since block_length == num records in each file, the output will still + # contain records in order of files. + self._verifySimpleShardingOutput(dataset, self._record) + + def testListfiles(self): + filenames = self._createTFRecordFiles() + file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt" + dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + actual, expected = [], [] + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + actual.append(sess.run(next_element)) + expected.append(self._record(r, f)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self.assertAllEqual(expected, actual) + + def testComplexPipeline(self): + # Setup a complex input pipeline. + batch_size = 2 + num_epochs = 5 + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.shuffle(buffer_size=self._num_files) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = dataset.prefetch(buffer_size=batch_size) + dataset = dataset.shuffle(2 * self._num_files * self._num_records) + dataset = dataset.repeat(num_epochs) + dataset = dataset.apply(batching.map_and_batch( + lambda x: x, batch_size=batch_size)) + dataset = dataset.prefetch(buffer_size=None) + + # Auto shard. + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Verify output. + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + actual = [] + num_iterations = (self._num_files * self._num_records * num_epochs) // ( + self._num_shards * batch_size) + for _ in range(num_iterations): + actual.extend(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + expected = [] + for f in range(0, self._num_files, self._num_shards): + for r in range(self._num_records): + expected.append(self._record(r, f)) + expected *= num_epochs + + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testZip(self): + dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset2 = readers.TextLineDataset(self._createTextFiles()) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) + self._verifySimpleShardingOutput(dataset, record_fn) + + def testConcat(self): + dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset2 = readers.TextLineDataset(self._createTextFiles()) + dataset = dataset1.concatenate(dataset2) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(self._record(r, f), sess.run(next_element)) + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(self._text_line(r, f), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testTextLineReader(self): + dataset = readers.TextLineDataset(self._createTextFiles()) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._text_line) + + def testTextLineReaderWithFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles()) + dataset = dataset.flat_map(readers.TextLineDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._text_line) + + def testFixedLengthReader(self): + dataset = readers.FixedLengthRecordDataset( + self._createFixedLengthRecordFiles(), self._record_bytes) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + + def testFixedLengthReaderWithFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createFixedLengthRecordFiles()) + dataset = dataset.flat_map( + lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py new file mode 100644 index 00000000000000..75ecd90dcffa7a --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -0,0 +1,148 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras Sequential and Functional models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), + 'keras_mirrored_strategy_test') + gfile.MakeDirs(self._base_dir) + self._config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + + def tearDown(self): + writer_cache.FileWriterCache.clear() + if os.path.isdir(self._base_dir): + gfile.DeleteRecursively(self._base_dir) + + def test_train_functional_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_functional_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + def test_train_sequential_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index e134fe34e10be4..5c056a7c73def2 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -44,13 +44,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True]), - combinations.combine(is_tpu=[False])) + - combinations.combine( - distribution=[combinations.tpu_strategy], - optimizer_fn=[combinations.adam_optimizer_v1_fn], - mode=["graph"], - use_callable_loss=[False], - is_tpu=[True])) + combinations.combine(is_tpu=[False])) + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + # TODO(isaprykin): Make Adam v2 work with while_loops + # and TPUs. + ], + mode=["graph"], + use_callable_loss=[False], + is_tpu=[True])) def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): with distribution.scope(): @@ -101,7 +104,8 @@ def run_step(): distribution=[combinations.tpu_strategy], optimizer_fn=[ combinations.adam_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v1_fn + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn, ], mode=["graph"], is_tpu=[True])) @@ -171,13 +175,28 @@ def get_expected_variables(optimizer_fn, num_parameter_devices): set(created_variables)) @combinations.generate( - combinations.times(combinations.distributions_and_v1_optimizers(), - combinations.combine( - mode=["graph", "eager"], - momentum=[0.8, 0.9, 0.99], - renorm=[False, True]))) + combinations.times( + combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + is_tpu=[False], + # TODO(isaprykin): Allow False here. Currently subsequent + # towers will re-execute UPDATE_OPS of previous towers. + update_ops_in_cross_tower_mode=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + mode=["graph"], + is_tpu=[True], + update_ops_in_cross_tower_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm): + renorm, is_tpu, + update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) @@ -185,27 +204,30 @@ def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, - renorm=renorm) + renorm=renorm, + update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) - # Disable prefetching since that makes the specific input on each device - # to be non deterministic, and this test relies on specific input being - # on each device. + # Make sure prefetching is disabled since that makes the + # specific input on each device to be non deterministic, and + # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): - distribution._prefetch_on_device = False + self.assertFalse(distribution._prefetch_on_device) iterator = distribution.distribute_dataset( dataset_fn).make_one_shot_iterator() def run_step(): - return control_flow_ops.group( - distribution.unwrap( - distribution.call_for_each_tower( - model_fn, - iterator.get_next(), - run_concurrently=batchnorm.built)) + - ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + fetches = distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), + run_concurrently=batchnorm.built)) + if update_ops_in_cross_tower_mode: + fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + return control_flow_ops.group(fetches) if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -229,22 +251,40 @@ def averaged_batch_mean(i): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + @combinations.generate( combinations.times( combinations.combine( - distribution=[combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn], - loss_reduction=[losses_impl.Reduction.SUM, - losses_impl.Reduction.MEAN, - losses_impl.Reduction.SUM_OVER_BATCH_SIZE, - losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), - combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + loss_reduction=[ + losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS + ]), + combinations.times( + combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + is_tpu=[False]), + combinations.combine( + mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + is_tpu=[True], + mode=["graph"], + use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, - use_callable_loss): + use_callable_loss, is_tpu): with distribution.scope(): all_vars = [] @@ -280,12 +320,13 @@ def run_step(): if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) run_step() - self.assertEqual(distribution.num_towers, len(all_vars)) v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) @@ -312,6 +353,10 @@ def run_step(): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 6efd578a775da7..cef0a2907b85d2 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import contextlib import threading import six @@ -39,6 +40,16 @@ # TODO(josh11b): Replace asserts in this file with if ...: raise ... +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) @@ -73,13 +84,13 @@ def __init__(self, assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? - self._devices = devices - self._canonical_device_set = set( - [device_util.canonicalize(d) for d in devices]) + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" @@ -107,13 +118,19 @@ def _create_variable(self, next_creator, *args, **kwargs): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - kwargs["name"] = "%s/replica_%d" % (var0name, i) + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): - initial_value = index[devices[0]].value() + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) else: - initial_value = index[devices[0]].initial_value - kwargs["initial_value"] = array_ops.identity(initial_value) + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) @@ -244,8 +261,15 @@ def _call_for_each_tower(self, fn, *args, **kwargs): {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: @@ -321,7 +345,6 @@ def _update_non_slot(self, colocate_with, fn, *args, **kwargs): def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" - assert isinstance(destination, six.string_types) if isinstance(val, values.TowerLocalVariable): val = self.reduce(val.reduce_method, val, destinations=destination) with ops.device(destination): @@ -386,7 +409,9 @@ def _get_devices_from(self, colocate_with=None): # pylint: disable=protected-access return list(colocate_with._index.keys()) elif isinstance(colocate_with, six.string_types): - return [colocate_with] + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] else: return colocate_with @@ -413,6 +438,7 @@ def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, self.merge_args = None self.merge_kwargs = None self.merge_result = None + self.captured_name_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -436,13 +462,13 @@ def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. - self._captured_name_scope = self.graph.get_name_scope() - if self._captured_name_scope: - self._captured_name_scope += "/" + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" if self.tower_id > 0: - if not self._captured_name_scope: - self._captured_name_scope = "" - self._captured_name_scope += "tower_%d/" % self.tower_id + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id def run(self): # pylint: disable=protected-access @@ -455,10 +481,10 @@ def run(self): with self.coord.stop_on_exception(), \ context.context()._mode(self.context_mode), \ context.context().device_policy(self.context_device_policy), \ - self.graph.as_default(), \ + _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ - ops.name_scope(self._captured_name_scope), \ + ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._captured_var_scope, reuse=self.tower_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): @@ -484,6 +510,10 @@ def _merge_call(self, fn, *args, **kwargs): t.merge_fn = fn t.merge_args = args t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" t.has_paused.set() t.should_run.wait() t.should_run.clear() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 6c5c055070c0fc..bccd278847e3c8 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -28,9 +28,12 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -116,7 +119,6 @@ def run_fn(device_id): self.assertEqual(expected, self.evaluate(unwrapped[0])) -@test_util.with_c_api class MirroredStrategyVariableCreationTest(test.TestCase): config = config_pb2.ConfigProto() @@ -370,22 +372,27 @@ def model_fn(device_id): expected_sum = 0.0 expected_mean = 0.0 for i, d in enumerate(dist.worker_devices): - # Test access within a device scope, should see different values. - with ops.device(d): - v_sum_value = self.evaluate(ret_v_sum.read_value()) - v_mean_value = self.evaluate(ret_v_mean.read_value()) - expected = i + 3.0 - self.assertEqual(expected, v_sum_value) - expected_sum += expected - expected = i * 6.0 - self.assertEqual(expected, v_mean_value) - expected_mean += expected - - # fetch() should return the value you get by applying the - # reduction across all towers. - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + # Should see different values on different devices. + v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) + v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected expected_mean /= len(dist.worker_devices) + + # Without get(device), should return the value you get by + # applying the reduction across all towers (whether you use + # fetch(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) + if not context.executing_eagerly(): + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -431,6 +438,98 @@ def model_fn(): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + + def testDynamicRnnVariables(self): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + # Two variables are created by the RNN layer. + self.assertEquals(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = dist.unwrap(v) + self.assertStartsWith(v1.name, "tower_1/") + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a1ef0ecc77a8e8..61cbe6df813bb2 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -27,7 +27,6 @@ from tensorflow.python.training import distribute as distribute_lib -@test_util.with_c_api class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): @@ -53,7 +52,6 @@ def testCallAndMergeExceptions(self): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) -@test_util.with_c_api class VariableCreatorStackTest(test.TestCase): def testCreatorStacksAreThreadLocal(self): diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 8277e1e7919e86..4fdb9bf69b4f6a 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import monitor as monitor_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -65,7 +66,7 @@ def testPassingASessionInEager(self): step_function, _ = single_loss_example( lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) - with self.test_session() as sess: + with session.Session() as sess, context.eager_mode(): with self.assertRaisesRegexp(ValueError, "Should not provide"): _ = monitor_lib.Monitor(step_function, sess) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py new file mode 100644 index 00000000000000..a552b370ebf359 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + + +# TODO(yuefengz): support between-graph replication. +# TODO(yuefengz): merge this class into its base class. +# TODO(yuefengz): in some cases, we probably want to use configure method to +# configure this class. +# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the +# class is introduced. +class MultiWorkerMirroredStrategy(MirroredStrategy): + """Mirrored strategy that works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will parition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + This class maps one tower to one device on a worker. It mirrors all model + variables on all towers. For example, if you have two `worker`s and each + `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 + GPUs. Then like in MirroredStrategy, each tower performs their computation + with their own copy of variables unless in cross-tower model where variable or + tensor reduction happens. + """ + + def __init__(self, + num_gpus_per_worker=1, + worker_job_name=None, + num_workers=None, + cluster=None, + cross_tower_ops=None, + prefetch_on_device=None): + """Initialize the strategy object. + + Args: + num_gpus_per_worker: number of GPUs per work. If it is zero, the local + CPU will be used. + worker_job_name: the job name for `worker`, typically just 'worker'. + num_workers: the number of workers. If it is 0, it regenerates to + single-worker MirroredStrategy. + cluster: a `tf.train.ClusterSpec` object or a dict that can be used to + construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` + proto buffer. It is an alternative way to initialize this object. + cross_tower_ops: the cross tower ops to use. If None, a default one will + be used. If configure method is called, a best one for the configuration + will be chosen. + prefetch_on_device: a boolean to specify whether to prefetech input to + each worker's devices. + + Raises: + ValueError: if got an unexpected `cluster`. + """ + if cluster is None: + self._workers = [ + '/job:%s/task:%d' % (worker_job_name, task_index) + for task_index in range(num_workers) + ] + else: + if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster) + elif isinstance(cluster, server_lib.ClusterSpec): + cluster_spec = cluster + else: + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + '`tf.train.ClusterDef` object') + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append('/job:%s/task:%d' % (job, task)) + + self._num_gpus_per_worker = num_gpus_per_worker + if num_gpus_per_worker > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + '/device:GPU:%d' % gpu) + for gpu in range(num_gpus_per_worker) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, '/device:CPU:0')] + for worker in self._workers + } + self._devices = nest.flatten(self._worker_device_map.values()) + + super(MultiWorkerMirroredStrategy, self).__init__( + devices=self._devices, prefetch_on_device=prefetch_on_device) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + + def distribute_dataset(self, dataset_fn): + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py new file mode 100644 index 00000000000000..09c859b32a3150 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultiWorkerMirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.training import server_lib + + +class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus_per_worker=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + +class DeviceScopeTest(test.TestCase): + """Test the device scope of MultiWorkerMirroredStrategy.""" + + def testDeviceScope(self): + with context.graph_mode(): + strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus_per_worker=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py new file mode 100644 index 00000000000000..f659be5f42594b --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -0,0 +1,90 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base testing class for strategies that require multiple nodes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import copy + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +class MultiWorkerTestBase(test.TestCase): + """Base class for testing multi node strategy and dataset.""" + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + num_workers = 2 + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + default_config = config_pb2.ConfigProto() + default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # The local cluster takes some portion of the local GPUs and there is no way + # for the cluster to terminate unless using multiple processes. Therefore, + # we have to only create only one cluster throughout a test process. + workers, _ = test_util.create_local_cluster( + num_workers, num_ps=0, worker_config=default_config) + cls._master_target = workers[0].target + + @contextlib.contextmanager + def test_session(self, graph=None, config=None): + """Create a test session with master target set to the testing cluster. + + This overrides the base class' method, removes arguments that are not needed + by the multi-node case and creates a test session that connects to the local + testing cluster. + + Args: + graph: Optional graph to use during the returned session. + config: An optional config_pb2.ConfigProto to use to configure the + session. + + Yields: + A Session object that should be used as a context manager to surround + the graph building and execution code in a test case. + """ + if self.id().endswith('.test_session'): + self.skipTest('Not a test.') + + if config is None: + config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + config = copy.deepcopy(config) + # Don't perform optimizations for tests so we don't inadvertently run + # gpu ops on cpu + config.graph_options.optimizer_options.opt_level = -1 + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + + if graph is None: + if self._cached_session is None: # pylint: disable=access-member-before-definition + self._cached_session = session.Session( + graph=None, config=config, target=self._master_target) + sess = self._cached_session + with sess.graph.as_default(), sess.as_default(): + yield sess + else: + with session.Session( + graph=graph, config=config, target=self._master_target) as sess: + yield sess diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 646d2a5c3b3b0b..09b6d4a515ab46 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -36,9 +36,11 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device): + def __init__(self, device, prefetch_on_device=None): super(OneDeviceStrategy, self).__init__() self._device = device + self._prefetch_on_device = prefetch_on_device + self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): # No need to distinguish tower-local variables when not mirroring, @@ -61,7 +63,9 @@ def _create_variable(self, next_creator, *args, **kwargs): return next_creator(*args, **kwargs) def distribute_dataset(self, dataset_fn): - return self._call_dataset_fn(dataset_fn) + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), [self._device], + self._prefetch_on_device) def _broadcast(self, tensor, destinations): return tensor diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7101ed0756f44b..7aad8a953cbedd 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import test_util -@test_util.with_c_api class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index 713494d603b855..a0b452fc2d445d 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -44,7 +44,6 @@ def testWrongPatterns(self): self.assertEquals("foo_a", self._canonicalize("foo_a")) -@test_util.with_c_api class SharedVariableCreatorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index 0db0b59fcacee2..d1fdb3279cf2a7 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.layers import normalization from tensorflow.python.ops import array_ops @@ -59,7 +60,7 @@ def dataset_fn(): # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be # fully defined for TPU. Remove this when XLA supports dynamic shapes. return dataset.apply( - batching.map_and_batch(lambda x: x, batch_size=2, drop_remainder=True)) + batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None @@ -68,11 +69,10 @@ def dataset_fn(): layer = core.Dense(1, use_bias=use_bias) - def model_fn(xs): + def model_fn(x): """A very simple model written by the user.""" def loss_fn(): - x = math_ops.reduce_mean(xs, keepdims=True) y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y @@ -89,7 +89,8 @@ def loss_fn(): def batchnorm_example(optimizer_fn, batch_per_epoch=1, momentum=0.9, - renorm=False): + renorm=False, + update_ops_in_tower_mode=False): """Example of non-distribution-aware legacy code with batch normalization.""" def dataset_fn(): @@ -103,12 +104,19 @@ def dataset_fn(): optimizer = optimizer_fn() batchnorm = normalization.BatchNormalization( renorm=renorm, momentum=momentum, fused=False) + layer = core.Dense(1, use_bias=False) def model_fn(x): + """A model that uses batchnorm.""" def loss_fn(): - y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) - loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + y = batchnorm(x, training=True) + with ops.control_dependencies( + ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops_in_tower_mode else []): + loss = math_ops.reduce_mean( + math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) + # `x` and `y` will be fetched by the gradient computation, but not `loss`. return loss # Callable loss. diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index a7e4fe80f3e659..75441786a615fc 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -33,7 +33,6 @@ from tensorflow.python.util import nest -# TODO(isaprykin): Consider whether inheriting is really appropriate. class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" @@ -73,7 +72,6 @@ def _call_for_each_tower(self, fn, *args, **kwargs): def infeed_input(i): """Get input, split it and then enqueue.""" iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] for inputs_per_core in iteration_inputs] for core_id in range(self._num_cores_per_host)] @@ -117,3 +115,14 @@ def iterate_on_tpu(): iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops) + + def _reduce(self, method_string, value, destinations): + del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + if method_string == 'mean': + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self._num_cores_per_host) + return tpu_ops.cross_replica_sum(value) + + @property + def num_towers(self): + return self._num_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 8cb5276579f48f..9572ade8e497fa 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -27,15 +27,18 @@ import six from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable +from tensorflow.python.ops import math_ops from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -58,13 +61,14 @@ def get(self, device=None): else: device = distribute_lib.get_update_device() if device is None: - device = device_util.current() + return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] - except KeyError: - raise ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())) + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) def on_device(self, device): device = device_util.canonicalize(device) @@ -312,6 +316,18 @@ def assign_add(self, *args, **kwargs): def assign(self, *args, **kwargs): return self.get(device=_get_update_device()).assign(*args, **kwargs) + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return array_ops.identity(self._index[device]) + return array_ops.identity(self._primary_var) + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -356,6 +372,12 @@ def restore(self, restored_tensors, restored_shapes): for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access +def _assert_tower_context(): + if not distribute_lib.get_tower_context(): + raise RuntimeError( + "Tower-local variables may only be assigned in a tower context.") + + class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -366,18 +388,35 @@ def __init__(self, index, primary_var, reduce_method): super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_sub(*args, **kwargs) def assign_add(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): + _assert_tower_context() return self.get().assign(*args, **kwargs) @property def reduce_method(self): return self._reduce_method + def _get_cross_tower(self): + all_components = tuple(self._index.values()) + # TODO(josh11b): Use a strategy-specific method. + total = math_ops.add_n(all_components) + if self._reduce_method == "mean": + return total * (1./ len(all_components)) + return total + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._get_cross_tower() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -570,11 +609,106 @@ def make_initializable_iterator(self): dataset_iterator, self._devices, self._prefetch_on_device) -class PerIteration(object): - """Holds input for multiple iterations at once.""" +class MultiWorkerDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" - def __init__(self, index): - self._index = index + def __init__(self, iterators, worker_device_map): + """Initialize the MultiWorkerDataIterator object. + + Args: + iterators: a dict mapping from each worker to an iterator for + that worker. + worker_device_map: a dict mapping from each worker's devices to a list of + devices that belong to this worker. + + Raises: + ValueError: if iterators and worker_device_map are not compatible. + """ + self._iterators = iterators + self._worker_device_map = worker_device_map + if set(self._iterators) != set(self._worker_device_map): + raise ValueError("iterators and worker_device_map are not compatible.") + + @property + def initializer(self): + return control_flow_ops.group( + [iterator.initializer for iterator in self._iterators.values()]) + + def get_next(self, name=None): + """Scatter the input across hosts and devices.""" + index = {} + for worker, iterator in six.iteritems(self._iterators): + if name is not None: + d = tf_device.DeviceSpec.from_string(worker) + new_name = "%s_%s_%d" % (name, d.job, d.task) + else: + new_name = None + with ops.device(worker): + data_per_worker = iterator.get_next(name=new_name) + + worker_devices = self._worker_device_map[worker] + # Ungroup these per-device value so as to get a flat map from devices to + # values. + for d in worker_devices: + v = select_device(d, data_per_worker) + if d in index: + raise ValueError("Duplicated devices in worker_device_map: %r" % v) + index[d] = v + + return regroup(index) + + +class MultiWorkerDataset(object): + """Like a `tf.data.Dataset` that distributes data to different workers. + + Each worker gets one shard of the input dataset. It is currently not working + in + eager mode. + """ + + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + """Initialize the MultiWorkerDataset object. + + Args: + dataset_fn: a function that returns a `tf.data.Dataset`. + worker_device_map: a dict mapping from each worker to a list of devices + that belong to this worker. + prefetch_on_device: whether to prefetch to devices. + """ + self._worker_device_map = worker_device_map + self._datasets = {} + # TODO(yuefengz, priyag): support different set of jobs for input + # processing. + for i, (worker, worker_devices) in enumerate( + six.iteritems(worker_device_map)): + with ops.device(worker): + worker_input = dataset_fn() + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) + self._datasets[worker] = PerDeviceDataset( + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + + def make_one_shot_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_one_shot_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + def make_initializable_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_initializable_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + +class _PerKey(object): + """Holds data associated by keys.""" + + def __init__(self, *index): + # pylint: disable=protected-access + self._index = list(index) def get(self, iteration): return array_ops.gather(self._index, iteration) @@ -585,6 +719,24 @@ def get_shape(self): def get_dtype(self): return self._index[-1][-1].dtype + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + +class PerIteration(_PerKey): + """Holds input for multiple iterations at once.""" + + def __init__(self, *index): + # pylint: disable=protected-access + super(PerIteration, self).__init__(*[batch._index for batch in index]) + + +class Batches(_PerKey): + pass + class MultiIterator(object): """Iterator that returns results of multiple get_next()s.""" @@ -595,11 +747,31 @@ def __init__(self, dataset_iterator, iterations, batches_per_iteration): self._batches_per_iteration = batches_per_iteration def get_next(self, name=None): - return PerIteration([[ - self._dataset_iterator.get_next(name=name) - for _ in range(self._batches_per_iteration) - ] - for _ in range(self._iterations)]) + """Return PerIteration with `iterations x batches_per_iteration` inputs.""" + data = [] + for _ in range(self._batches_per_iteration): + batch = [] + for _ in range(self._iterations): + batch.append(self._dataset_iterator.get_next(name=name)) + data.append(batch) + + # Here is an example. Suppose each get_next returns a tuple of two tensors. + # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: + # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] + # + # After the first `map_structure` it gets transformed to: + # [(Batches(a, A), Batches(z, Z)), + # (Batches(b, B), Batches(y, Y)), + # (Batches(c, C), Batches(x, X))] + # + # After the second `map_structure` it gets transformed to a tuple of: + # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), + # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) + + data = nest.map_structure(Batches, *data) + data = nest.map_structure(PerIteration, *data) + + return data @property def initializer(self): diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index e96ce547415fcb..1c95758d96aba4 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -18,9 +18,11 @@ from __future__ import division from __future__ import print_function +import collections import os from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops @@ -34,11 +36,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest -@test_util.with_c_api class DistributedValuesTest(test.TestCase): def testGetEager(self): @@ -77,7 +80,6 @@ def testCanonicalization(self): v = values.DistributedValues({"/device:cpu:0": 42}) -@test_util.with_c_api class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() @@ -160,7 +162,6 @@ def _make_mirrored(): return v, devices, mirrored -@test_util.with_c_api class RegroupAndSelectDeviceTest(test.TestCase): def _is_per_device(self, result, expected, klass=values.PerDevice): @@ -313,7 +314,6 @@ def testNamedTupleEstimatorSpec(self): merged_estimator_spec)) -@test_util.with_c_api class PerDeviceDatasetTest(test.TestCase): config = config_pb2.ConfigProto() @@ -436,7 +436,130 @@ def testInitializableIterator(self): self.evaluate(next_element) -@test_util.with_c_api +class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): + + def _test_iterator(self, iterator, devices, expected_values): + next_element = iterator.get_next() + for device in devices: + v = values.select_device(device, next_element) + # The `v` here can be a tuple. + for element in nest.flatten(v): + self.assertTrue(element.device in device) + + for expected_value in expected_values: + actual = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + self.evaluate([values.select_device(d, next_element) for d in devices]) + + def _test_dataset(self, dataset_fn, worker_device_map, devices, + expected_values): + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() + self._test_iterator(multi_worker_iterator, devices, expected_values) + + def _cpu_devices(self): + worker_device_map = collections.OrderedDict( + [("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])]) + devices = [ + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_device_map, devices + + def _cpu_and_one_gpu_devices(self): + # The worker_device_map doesn't have to be a OrderDict object, this is just + # to simplify the testing so that we can pass expected values as a list + # instead of a dict. + worker_device_map = collections.OrderedDict( + [("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ])]) + devices = [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_device_map, devices + + def testDataDistributionOneDevicePerWorker(self): + worker_device_map, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset(dataset_fn, worker_device_map, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + def testDataDistributionTwoDevicePerWorker(self): + if context.num_gpus() < 1: + self.skipTest("A GPU is not available for this test.") + worker_device_map, devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset(dataset_fn, worker_device_map, devices, + [[0, 2, 1, 3], [4, 6, 5, 7]]) + + def testTupleDataset(self): + worker_device_map, devices = self._cpu_devices() + + with context.graph_mode(): + + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(8) + dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [ + [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) + ] + self._test_dataset(dataset_fn, worker_device_map, devices, + expected_values) + + def testInitializableIterator(self): + worker_device_map, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + + self.evaluate(multi_worker_iterator.initializer) + self._test_iterator(multi_worker_iterator, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(multi_worker_iterator.initializer) + self._test_iterator(multi_worker_iterator, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + def testValueErrorForIterator(self): + # Incompatiable arguments. + with self.assertRaises(ValueError): + values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) + + # Test duplicated devices under same worker. + worker_device_map, _ = self._cpu_devices() + worker_device_map["/job:worker/replica:0/task:0"].append( + "/job:worker/replica:0/task:0/device:CPU:0") + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.assertRaises(ValueError): + multi_worker_iterator.get_next() + + class MirroredVariableTest(test.TestCase): config = config_pb2.ConfigProto() @@ -582,6 +705,21 @@ def testSaveNormalRestoreMirrored(self): save_path = self._save_normal() self._restore_mirrored(save_path) + @test_util.run_in_graph_and_eager_modes(config=config) + def testFetchAMirroredVariable(self): + if context.num_gpus() < 1 or context.executing_eagerly(): + self.skipTest("A GPU is not available for this test or it's eager mode.") + + with self.test_session( + graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( + ["/device:GPU:0"]).scope(): + with ops.device("/device:GPU:0"): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + sess.run(variables_lib.global_variables_initializer()) + sess.run({"complicated": mirrored}) + _devices = ["/device:GPU:0", "/device:CPU:0"] @@ -598,7 +736,6 @@ def _make_tower_local(method): return v, tower_local -@test_util.with_c_api class TowerLocalVariableTest(test.TestCase): config = config_pb2.ConfigProto() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 9a3b02cd813fa8..cbce7a88fd1cff 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -94,7 +94,7 @@ gpu_py_test( gpu_py_test( name = "distribution_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/distribution_test.py"], additional_deps = [ ":distributions_py", @@ -337,7 +337,7 @@ gpu_py_test( gpu_py_test( name = "mvn_tril_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_tril_test.py"], additional_deps = [ ":distributions_py", @@ -372,6 +372,7 @@ gpu_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:variables", ], + shard_count = 4, ) gpu_py_test( @@ -459,7 +460,7 @@ gpu_py_test( gpu_py_test( name = "batch_reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/batch_reshape_test.py"], additional_deps = [ ":distributions_py", @@ -578,7 +579,7 @@ gpu_py_test( gpu_py_test( name = "wishart_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/wishart_test.py"], additional_deps = [ ":distributions_py", @@ -709,6 +710,8 @@ gpu_py_test( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", ], + shard_count = 4, + tags = ["noasan"], # times out, http://b/78588814 ) gpu_py_test( @@ -865,7 +868,7 @@ gpu_py_test( gpu_py_test( name = "batch_normalization_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], additional_deps = [ ":bijectors_py", @@ -1029,6 +1032,25 @@ gpu_py_test( ], ) +gpu_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + gpu_py_test( name = "real_nvp_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index ddf59891e626a8..802538ba97578c 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -32,6 +32,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse @@ -156,6 +157,7 @@ 'kl_divergence', 'RegisterKL', 'fill_triangular', + 'fill_triangular_inverse', 'matrix_diag_transform', 'reduce_weighted_logsumexp', 'softplus_inverse', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 59d549b7b80a3d..f2bb2d3325a7cc 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -448,8 +448,7 @@ def test_bad_reshape_size(self): else: with self.test_session(): - with self.assertRaisesOpError(r"`batch_shape` size must match " - r"`distributions.batch_shape` size"): + with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -457,8 +456,13 @@ def test_bad_reshape_size(self): def test_non_positive_shape(self): dims = 2 - new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. - old_batch_shape = [2] + old_batch_shape = [4] + if self.is_static_shape: + # Unknown first dimension does not trigger size check. Note that + # any dimension < 0 is treated statically as unknown. + new_batch_shape = [-1, 0] + else: + new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = ( constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape @@ -471,7 +475,7 @@ def test_non_positive_shape(self): mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: - with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -479,7 +483,7 @@ def test_non_positive_shape(self): else: with self.test_session(): - with self.assertRaisesOpError(r".*must be positive.*"): + with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index ca20442c394066..dc45114b1c23b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -188,6 +189,15 @@ def testChainAffineExp(self): -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + def testChainIldjWithPlaceholder(self): + chain = Chain((Exp(), Exp())) + samples = array_ops.placeholder( + dtype=np.float32, shape=[None, 10], name="samples") + ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) + self.assertTrue(ildj is not None) + with self.test_session(): + ildj.eval({samples: np.zeros([2, 10], np.float32)}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 8b279ebcd908b6..f8a52615b0f3f5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -59,7 +59,7 @@ def testConditionalBijector(self): for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1., event_ndims=0., arg1="b1", arg2="b2") + method(1., event_ndims=0, arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 00000000000000..18397035571561 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes() + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 46f2c63f9b0f78..d44e49b4874a5b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,15 +22,12 @@ from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test -@test_util.with_c_api class _ReshapeBijectorTest(object): """Base class for testing the reshape transformation. @@ -265,7 +262,6 @@ def build_shapes(self, *args, **kwargs): raise NotImplementedError("Subclass failed to implement `build_shapes`.") -@test_util.with_c_api class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -305,21 +301,13 @@ def testBijectiveAndFinite(self): bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): - if ops._USE_C_API: - error_message = "Invalid value in tensor used for shape: -2" - else: - error_message = "elements must be either positive integers or `-1`." - self._testInvalidDimensionsOpError(error_message) + self._testInvalidDimensionsOpError( + "Invalid value in tensor used for shape: -2") def testInputOutputMismatchOpError(self): - if ops._USE_C_API: - error_message = "Cannot reshape a tensor with" - else: - error_message = "Input to reshape is a tensor with" - self._testInputOutputMismatchOpError(error_message) + self._testInputOutputMismatchOpError("Cannot reshape a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -341,7 +329,6 @@ def testInputOutputMismatchOpError(self): self._testInputOutputMismatchOpError("Input to reshape is a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py index 7435bcbc684c16..b003526392709b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py @@ -131,8 +131,8 @@ def _random_mu_and_sigma(self, batch_shape, event_shape): return mu, sigma def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -156,6 +156,33 @@ def testKLBatch(self): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalFullCovariance( + loc=mu_a, + covariance_matrix=sigma_a, + validate_args=True) + mvn_b = ds.MultivariateNormalFullCovariance( + loc=mu_b, + covariance_matrix=sigma_b, + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py index 685f32883dae5b..b556d06123800f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -235,8 +235,8 @@ def _random_mu_and_sigma(self, batch_shape, event_shape): return mu, sigma def testKLNonBatch(self): - batch_shape = () - event_shape = (2,) + batch_shape = [] + event_shape = [2] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -257,8 +257,8 @@ def testKLNonBatch(self): self.assertAllClose(expected_kl, kl_v) def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -282,9 +282,36 @@ def testKLBatch(self): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLTwoIdenticalDistributionsIsZero(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = ds.MultivariateNormalTriL( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 96805733178705..b91a610acf1a90 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ def testNestingRobustness(self): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py index ce6cf702d52279..9c4dfed83631e9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -98,23 +98,21 @@ def test_true_mean_confidence_interval_by_dkwm_one_sample(self): num_samples = 5000 # 5000 samples is chosen to be enough to find discrepancies of # size 0.1 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.1) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.1) # Test that the confidence interval computed for the mean includes # 0.5 and excludes 0.4 and 0.6. - with self.test_session() as sess: - samples = rng.uniform(size=num_samples).astype(np.float32) - (low, high) = st.true_mean_confidence_interval_by_dkwm( - samples, 0., 1., error_rate=1e-6) - low, high = sess.run([low, high]) - self.assertGreater(low, 0.4) - self.assertLess(low, 0.5) - self.assertGreater(high, 0.5) - self.assertLess(high, 0.6) + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = self.evaluate([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) def test_dkwm_mean_one_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -123,21 +121,45 @@ def test_dkwm_mean_one_sample_assertion(self): # Test that the test assertion agrees that the mean of the standard # uniform distribution is 0.5. samples = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.5, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.4. - with self.assertRaisesOpError("Mean confidence interval too high"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.4, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.6. - with self.assertRaisesOpError("Mean confidence interval too low"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.6, false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_in_interval_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is between 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.2 and 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.6 and 0.8. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -145,20 +167,18 @@ def test_dkwm_mean_two_sample_assertion(self): # 4000 samples is chosen to be enough to find discrepancies of # size 0.2 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( - num_samples, 0., 1., num_samples, 0., 1., - false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.2) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.2) # Test that the test assertion agrees that the standard # uniform distribution has the same mean as itself. samples1 = rng.uniform(size=num_samples).astype(np.float32) samples2 = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): rng = np.random.RandomState(seed=0) @@ -168,15 +188,14 @@ def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(2, 1). - beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples1 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_high_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): rng = np.random.RandomState(seed=0) @@ -186,15 +205,14 @@ def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(1, 2). - beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples2 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_low_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_argument_validity_checking(self): rng = np.random.RandomState(seed=0) @@ -203,18 +221,17 @@ def test_dkwm_argument_validity_checking(self): # Test that the test library complains if the given samples fall # outside the purported bounds. - with self.test_session() as sess: - with self.assertRaisesOpError("maximum value exceeds expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) - with self.assertRaisesOpError("minimum value falls below expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) - - # But doesn't complain if they don't. - op = st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) - _ = sess.run(op) + with self.assertRaisesOpError("maximum value exceeds expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) + with self.assertRaisesOpError("minimum value falls below expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) + _ = self.evaluate(op) def test_do_maximum_mean(self): n = 117 @@ -223,10 +240,9 @@ def test_do_maximum_mean(self): samples = rng.uniform(size=n).astype(np.float32) # Compute the answer in TF using the code under test - with self.test_session() as sess: - envelope_t = ops.convert_to_tensor(envelope) - max_mean = st._do_maximum_mean(samples, envelope_t, 1) - max_mean = sess.run(max_mean) + envelope_t = ops.convert_to_tensor(envelope) + max_mean = st._do_maximum_mean(samples, envelope_t, 1) + max_mean = self.evaluate(max_mean) # Compute the correct answer for this case in numpy. In this # example, `n` and `envelope` are such that `samples[2]` is the diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD new file mode 100644 index 00000000000000..03e26b198ea02a --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -0,0 +1,48 @@ +# Description: +# Internal testing utilities, e.g., computing the correct answer to +# put in a unit test. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "correlation_matrix_volumes_py", + srcs = [ + "correlation_matrix_volumes_lib.py", + ], + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "correlation_matrix_volumes", + srcs = [ + "correlation_matrix_volumes.py", + ], + deps = [ + ":correlation_matrix_volumes_py", + ], +) + +py_test( + name = "correlation_matrix_volumes_test", + size = "medium", + srcs = ["correlation_matrix_volumes_test.py"], + tags = ["no_pip"], + deps = [ + ":correlation_matrix_volumes_py", + # For statistical testing + "//tensorflow/contrib/distributions:distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py new file mode 100644 index 00000000000000..2eab51cd3053ea --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Executable to estimate the volume of various sets of correlation matrices. + +See correlation_matrix_volumes_lib.py for purpose and methodology. + +Invocation example: +``` +python correlation_matrix_volumes.py --num_samples 1e7 +``` + +This will compute 10,000,000-sample confidence intervals for the +volumes of several sets of correlation matrices. Which sets, and the +desired statistical significance, are hard-coded in this source file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint + +from absl import app +from absl import flags + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr + +FLAGS = flags.FLAGS + +# Float to support giving the number of samples in scientific notation. +# The production run used for the LKJ test used 1e7 samples. +flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.') + + +def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + # This wrapper undoes the batching in compute_true_volumes, because + # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM. + bounds = {} + for db in det_bounds: + bounds[db] = corr.compute_true_volumes( + [db], dim, num_samples, error_rate=error_rate, seed=seed)[db] + return bounds + + +# The particular bounds in all three of these functions were chosen by +# a somewhat arbitrary walk through an empirical tradeoff, for the +# purpose of testing the LKJ distribution. Setting the determinant +# bound lower +# - Covers more of the testee's sample space, and +# - Increases the probability that the rejection sampler will hit, thus +# - Decreases the relative error (at a fixed sample count) in the +# rejection-based volume estimate; +# but also +# - Increases the variance of the estimator used in the LKJ test. +# This latter variance is also affected by the dimension and the +# tested concentration parameter, and can be compensated for with more +# compute (expensive) or a looser discrepancy limit (unsatisfying). +# The values here are the projection of the points in that test design +# space that ended up getting chosen. +def compute_3x3_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 3, num_samples, error_rate=5e-7, seed=46) + + +def compute_4x4_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 4, num_samples, error_rate=5e-7, seed=47) + + +def compute_5x5_volumes(num_samples): + det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4] + return ctv_debatched( + det_bounds, 5, num_samples, error_rate=5e-7, seed=48) + + +def main(_): + full_bounds = {} + full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples)) + full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples)) + full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples)) + pprint.pprint(full_bounds) + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py new file mode 100644 index 00000000000000..455e71f00c96e7 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimating the volume of the correlation matrices with bounded determinant. + +Why? Because lkj_test.py tests the sampler for the LKJ distribution +by estimating the same volume another way. + +How? Rejection sampling. Or, more precisely, importance sampling, +proposing from the uniform distribution on symmetric matrices with +diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation +matrix if and only if it is also positive semi-definite. + +The samples can then be converted into a confidence interval on the +volume in question by the [Clopper-Pearson +method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval), +also implemented here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.ops.distributions import util +from tensorflow.python.platform import tf_logging + +__all__ = [ + "correlation_matrix_volume_rejection_samples", + "compute_true_volumes", +] + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +optimize = try_import("scipy.optimize") +stats = try_import("scipy.stats") + + +def _psd_mask(x): + """Computes whether each square matrix in the input is positive semi-definite. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + + Returns: + mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each + scalar is 1 if the corresponding matrix was PSD, otherwise 0. + """ + # Allegedly + # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite + # it is more efficient to test for positive semi-definiteness by + # trying to compute the Cholesky decomposition -- the matrix is PSD + # if you succeed and not PSD if you fail. However, TensorFlow's + # Cholesky raises an exception if _any_ of the input matrices are + # not PSD, from which I don't know how to extract _which ones_, so I + # proceed by explicitly computing all the eigenvalues and checking + # whether they are all positive or not. + # + # Also, as was discussed in the answer, it is somewhat dangerous to + # treat SPD-ness as binary in floating-point arithmetic. Cholesky + # factorization can complete and 'look' like everything is fine + # (e.g., O(1) entries and a diagonal of all ones) but the matrix can + # have an exponential condition number. + eigenvalues, _ = linalg_ops.self_adjoint_eig(x) + return math_ops.cast( + math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype) + + +def _det_large_enough_mask(x, det_bounds): + """Returns whether the input matches the given determinant limit. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + det_bounds: A floating-point `Tensor` that must broadcast to shape + `[B1, ..., Bn]`, giving the desired lower bound on the + determinants in `x`. + + Returns: + mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each + scalar is 1 if the corresponding matrix had determinant above + the corresponding bound, otherwise 0. + """ + # For the curious: I wonder whether it is possible and desirable to + # use a Cholesky decomposition-based algorithm for this, since the + # only matrices whose determinant this code cares about will be PSD. + # Didn't figure out how to code that in TensorFlow. + # + # Expert opinion is that it would be about twice as fast since + # Cholesky is roughly half the cost of Gaussian Elimination with + # Partial Pivoting. But this is less of an impact than the switch in + # _psd_mask. + return math_ops.cast( + linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) + + +def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): + """Returns a uniformly random `Tensor` of "correlation-like" matrices. + + A "correlation-like" matrix is a symmetric square matrix with all entries + between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, + the ones that are positive semi-definite are exactly the correlation + matrices. + + Args: + num_rows: Python `int` dimension of the correlation-like matrices. + batch_shape: `Tensor` or Python `tuple` of `int` shape of the + batch to return. + dtype: `dtype` of the `Tensor` to return. + seed: Random seed. + + Returns: + matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` + and dtype `dtype`. Each entry is in [-1, 1], and each matrix + along the bottom two dimensions is symmetric and has 1s on the + main diagonal. + """ + num_entries = num_rows * (num_rows + 1) / 2 + ones = array_ops.ones(shape=[num_entries], dtype=dtype) + # It seems wasteful to generate random values for the diagonal since + # I am going to throw them away, but `fill_triangular` fills the + # diagonal, so I probably need them. + # It's not impossible that it would be more efficient to just fill + # the whole matrix with random values instead of messing with + # `fill_triangular`. Then would need to filter almost half out with + # `matrix_band_part`. + unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) + tril = util.fill_triangular(unifs) + symmetric = tril + array_ops.matrix_transpose(tril) + diagonal_ones = array_ops.ones( + shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), + dtype=dtype) + return array_ops.matrix_set_diag(symmetric, diagonal_ones) + + +def correlation_matrix_volume_rejection_samples( + det_bounds, dim, sample_shape, dtype, seed): + """Returns rejection samples from trying to get good correlation matrices. + + The proposal being rejected from is the uniform distribution on + "correlation-like" matrices. We say a matrix is "correlation-like" + if it is a symmetric square matrix with all entries between -1 and 1 + (inclusive) and 1s on the main diagonal. Of these, the ones that + are positive semi-definite are exactly the correlation matrices. + + The rejection algorithm, then, is to sample a `Tensor` of + `sample_shape` correlation-like matrices of dimensions `dim` by + `dim`, and check each one for (i) being a correlation matrix (i.e., + PSD), and (ii) having determinant at least the corresponding entry + of `det_bounds`. + + Args: + det_bounds: A `Tensor` of lower bounds on the determinants of + acceptable matrices. The shape must broadcast with `sample_shape`. + dim: A Python `int` dimension of correlation matrices to sample. + sample_shape: Python `tuple` of `int` shape of the samples to + compute, excluding the two matrix dimensions. + dtype: The `dtype` in which to do the computation. + seed: Random seed. + + Returns: + weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the + corresponding matrix was not a correlation matrix, or had too + small of a determinant. Otherwise, the entry is the + multiplicative inverse of the density of proposing that matrix + uniformly, i.e., the volume of the set of `dim` by `dim` + correlation-like matrices. + volume: The volume of the set of `dim` by `dim` correlation-like + matrices. + """ + with ops.name_scope("rejection_sampler"): + rej_proposals = _uniform_correlation_like_matrix( + dim, sample_shape, dtype, seed=seed) + rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) + # The density of proposing any given point is 1 / rej_proposal_volume; + # The weight of that point should be scaled by + # 1 / density = rej_proposal_volume. + rej_weights = rej_proposal_volume * _psd_mask( + rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) + return rej_weights, rej_proposal_volume + + +def _clopper_pearson_confidence_interval(samples, error_rate): + """Computes a confidence interval for the mean of the given 1-D distribution. + + Assumes (and checks) that the given distribution is Bernoulli, i.e., + takes only two values. This licenses using the CDF of the binomial + distribution for the confidence, which is tighter (for extreme + probabilities) than the DKWM inequality. The method is known as the + [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Assumes: + + - The given samples were drawn iid from the distribution of interest. + + - The given distribution is a Bernoulli, i.e., supported only on + low and high. + + Guarantees: + + - The probability (over the randomness of drawing the given sample) + that the true mean is outside the returned interval is no more + than the given error_rate. + + Args: + samples: `np.ndarray` of samples drawn iid from the distribution + of interest. + error_rate: Python `float` admissible rate of mistakes. + + Returns: + low: Lower bound of confidence interval. + high: Upper bound of confidence interval. + + Raises: + ValueError: If `samples` has rank other than 1 (batch semantics + are not implemented), or if `samples` contains values other than + `low` or `high` (as that makes the distribution not Bernoulli). + """ + # TODO(b/78025336) Migrate this confidence interval function + # to statistical_testing.py. In order to do that + # - Get the binomial CDF from the Binomial distribution + # - Implement scalar root finding in TF. Batch bisection search + # shouldn't be too hard, and is definitely good enough for this + # problem. Batching the Brent algorithm (from scipy) that is used + # here may be more involved, but may also not be necessary---it's + # only used here because scipy made it convenient. In particular, + # robustness is more important than speed here, which may make + # bisection search actively better. + # - The rest is just a matter of rewriting in the appropriate style. + if optimize is None or stats is None: + raise ValueError( + "Scipy is required for computing Clopper-Pearson confidence intervals") + if len(samples.shape) != 1: + raise ValueError("Batch semantics not implemented") + n = len(samples) + low = np.amin(samples) + high = np.amax(samples) + successes = np.count_nonzero(samples - low) + failures = np.count_nonzero(samples - high) + if successes + failures != n: + uniques = np.unique(samples) + msg = ("Purportedly Bernoulli distribution had distinct samples" + " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) + raise ValueError(msg) + def p_small_enough(p): + prob = stats.binom.logcdf(successes, n, p) + return prob - np.log(error_rate / 2.) + def p_big_enough(p): + prob = stats.binom.logsf(successes, n, p) + return prob - np.log(error_rate / 2.) + high_p = optimize.brentq( + p_small_enough, float(successes) / n, 1., rtol=1e-9) + low_p = optimize.brentq( + p_big_enough, 0., float(successes) / n, rtol=1e-9) + low_interval = low + (high - low) * low_p + high_interval = low + (high - low) * high_p + return (low_interval, high_interval) + + +def compute_true_volumes( + det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + """Returns confidence intervals for the desired correlation matrix volumes. + + The confidence intervals are computed by the [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Args: + det_bounds: A rank-1 numpy array of lower bounds on the + determinants of acceptable matrices. Entries must be unique. + dim: A Python `int` dimension of correlation matrices to sample. + num_samples: The number of samples to draw. + error_rate: The statistical significance of the returned + confidence intervals. The significance is broadcast: Each + returned interval separately may be incorrect with probability + (under the sample of correlation-like matrices drawn internally) + at most `error_rate`. + seed: Random seed. + + Returns: + bounds: A Python `dict` mapping each determinant bound to the low, high + tuple giving the confidence interval. + """ + bounds = {} + with session.Session() as sess: + rej_weights, _ = correlation_matrix_volume_rejection_samples( + det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) + rej_weights = sess.run(rej_weights) + for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): + template = ("Estimating volume of {}x{} correlation " + "matrices with determinant >= {}.") + print(template.format(dim, dim, det)) + sys.stdout.flush() + bounds[det] = _clopper_pearson_confidence_interval( + rw, error_rate=error_rate) + return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py new file mode 100644 index 00000000000000..8f99300e638711 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for correlation_matrix_volumes_lib.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +# NxN correlation matrices are determined by the N*(N-1)/2 +# lower-triangular entries. In addition to being between -1 and 1, +# they must also obey the constraint that the determinant of the +# resulting symmetric matrix is non-negative. In 2x2, we can even +# analytically compute the volume when the determinant is bounded to > +# epsilon, as that boils down to the one lower-triangular entry being +# less than 1 - epsilon in absolute value. +def two_by_two_volume(det_bound): + return 2 * np.sqrt(1.0 - det_bound) + + +# The post +# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ +# derives (with elementary calculus) that the volume (with respect to +# Lebesgue^3 measure) of the set of 3x3 correlation matrices is +# pi^2/2. The same result is also obtained by [1]. +def three_by_three_volume(): + return np.pi**2 / 2. + + +# The volume of the unconstrained set of correlation matrices is also +# the normalization constant of the LKJ distribution from [2]. As +# part of defining the distribution, that reference a derives general +# formula for this volume for all dimensions. A TensorFlow +# computation thereof gave the below result for 4x4: +def four_by_four_volume(): + # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) + return 11.6973076 + +# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of +# correlation matrices." The American Statistician, 48(4), 276-279. + +# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating +# random correlation matrices based on vines and extended onion +# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. + + +class CorrelationMatrixVolumesTest(test.TestCase): + + def testRejection2D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + exact_volumes = two_by_two_volume(det_bounds) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) + # shape of rej_weights: [num_samples, 9, 2, 2] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + # Correct the false fail rate due to different broadcasting + false_fail_rate=1.1e-7, false_pass_rate=1e-6), + 0.036) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection3D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) + # shape of rej_weights: [num_samples, 1, 3, 3] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 3% relative error + 0.15) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection4D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = [four_by_four_volume()] + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) + # shape of rej_weights: [num_samples, 1, 4, 4] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 10% relative error + 1.1) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testVolumeEstimation2D(self): + # Test that the confidence intervals produced by + # corr.compte_true_volumes are sound, in the sense of containing + # the exact volume. + num_samples = int(1e5) # Chosen by symmetry with testRejection2D + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + volume_bounds = corr.compute_true_volumes( + det_bounds, 2, num_samples, error_rate=1e-6, seed=47) + exact_volumes = two_by_two_volume(det_bounds) + for det, volume in zip(det_bounds, exact_volumes): + computed_low, computed_high = volume_bounds[det] + self.assertLess(computed_low, volume) + self.assertGreater(computed_high, volume) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 88ed0127841093..11ca90c4833d84 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -144,7 +144,7 @@ def __init__(self, `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index bf5590cd552a91..4714caad69ee43 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -41,9 +41,6 @@ class BatchReshape(distribution_lib.Distribution): This "meta-distribution" reshapes the batch dimensions of another distribution. - Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support - `-1` for flattening. - #### Examples ```python @@ -51,7 +48,7 @@ class BatchReshape(distribution_lib.Distribution): dtype = np.float32 dims = 2 - new_batch_shape = [1, 2, 3] + new_batch_shape = [1, 2, -1] old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype) @@ -85,8 +82,9 @@ def __init__(self, Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. - batch_shape: Positive `int`-like vector-shaped `Tensor` representing the - new shape of the batch dimensions. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing + the new shape of the batch dimensions. Up to one dimension may contain + `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -104,31 +102,28 @@ def __init__(self, ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ - parameters = locals() + parameters = dict(locals()) name = name or "BatchReshape" + distribution.name - self._distribution = distribution with ops.name_scope(name, values=[batch_shape]) as name: - self._batch_shape_ = ops.convert_to_tensor( - batch_shape, - dtype=dtypes.int32, - name="batch_shape") - self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) - if self._batch_shape_static is not None: - self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = validate_init_args( - self._distribution, - self._batch_shape_, - validate_args, - self._batch_shape_static) + # The unexpanded batch shape may contain up to one dimension of -1. + self._batch_shape_unexpanded = ops.convert_to_tensor( + batch_shape, dtype=dtypes.int32, name="batch_shape") + validate_init_args_statically(distribution, self._batch_shape_unexpanded) + batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( + distribution.batch_shape_tensor(), self._batch_shape_unexpanded, + validate_args) + self._distribution = distribution + self._batch_shape_ = batch_shape + self._batch_shape_static = batch_shape_static + self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( - dtype=self._distribution.dtype, - reparameterization_type=self._distribution.reparameterization_type, + dtype=distribution.dtype, + reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [self._batch_shape_] + - self._distribution._graph_parents), # pylint: disable=protected-access + [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name) @property @@ -140,7 +135,7 @@ def _batch_shape_tensor(self): return array_ops.identity(self._batch_shape_) def _batch_shape(self): - return tensor_shape.TensorShape(self._batch_shape_static) + return self._batch_shape_static def _event_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -152,11 +147,13 @@ def _event_shape(self): def _sample_n(self, n, seed=None): with ops.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed) - new_shape = array_ops.concat([ - [n], - self.batch_shape_tensor(), - self.event_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + [n], + self._batch_shape_unexpanded, + self.event_shape_tensor(), + ], + axis=0) return array_ops.reshape(x, new_shape) def _log_prob(self, x): @@ -213,9 +210,9 @@ def _sample_shape(self, x): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] @@ -238,10 +235,11 @@ def _call_reshape_input_output(self, fn, x): self.event_shape_tensor(), ], axis=0) result = fn(array_ops.reshape(x, old_shape)) - new_shape = array_ops.concat([ - sample_shape, - self.batch_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + sample_shape, + self._batch_shape_unexpanded, + ], axis=0) result = array_ops.reshape(result, new_shape) if (static_sample_shape.ndims is not None and self.batch_shape.ndims is not None): @@ -261,8 +259,7 @@ def _call_and_reshape_output( if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = array_ops.concat( - [self.batch_shape_tensor()] + event_shape_list, - axis=0) + [self._batch_shape_unexpanded] + event_shape_list, axis=0) result = array_ops.reshape(fn(), new_shape) if (self.batch_shape.ndims is not None and self.event_shape.ndims is not None): @@ -281,9 +278,9 @@ def _validate_sample_arg(self, x): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and @@ -355,62 +352,56 @@ def _validate_sample_arg(self, x): return runtime_assertions -def validate_init_args( - distribution, - batch_shape, - validate_args, - batch_shape_static): +def calculate_reshape(original_shape, new_shape, validate=False, name=None): + """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" + batch_shape_static = tensor_util.constant_value_as_shape(new_shape) + if batch_shape_static.is_fully_defined(): + return np.int32(batch_shape_static.as_list()), batch_shape_static, [] + with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): + original_size = math_ops.reduce_prod(original_shape) + implicit_dim = math_ops.equal(new_shape, -1) + size_implicit_dim = ( + original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) + new_ndims = array_ops.shape(new_shape) + expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) + validations = [] if not validate else [ + check_ops.assert_rank( + original_shape, 1, message="Original shape must be a vector."), + check_ops.assert_rank( + new_shape, 1, message="New shape must be a vector."), + check_ops.assert_less_equal( + math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), + 1, + message="At most one dimension can be unknown."), + check_ops.assert_positive( + expanded_new_shape, message="Shape elements must be >=-1."), + check_ops.assert_equal( + math_ops.reduce_prod(expanded_new_shape), + original_size, + message="Shape sizes do not match."), + ] + return expanded_new_shape, batch_shape_static, validations + + +def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" - with ops.name_scope(name="validate_init_args", - values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access - runtime_assertions = [] - - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] - - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) - - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format(batch_shape.shape.ndims)) + + batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) + batch_size_static = batch_shape_static.num_elements() + dist_batch_size_static = distribution.batch_shape.num_elements() + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, dist_batch_size_static)) + + if batch_shape_static.dims is not None: + if any( + dim.value is not None and dim.value < 1 for dim in batch_shape_static): + raise ValueError("`batch_shape` elements must be >=-1.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 51478dbeffaabc..4965381ef33e14 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -30,6 +30,7 @@ @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL @@Ordered @@Permute @@PowerTransform @@ -68,6 +69,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 85ad23e4133ef0..16f959560ce0f1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -20,10 +20,9 @@ import itertools -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector @@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) -def _maybe_get_event_ndims_statically(event_ndims): - static_event_ndims = (event_ndims if isinstance(event_ndims, int) - else tensor_util.constant_value(event_ndims)) - if static_event_ndims is not None: - return static_event_ndims - - return event_ndims - - def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -238,13 +228,13 @@ def _inverse(self, y, **kwargs): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant( - 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + y = ops.convert_to_tensor(y, name="y") + ildj = math_ops.cast(0., dtype=y.dtype.base_dtype) if not self.bijectors: return ildj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -258,11 +248,15 @@ def _inverse_log_det_jacobian(self, y, **kwargs): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_static_event_ndims( + event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -274,13 +268,12 @@ def _forward(self, x, **kwargs): def _forward_log_det_jacobian(self, x, **kwargs): x = ops.convert_to_tensor(x, name="x") - fldj = constant_op.constant( - 0., dtype=x.dtype, name="inverse_log_det_jacobian") + fldj = math_ops.cast(0., dtype=x.dtype.base_dtype) if not self.bijectors: return fldj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -293,13 +286,14 @@ def _forward_log_det_jacobian(self, x, **kwargs): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index ecdb8967f43e59..268c8d03426d43 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -53,7 +53,7 @@ class CholeskyOuterProduct(bijector.Bijector): its spectrum), and that the product of two positive-diagonal lower-triangular matrices is another positive-diagonal lower-triangular matrix. - A simple inductive argument (proceding one column of L_3 at a time) shows + A simple inductive argument (proceeding one column of L_3 at a time) shows that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 00000000000000..71903f705232f0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,145 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index 12d16031783b78..e4944beedcbca0 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -163,7 +163,7 @@ def __init__(self, more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index daacfe657fe154..23b6a83c17d586 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -120,7 +120,7 @@ def __init__(self, Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index c77c5fd20895a6..686ae1ba74641e 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -83,7 +83,7 @@ def __init__(self, more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True @@ -119,7 +119,7 @@ def __init__(self, validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 10b45361358b40..3598c8d23ea900 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,7 +20,6 @@ from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -106,7 +105,7 @@ def _log_prob(self, y, bijector_kwargs=None, distribution_kwargs=None): bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -131,7 +130,7 @@ def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None): bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -220,14 +219,14 @@ def _quantile(self, value, bijector_kwargs=None, distribution_kwargs=None): inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) - static_event_ndims = tensor_util.constant_value(event_ndims) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index a42350430e9851..c44c76a1338176 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -86,7 +86,7 @@ def __init__(self, Raises: ValueError: If `loc` is a scalar. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index 53dd42f4c83fce..e1e42ee95d200d 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -85,7 +85,7 @@ def __init__(self, name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 2c261073ee1646..9d94fd11c62ce6 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -124,7 +124,7 @@ def __init__(self, Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index d0df2befd6e46c..9c96254d1c0a59 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -105,7 +105,7 @@ def __init__(self, if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index fbde55ef310de1..cd6eaa8407477b 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -116,7 +116,7 @@ def __init__( ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ - parameters = locals() + parameters = dict(locals()) name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 502bd4f493337b..208057b34db288 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -125,7 +125,7 @@ def __init__(self, Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), @@ -280,7 +280,7 @@ def __init__(self, validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 66682b2ff5493f..0ff989fc952c6f 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,7 +31,6 @@ from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform -from tensorflow.python.util.tf_export import tf_export __all__ = [ "Kumaraswamy", @@ -59,7 +58,6 @@ def _harmonic_number(x): return math_ops.digamma(x + one) - math_ops.digamma(one) -@tf_export("distributions.Kumaraswamy") class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index c83b5bc2e3a8c5..27aa863440574e 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -119,7 +119,7 @@ def __init__(self, Raises: TypeError: if loc and scale are different dtypes. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index 2ef294af2e8bc9..bfb53a06c011ce 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -116,7 +116,7 @@ def __init__(self, matching static batch shapes, or all components do not have matching static event shapes. """ - parameters = locals() + parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 0b1301e551728f..112eefd3691815 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -130,7 +130,7 @@ def __init__(self, ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index e3236c2db93695..d2beb2aff0481e 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -193,7 +193,7 @@ def __init__(self, Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): @@ -224,7 +224,7 @@ def __init__(self, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagWithSoftplusScale"): - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[scale_diag]) as name: super(MultivariateNormalDiagWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 2f6a6f198cbcfb..5117379b047f5e 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -215,7 +215,7 @@ def __init__(self, Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 86fcd4db54ad85..57f47db50c496f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -45,7 +45,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): The probability density function (pdf) is, with `@` as matrix multiplication, ```none - pdf(x; loc, covariance_matrix) = exp(-0.5 ||y||**2) / Z, + pdf(x; loc, covariance_matrix) = exp(-0.5 y) / Z, y = (x - loc)^T @ inv(covariance_matrix) @ (x - loc) Z = (2 pi)**(0.5 k) |det(covariance_matrix)|**(0.5). ``` @@ -54,8 +54,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): * `loc` is a vector in `R^k`, * `covariance_matrix` is an `R^{k x k}` symmetric positive definite matrix, - * `Z` denotes the normalization constant, and, - * `||y||**2` denotes the squared Euclidean norm of `y`. + * `Z` denotes the normalization constant. Additional leading dimensions (if any) in `loc` and `covariance_matrix` allow for batch dimensions. @@ -156,7 +155,7 @@ def __init__(self, Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ - parameters = locals() + parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 44c92312c7dc75..6a0383db025552 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -170,7 +170,7 @@ def __init__(self, ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d6f8b731cbeed5..c809ef3c1cb5b8 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -179,7 +179,7 @@ def __init__(self, Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ - parameters = locals() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index eeaf9c0a5ebc13..2bd11e24b315e0 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -90,7 +90,7 @@ def __init__(self, name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 305b138fdc2318..3e44c10fab726a 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -115,7 +115,7 @@ def __init__( more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index a84aad6fc93723..04de8106ee0c06 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -93,7 +93,7 @@ def __init__(self, TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 19c99dcee92978..7b10ba998f0cea 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -255,7 +255,7 @@ def __init__(self, TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: if loc is not None: loc = ops.convert_to_tensor(loc, name="loc") diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 1ef7651d03a338..5ac6c34b538016 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -128,7 +128,7 @@ def _logsum_expbig_minus_expsmall(big, small): class QuantizedDistribution(distributions.Distribution): """Distribution representing the quantization `Y = ceiling(X)`. - #### Definition in terms of sampling. + #### Definition in Terms of Sampling ``` 1. Draw X @@ -138,7 +138,7 @@ class QuantizedDistribution(distributions.Distribution): 5. Return Y ``` - #### Definition in terms of the probability mass function. + #### Definition in Terms of the Probability Mass Function Given scalar random variable `X`, we define a discrete random variable `Y` supported on the integers as follows: @@ -170,12 +170,62 @@ class QuantizedDistribution(distributions.Distribution): `P[Y = j]` is still the mass of `X` within the `jth` interval. - #### Caveats + #### Examples + + We illustrate a mixture of discretized logistic distributions + [(Salimans et al., 2017)][1]. This is used, for example, for capturing 16-bit + audio in WaveNet [(van den Oord et al., 2017)][2]. The values range in + a 1-D integer domain of `[0, 2**16-1]`, and the discretization captures + `P(x - 0.5 < X <= x + 0.5)` for all `x` in the domain excluding the endpoints. + The lowest value has probability `P(X <= 0.5)` and the highest value has + probability `P(2**16 - 1.5 < X)`. + + Below we assume a `wavenet` function. It takes as `input` right-shifted audio + samples of shape `[..., sequence_length]`. It returns a real-valued tensor of + shape `[..., num_mixtures * 3]`, i.e., each mixture component has a `loc` and + `scale` parameter belonging to the logistic distribution, and a `logits` + parameter determining the unnormalized probability of that component. + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + net = wavenet(inputs) + loc, unconstrained_scale, logits = tf.split(net, + num_or_size_splits=3, + axis=-1) + scale = tf.nn.softplus(unconstrained_scale) + + # Form mixture of discretized logistic distributions. Note we shift the + # logistic distribution by -0.5. This lets the quantization capture "rounding" + # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`. + discretized_logistic_dist = tfd.QuantizedDistribution( + distribution=tfd.TransformedDistribution( + distribution=tfd.Logistic(loc=loc, scale=scale), + bijector=tfb.AffineScalar(shift=-0.5)), + low=0., + high=2**16 - 1.) + mixture_dist = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=logits), + components_distribution=discretized_logistic_dist) + + neg_log_likelihood = -tf.reduce_sum(mixture_dist.log_prob(targets)) + train_op = tf.train.AdamOptimizer().minimize(neg_log_likelihood) + ``` + + After instantiating `mixture_dist`, we illustrate maximum likelihood by + calculating its log-probability of audio samples as `target` and optimizing. + + #### References - Since evaluation of each `P[Y = j]` involves a cdf evaluation (rather than - a closed form function such as for a Poisson), computations such as mean and - entropy are better done with samples or approximations, and are not - implemented by this class. + [1]: Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik P. Kingma. + PixelCNN++: Improving the PixelCNN with discretized logistic mixture + likelihood and other modifications. + _International Conference on Learning Representations_, 2017. + https://arxiv.org/abs/1701.05517 + [2]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech + Synthesis. _arXiv preprint arXiv:1711.10433_, 2017. + https://arxiv.org/abs/1711.10433 """ def __init__(self, @@ -213,7 +263,7 @@ def __init__(self, `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ - parameters = locals() + parameters = dict(locals()) values = ( list(distribution.parameters.values()) + [low, high]) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 84c8d29072c2f1..4182ca2b56ea80 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -165,7 +165,7 @@ def __init__(self, Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 325f41e37c928b..5414f347cd65e2 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -162,7 +162,7 @@ def __init__( more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py index 056d349688511e..cf505ac627b62a 100644 --- a/tensorflow/contrib/distributions/python/ops/seed_stream.py +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -169,7 +169,7 @@ def __init__(self, seed, salt): and TensorFlow Probability code base. See class docstring for rationale. """ - self._seed = seed + self._seed = seed.original_seed if isinstance(seed, SeedStream) else seed self._salt = salt self._counter = 0 diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index 03828fa61277ee..a764544932cea8 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -132,7 +132,7 @@ def __init__(self, if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale, skewness, tailweight]) as name: diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index 9c69435fac1099..c25e8c51d7705b 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -140,6 +140,7 @@ "assert_true_mean_equal_by_dkwm", "min_discrepancy_of_true_means_detectable_by_dkwm", "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_in_interval_by_dkwm", "assert_true_mean_equal_by_dkwm_two_sample", "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", "min_num_samples_for_dkwm_mean_two_sample_test", @@ -209,17 +210,17 @@ def _maximum_mean(samples, envelope, high, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `high`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - high: Floating-point tensor of upper bounds on the distributions' - supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. `samples <= high`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of upper bounds on the true means. + bound: Floating-point `Tensor` of upper bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be larger than @@ -254,17 +255,17 @@ def _minimum_mean(samples, envelope, low, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `low`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - low: Floating-point tensor of lower bounds on the distributions' - supports. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. `samples >= low`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of lower bounds on the true means. + bound: Floating-point `Tensor` of lower bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be smaller than @@ -300,12 +301,12 @@ def _dkwm_cdf_envelope(n, error_rate, name=None): probability above. Args: - n: Tensor of numbers of samples drawn. - error_rate: Floating-point tensor of admissible rates of mistakes. + n: `Tensor` of numbers of samples drawn. + error_rate: Floating-point `Tensor` of admissible rates of mistakes. name: A name for this operation (optional). Returns: - eps: Tensor of maximum distances the true CDF can be from the + eps: `Tensor` of maximum distances the true CDF can be from the empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and `error_rate`. @@ -324,8 +325,8 @@ def _check_shape_dominates(samples, parameters): sample counts end up inflated. Args: - samples: A Tensor whose shape is to be protected against broadcasting. - parameters: A list of Tensors who are parameters for the statistical test. + samples: A `Tensor` whose shape is to be protected against broadcasting. + parameters: A list of `Tensor`s who are parameters for the statistical test. Returns: samples: Return original `samples` with control dependencies attached @@ -369,19 +370,23 @@ def true_mean_confidence_interval_by_dkwm( members. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - error_rate: *Scalar* admissible total rate of mistakes. + error_rate: *Scalar* floating-point `Tensor` admissible total rate + of mistakes. name: A name for this operation (optional). Returns: - low: A floating-point tensor of stochastic lower bounds on the true means. - high: A floating-point tensor of stochastic upper bounds on the true means. + low: A floating-point `Tensor` of stochastic lower bounds on the + true means. + high: A floating-point `Tensor` of stochastic upper bounds on the + true means. """ with ops.name_scope( name, "true_mean_confidence_interval_by_dkwm", @@ -436,15 +441,17 @@ def assert_true_mean_equal_by_dkwm( the assertion will insist on stronger evidence to fail any one member. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - expected: Floating-point tensor of expected true means. - false_fail_rate: *Scalar* admissible total rate of mistakes. + expected: Floating-point `Tensor` of expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -454,20 +461,8 @@ def assert_true_mean_equal_by_dkwm( with ops.name_scope( name, "assert_true_mean_equal_by_dkwm", [samples, low, high, expected, false_fail_rate]): - samples = ops.convert_to_tensor(samples, name="samples") - low = ops.convert_to_tensor(low, name="low") - high = ops.convert_to_tensor(high, name="high") - expected = ops.convert_to_tensor(expected, name="expected") - false_fail_rate = ops.convert_to_tensor( - false_fail_rate, name="false_fail_rate") - samples = _check_shape_dominates(samples, [low, high, expected]) - min_mean, max_mean = true_mean_confidence_interval_by_dkwm( - samples, low, high, error_rate=false_fail_rate) - less_op = check_ops.assert_less( - min_mean, expected, message="Mean confidence interval too high") - with ops.control_dependencies([less_op]): - return check_ops.assert_greater( - max_mean, expected, message="Mean confidence interval too low") + return assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected, expected, false_fail_rate) def min_discrepancy_of_true_means_detectable_by_dkwm( @@ -487,30 +482,35 @@ def min_discrepancy_of_true_means_detectable_by_dkwm( with the same `false_pass_rate`. Args: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. - low: Floating-point tensor of lower bounds on the distributions' + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true + discr: `Tensor` of lower bounds on the distances between true means detectable by a DKWM-based test. For each batch member `i`, of `K` total, drawing `n[i]` samples from some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discr[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discr[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The detectable discrepancy scales as @@ -558,17 +558,19 @@ def min_num_samples_for_dkwm_mean_test( on a scalar distribution supported on `[low, high]`. Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low: Tensor of lower bounds on the distributions' support. - high: Tensor of upper bounds on the distributions' support. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + low: `Tensor` of lower bounds on the distributions' support. + high: `Tensor` of upper bounds on the distributions' support. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. The `discrepancy`, `low`, and `high` tensors must have @@ -578,12 +580,15 @@ def min_num_samples_for_dkwm_mean_test( some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discrepancy[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discrepancy[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The required number of samples scales as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, @@ -610,6 +615,76 @@ def min_num_samples_for_dkwm_mean_test( return math_ops.maximum(n1, n2) +def assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected_low, expected_high, + false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is in the given interval. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the mean of the distribution from which the given samples are + drawn is _outside_ the given interval with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want + to check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected_low: Floating-point `Tensor` of lower bounds on the + expected true means. + expected_high: Floating-point `Tensor` of upper bounds on the + expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean + interval does not overlap with the corresponding confidence + interval. + """ + with ops.name_scope( + name, "assert_true_mean_in_interval_by_dkwm", + [samples, low, high, expected_low, expected_high, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected_low = ops.convert_to_tensor(expected_low, name="expected_low") + expected_high = ops.convert_to_tensor(expected_high, name="expected_high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates( + samples, [low, high, expected_low, expected_high]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, false_fail_rate) + # Assert that the interval [min_mean, max_mean] intersects the + # interval [expected_low, expected_high]. This is true if + # max_mean >= expected_low and min_mean <= expected_high. + # By DeMorgan's law, that's also equivalent to + # not (max_mean < expected_low or min_mean > expected_high), + # which is a way of saying the two intervals are not disjoint. + check_confidence_interval_can_intersect = check_ops.assert_greater_equal( + max_mean, expected_low, message="Confidence interval does not " + "intersect: true mean smaller than expected") + with ops.control_dependencies([check_confidence_interval_can_intersect]): + return check_ops.assert_less_equal( + min_mean, expected_high, message="Confidence interval does not " + "intersect: true mean greater than expected") + + def assert_true_mean_equal_by_dkwm_two_sample( samples1, low1, high1, samples2, low2, high2, false_fail_rate=1e-6, name=None): @@ -630,23 +705,26 @@ def assert_true_mean_equal_by_dkwm_two_sample( the assertion will insist on stronger evidence to fail any one member. Args: - samples1: Floating-point tensor of samples from the + samples1: Floating-point `Tensor` of samples from the distribution(s) A. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low1: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low1 <= samples1 <= high1`. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - samples2: Floating-point tensor of samples from the + samples2: Floating-point `Tensor` of samples from the distribution(s) B. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low2: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low2 <= samples2 <= high2`. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of mistakes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -676,20 +754,10 @@ def assert_true_mean_equal_by_dkwm_two_sample( # and sample counts should be valid; however, because the intervals # scale as O(-log(false_fail_rate)), there doesn't seem to be much # room to win. - min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( - samples1, low1, high1, false_fail_rate / 2.) min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( samples2, low2, high2, false_fail_rate / 2.) - # I want to assert - # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), - # but I think I only have and-combination of asserts, so use DeMorgan. - check_confidence_intervals_can_intersect = check_ops.assert_greater_equal( - max_mean_1, min_mean_2, message="Confidence intervals do not " - "intersect: samples1 has a smaller mean than samples2") - with ops.control_dependencies([check_confidence_intervals_can_intersect]): - return check_ops.assert_less_equal( - min_mean_1, max_mean_2, message="Confidence intervals do not " - "intersect: samples2 has a smaller mean than samples1") + return assert_true_mean_in_interval_by_dkwm( + samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.) def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( @@ -710,22 +778,24 @@ def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( with the same `false_pass_rate`. Args: - n1: Tensor of numbers of samples to be drawn from the distributions A. - low1: Floating-point tensor of lower bounds on the supports of the + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. - low2: Floating-point tensor of lower bounds on the supports of the + n2: `Tensor` of numbers of samples to be drawn from the distributions B. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true means + discr: `Tensor` of lower bounds on the distances between true means detectable by a two-sample DKWM-based test. For each batch member `i`, of `K` total, drawing `n1[i]` samples @@ -776,24 +846,26 @@ def min_num_samples_for_dkwm_mean_two_sample_test( (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low1: Floating-point tensor of lower bounds on the supports of the + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - low2: Floating-point tensor of lower bounds on the supports of the + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n1: Tensor of numbers of samples to be drawn from the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + n2: `Tensor` of numbers of samples to be drawn from the distributions B. For each batch member `i`, of `K` total, drawing `n1[i]` samples from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index af6ff8162b1730..8d4914e16cd374 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -395,7 +395,7 @@ def __init__(self, ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[mix_loc, temperature]) as name: if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index e265b5d0f7c10b..a75b3f3df1f286 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -175,7 +175,7 @@ def __init__(self, Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 89136d6760bb66..a7d4c55be93f61 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -175,7 +175,7 @@ def __init__(self, ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 8dd983b750d9b3..4a53e7a621f273 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -210,7 +210,7 @@ def __init__(self, Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name): with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index ec485c95c15da2..0566e04fece6f9 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -191,7 +191,7 @@ def __init__(self, ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = locals() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 1438ede26500bc..bb33cd0762a368 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -163,7 +163,7 @@ def __init__(self, Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope( name, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 7e78ded9df0756..21f84dcbdea8b4 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -175,7 +175,7 @@ def __init__(self, if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 91453fed5d2791..88d4280759da7c 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -107,7 +107,7 @@ def __init__(self, ValueError: if df < k, where scale operator event shape is `(k, k)` """ - parameters = locals() + parameters = dict(locals()) self._cholesky_input_output_matrices = cholesky_input_output_matrices with ops.name_scope(name) as name: with ops.name_scope("init", values=[df, scale_operator]): @@ -530,7 +530,7 @@ def __init__(self, more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) @@ -646,7 +646,7 @@ def __init__(self, more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = locals() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 9a3b780af888a5..4384431e7b9c3e 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,6 +1,6 @@ # Eager Execution -Eager execution provides an imperative interface to TensorFlow (similiar to +Eager execution provides an imperative interface to TensorFlow (similar to [NumPy](http://www.numpy.org)). When you enable eager execution, TensorFlow operations execute immediately; you do not execute a pre-constructed graph with [`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session). @@ -37,7 +37,7 @@ support for distributed and multi-GPU training and performance. ## Installation -Eager execution is included in TensorFlow versions 1.7 and above. +For eager execution, we recommend using TensorFlow version 1.8 or newer. Installation instructions at https://www.tensorflow.org/install/ ## Documentation @@ -48,12 +48,3 @@ For an introduction to eager execution in TensorFlow, see: - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) - -## Changelog - -- 2017/10/31: Initial preview release (in TensorFlow 1.5) -- 2017/12/01: Example of dynamic neural network: - [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). - See [README.md](python/examples/spinn/README.md) for details. -- 2017/03: Core functionality moved out of the experimental tf.contrib namespace - in TensorFlow 1.7. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 4d28d2266d91b9..dccd813256355d 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -120,7 +120,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -131,6 +130,7 @@ py_library( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", + "//tensorflow/python/training/checkpointable:base", ], ) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 0783d1b5d70e50..adf92c27ea0a27 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 @@ -106,7 +106,8 @@ def remote_fn(h): target_device=target, buffer_size=10, container="", - shared_name=_generate_shared_name("function_buffer_resource")) + shared_name=_generate_shared_name( + "contrib_eager_iterator_function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long handle=self._buffer_resource_handle, handle_device=self._device) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 7b123707cc3a26..68bec9aee894ed 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -37,7 +37,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import util as checkpointable_utils class IteratorTest(test.TestCase): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index c1fd9e0ed020be..1d9371c7ac405d 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -7,6 +7,8 @@ py_library( name = "examples_pip", deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/l2hmc", + "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b80c9090235370..cc9cf53410f641 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -227,7 +227,7 @@ def train_one_epoch(generator, discriminator, generator_optimizer, maxval=1., seed=batch_index) - with tfe.GradientTape(persistent=True) as g: + with tf.GradientTape(persistent=True) as g: generated_images = generator(noise) tf.contrib.summary.image( 'generated_images', @@ -306,7 +306,7 @@ def main(_): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index bd35e50c1f434d..81ac05e26d23c2 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -111,5 +111,5 @@ def benchmark_generate(self): if __name__ == '__main__': - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD new file mode 100644 index 00000000000000..72341835ff10e3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -0,0 +1,39 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "gpu_py_test") + +py_library( + name = "neural_nets", + srcs = ["neural_nets.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +py_library( + name = "l2hmc", + srcs = ["l2hmc.py"], + srcs_version = "PY2AND3", + deps = [ + ":neural_nets", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) + +gpu_py_test( + name = "l2hmc_test", + size = "large", + srcs = ["l2hmc_test.py"], + additional_deps = [ + ":l2hmc", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py new file mode 100644 index 00000000000000..98b4ce1b26acf2 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -0,0 +1,382 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets + + +class Dynamics(tf.keras.Model): + """Dynamics engine of naive L2HMC sampler. + + Args: + x_dim: dimensionality of observed data + loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + """ + + def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + super(Dynamics, self).__init__() + + self.x_dim = x_dim + self.potential = loglikelihood_fn + self.n_steps = n_steps + + self._construct_time() + self._construct_masks() + + self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) + self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) + + self.eps = tfe.Variable( + initial_value=eps, name="eps", dtype=tf.float32, trainable=True) + + # TODO(lxuechen): Remove this after model.add_weight is in place + self.vars_not_in_layers = [self.eps] + self.vars_not_in_layers += self.position_fn.vars_not_in_layers + self.vars_not_in_layers += self.momentum_fn.vars_not_in_layers + + def apply_transition(self, position): + """Propose a new state and perform the accept or reject step.""" + + # Simulate dynamics both forward and backward; + # Use sampled Bernoulli masks to compute the actual solutions + position_f, momentum_f, accept_prob_f = self.transition_kernel( + position, forward=True) + position_b, momentum_b, accept_prob_b = self.transition_kernel( + position, forward=False) + + # Decide direction uniformly + forward_mask = tf.cast( + tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + backward_mask = 1. - forward_mask + + # Obtain proposed states + position_post = ( + forward_mask[:, None] * position_f + + backward_mask[:, None] * position_b) + momentum_post = ( + forward_mask[:, None] * momentum_f + + backward_mask[:, None] * momentum_b) + + # Probability of accepting the proposed states + accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b + + # Accept or reject step + accept_mask = tf.cast( + accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32) + reject_mask = 1. - accept_mask + + # Samples after accept/reject step + position_out = ( + accept_mask[:, None] * position_post + reject_mask[:, None] * position) + + return position_post, momentum_post, accept_prob, position_out + + def transition_kernel(self, position, forward=True): + """Transition kernel of augmented leapfrog integrator.""" + + lf_fn = self._forward_lf if forward else self._backward_lf + + # Resample momentum + momentum = tf.random_normal(tf.shape(position)) + position_post, momentum_post = position, momentum + sumlogdet = 0. + # Apply augmented leapfrog steps + for i in range(self.n_steps): + position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, + i) + sumlogdet += logdet + + accept_prob = self._compute_accept_prob(position, momentum, position_post, + momentum_post, sumlogdet) + + return position_post, momentum_post, accept_prob + + def _forward_lf(self, position, momentum, i): + """One forward augmented leapfrog step. See eq (5-6) in paper.""" + + t = self._get_time(i) + mask, mask_inv = self._get_mask(i) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _backward_lf(self, position, momentum, i): + """One backward augmented leapfrog step. See Appendix A in paper.""" + + # Reversed index/sinusoidal time + t = self._get_time(self.n_steps - i - 1) + mask, mask_inv = self._get_mask(self.n_steps - i - 1) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _update_momentum_forward(self, position, momentum, t): + """Update v in the forward leapfrog step.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= .5 * self.eps + transformed *= self.eps + momentum = ( + momentum * tf.exp(scale) - + .5 * self.eps * (tf.exp(transformed) * grad - translation)) + + return momentum, scale + + def _update_position_forward(self, position, momentum, t, mask): + """Update x in the forward leapfrog step.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask * position, t]) + scale *= self.eps + transformed *= self.eps + position = ( + mask * position + + mask_inv * (position * tf.exp(scale) + self.eps * + (tf.exp(transformed) * momentum + translation))) + + return position, mask_inv * scale + + def _update_momentum_backward(self, position, momentum, t): + """Update v in the backward leapfrog step. Inverting the forward update.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= -.5 * self.eps + transformed *= self.eps + momentum = ( + tf.exp(scale) * (momentum + .5 * self.eps * + (tf.exp(transformed) * grad - translation))) + + return momentum, scale + + def _update_position_backward(self, position, momentum, t, mask): + """Update x in the backward leapfrog step. Inverting the forward update.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask_inv * position, t]) + scale *= -self.eps + transformed *= self.eps + position = ( + mask_inv * position + mask * tf.exp(scale) * + (position - self.eps * tf.exp(transformed) * momentum + translation)) + + return position, mask * scale + + def _compute_accept_prob(self, position, momentum, position_post, + momentum_post, sumlogdet): + """Compute the prob of accepting the proposed state given old state.""" + + old_hamil = self.hamiltonian(position, momentum) + new_hamil = self.hamiltonian(position_post, momentum_post) + + return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + + def _construct_time(self): + """Convert leapfrog step index into sinusoidal time.""" + + self.ts = [] + for i in range(self.n_steps): + t = tf.constant( + [ + np.cos(2 * np.pi * i / self.n_steps), + np.sin(2 * np.pi * i / self.n_steps) + ], + dtype=tf.float32) + self.ts.append(t[None, :]) + + def _get_time(self, i): + """Get sinusoidal time for i-th augmented leapfrog step.""" + + return self.ts[i] + + def _construct_masks(self): + """Construct different binary masks for different time steps.""" + + self.masks = [] + for _ in range(self.n_steps): + idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] + mask = np.zeros((self.x_dim,)) + mask[idx] = 1. + mask = tf.constant(mask, dtype=tf.float32) + self.masks.append(mask[None, :]) + + def _get_mask(self, i): + """Get binary masks for i-th augmented leapfrog step.""" + + m = self.masks[i] + return m, 1. - m + + def kinetic(self, v): + """Compute the kinetic energy.""" + + return .5 * tf.reduce_sum(v**2, axis=1) + + def hamiltonian(self, position, momentum): + """Compute the overall Hamiltonian.""" + + return self.potential(position) + self.kinetic(momentum) + + def grad_potential(self, position, check_numerics=True): + """Get gradient of potential function at current location.""" + + if not tf.executing_eagerly(): + # TODO(lxuechen): Change this to tfe.gradients_function when it works + grad = tf.gradients(self.potential(position), position)[0] + else: + grad = tfe.gradients_function(self.potential)(position)[0] + + if check_numerics: + return tf.check_numerics(grad, message="gradient of potential") + + return grad + + +# Defining loss and grads for training +def compute_loss(x, dynamics, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out + + +def loss_and_grads(x, dynamics): + """Obtain loss value and gradients.""" + + with tf.GradientTape() as tape: + loss_val, x_out = compute_loss(x, dynamics) + + vars_ = dynamics.variables + dynamics.vars_not_in_layers + grads = tape.gradient(loss_val, vars_) + + return loss_val, grads, x_out + + +def warmup(dynamics, optimizer, n_iters=1, n_samples=200): + """Warmup optimization to reduce overhead.""" + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + for _ in range(n_iters): + _, grads, samples = loss_and_grads(samples, dynamics) + vars_ = dynamics.variables + dynamics.vars_not_in_layers + optimizer.apply_gradients(zip(grads, vars_)) + + +def fit(dynamics, + optimizer, + n_samples=200, + n_iters=5000, + verbose=True, + logdir=None): + """Fit L2HMC sampler with given log-likelihood function.""" + + if logdir: + summary_writer = tf.contrib.summary.create_file_writer(logdir) + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + tf.train.get_or_create_global_step() + for i in range(n_iters): + loss, grads, samples = loss_and_grads(samples, dynamics) + # TODO(lxuechen): Proper learning rate decay + grads_ = [grad * .96**(i // 1000) for grad in grads] + vars_ = dynamics.variables + dynamics.vars_not_in_layers + optimizer.apply_gradients( + zip(grads_, vars_), global_step=tf.train.get_global_step()) + + if verbose: + print("Iteration %d: loss %.4f" % (i, loss)) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) + + +def get_scg_energy_fn(): + """Get energy function for 2d strongly correlated Gaussian.""" + + # Avoid recreating tf constants on each invocation of gradients + mu = tf.constant([0., 0.]) + sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py new file mode 100644 index 00000000000000..522a7c9380131b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -0,0 +1,162 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc + + +def get_default_hparams(): + return tf.contrib.training.HParams( + x_dim=2, + n_samples=200, + n_steps=10, + eps=.1, + n_iters=5, + learning_rate=.001, + n_warmup_iters=1) + + +class L2hmcTest(tf.test.TestCase): + """Unit tests for l2hmc in both eager and graph mode.""" + + def testComputeLoss(self): + """Testing function l2hmc.compute_loss in both graph and eager mode.""" + + # Eager mode testing + hparams = get_default_hparams() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) + loss, x_out = l2hmc.compute_loss(samples, dynamics) + + # Check shape and numerical stability + self.assertEqual(x_out.shape, samples.shape) + self.assertEqual(loss.shape, []) + self.assertAllClose(loss.numpy(), loss.numpy(), rtol=1e-5) + + # Graph mode testing + with tf.Graph().as_default(): + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + loss, x_out = l2hmc.compute_loss(x, dynamics) + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + loss_np, x_out_np = sess.run([loss, x_out], feed_dict={x: samples}) + + # Check shape and numerical stability + self.assertEqual(x_out_np.shape, samples.shape) + self.assertEqual(loss_np.shape, ()) + self.assertAllClose(loss_np, loss_np, rtol=1e-5) + + +class L2hmcBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for l2hmc.""" + + def benchmarkEagerL2hmc(self): + """Benchmark Eager performance.""" + + hparams = get_default_hparams() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + # TODO(lxuechen): Add learning rate decay + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + + # Warmup to reduce initialization effect when timing + l2hmc.warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters) + + # Time + start_time = time.time() + l2hmc.fit( + dynamics, + optimizer, + n_samples=hparams.n_samples, + n_iters=hparams.n_iters) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + def benchmarkGraphL2hmc(self): + """Benchmark Graph performance.""" + + hparams = get_default_hparams() + with tf.Graph().as_default(): + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + loss, x_out = l2hmc.compute_loss(x, dynamics) + + global_step = tf.Variable(0., name="global_step", trainable=False) + learning_rate = tf.train.exponential_decay( + hparams.learning_rate, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + # Warmup to reduce initialization effect when timing + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + for _ in range(hparams.n_warmup_iters): + samples, _, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + + # Time + start_time = time.time() + for _ in range(hparams.n_iters): + samples, _, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="graph_train_%s" % ("gpu" + if tf.test.is_gpu_available() else "cpu"), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py new file mode 100644 index 00000000000000..c902e1f1f4862d --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -0,0 +1,86 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.eager as tfe + + +class GenericNet(tf.keras.Model): + """Generic neural net with different initialization scale based on input. + + Args: + x_dim: dimensionality of observed data + factor: factor of variance scaling initializer + n_hidden: number of hidden units + """ + + def __init__(self, x_dim, factor, n_hidden=10): + super(GenericNet, self).__init__() + + self.v_layer = _custom_dense(n_hidden, 1. / 3.) + self.x_layer = _custom_dense(n_hidden, factor / 3.) + self.t_layer = _custom_dense(n_hidden, 1. / 3.) + self.h_layer = _custom_dense(n_hidden) + + # Scale + self.scale_layer = _custom_dense(x_dim, .001) + self.coeff_scale = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) + # Translation + self.translation_layer = _custom_dense(x_dim, factor=.001) + # Transformation + self.transformation_layer = _custom_dense(x_dim, .001) + self.coeff_transformation = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), + name='coeff_transformation', + trainable=True) + # TODO(lxuechen): Remove this after model.add_weight is in place + self.vars_not_in_layers = [self.coeff_scale, self.coeff_transformation] + + def call(self, inputs): + v, x, t = inputs + h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t) + h = tf.nn.relu(h) + h = self.h_layer(h) + h = tf.nn.relu(h) + scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale) + translation = self.translation_layer(h) + transformation = ( + tf.nn.tanh(self.transformation_layer(h)) * tf.exp( + self.coeff_transformation)) + + return scale, translation, transformation + + +def _custom_dense(units, factor=1.): + """Custom dense layer with specified weight initialization.""" + + return tf.keras.layers.Dense( + units=units, + use_bias=True, + kernel_initializer=tf.contrib.layers.variance_scaling_initializer( + factor=factor * 2., mode='FAN_IN', uniform=False), + bias_initializer=tf.constant_initializer(0., dtype=tf.float32)) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 4e1380afb2e6e7..099b712fc06d1d 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): mse = lambda xs, ys: mean_square_loss(model, xs, ys) loss_and_grads = tfe.implicit_value_and_gradients(mse) - tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= @@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if verbose: print("Iteration %d: loss = %s" % (i, loss.numpy())) - optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + optimizer.apply_gradients(grads) if logdir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + tf.contrib.summary.scalar("loss", loss, step=i) + tf.contrib.summary.scalar("step", i, step=i) def synthetic_dataset(w, b, noise_level, batch_size, num_batches): @@ -119,7 +119,7 @@ def batch(_): def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() # Ground-truth constants. true_w = [[-2.0], [4.0], [1.0]] true_b = [0.5] diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index e53234b51a7dcc..2bc2fc2aa9150a 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -117,5 +117,5 @@ def benchmarkEagerLinearRegression(self): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb index 459f2f4a7d2afa..51d10a778413cf 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -7,17 +7,13 @@ "id": "U9i2Dsh-ziXr" }, "source": [ - "# Eager Execution Tutorial: Basics\n", + "# An introduction to TensorFlow\n", "\n", - "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n", + "This is an introductory tutorial for using TensorFlow. It will cover:\n", "\n", "* Importing required packages\n", - "* Enabling eager execution\n", - "* Creating and using TensorFlow Tensors and Variables\n", - "* Using TensorFlow interactively\n", - "* Using GPUs with eager execution enabled\n", - "\n", - "This notebook does *not* cover modeling topics, such as gradients." + "* Creating and using Tensors\n", + "* Using GPU acceleration\n" ] }, { @@ -27,9 +23,10 @@ "id": "z1JcS5iBXMRO" }, "source": [ - "# Step 1: Import Eager\n", + "## Import TensorFlow\n", "\n", - "The key imports for eager execution are the following:" + "To get started, import the `tensorflow` module and enable eager execution.\n", + "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later." ] }, { @@ -48,11 +45,9 @@ }, "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "import tensorflow.contrib.eager as tfe" + "tf.enable_eager_execution()" ] }, { @@ -62,10 +57,9 @@ "id": "H9UySOPLXdaw" }, "source": [ - "# Step 2: Enable eager execution\n", + "## Tensors\n", "\n", - "All future TensorFlow calls will execute the\n", - "underlying TensorFlow ops immediately:" + "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n" ] }, { @@ -77,60 +71,47 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 125 }, "colab_type": "code", - "id": "WPTUfGq6kJ5w" - }, - "outputs": [], - "source": [ - "tfe.enable_eager_execution()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "twBfWd5xyu_d" - }, - "source": [ - "# Step 3: Interactively Use TensorFlow!\n", - "\n", - "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n", - "\n", - "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1526420535530, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "colab_type": "code", - "id": "ngUe237Wt48W" + "id": "ngUe237Wt48W", + "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(3, shape=(), dtype=int32)\n", + "tf.Tensor([4 6], shape=(2,), dtype=int32)\n", + "tf.Tensor(25, shape=(), dtype=int32)\n", + "tf.Tensor(6, shape=(), dtype=int32)\n", + "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n", + "tf.Tensor(13, shape=(), dtype=int32)\n" + ] + } + ], "source": [ "print(tf.add(1, 2))\n", "print(tf.add([1, 2], [3, 4]))\n", "print(tf.square(5))\n", "print(tf.reduce_sum([1, 2, 3]))\n", "print(tf.encode_base64(\"hello world\"))\n", - "print(\"\")\n", - "\n", - "x = tf.constant(2)\n", - "y = tf.constant(3)\n", - "print(x * y + 1)\n", "\n", - "# Most TensorFlow ops are directly usable with eager execution, giving\n", - "# results immediately.\n", - "print(tf.contrib.signal.hamming_window(x * y + 1))" + "# Operator overloading is also supported\n", + "print(tf.square(2) + tf.square(3))" ] }, { @@ -140,7 +121,7 @@ "id": "IDY4WsYRhP81" }, "source": [ - "Numpy arrays are supported, too:" + "Each Tensor has a shape and a datatype" ] }, { @@ -151,178 +132,144 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, "colab_type": "code", - "id": "lCUWzso6mbqR" + "executionInfo": { + "elapsed": 215, + "status": "ok", + "timestamp": 1526420538162, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "srYWH1MdJNG7", + "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 2)\n", + "\u003cdtype: 'int32'\u003e\n" + ] + } + ], "source": [ - "import numpy as np\n", - "\n", - "ones = np.ones([3, 3])\n", - "\n", - "print(\"numpy 3x3 matrix of 1s:\")\n", - "print(ones)\n", - "print(\"\")\n", - "\n", - "print(\"Multiplied by 42:\")\n", - "print(tf.multiply(ones, 42))" + "x = tf.matmul([[1]], [[2, 3]])\n", + "print(x.shape)\n", + "print(x.dtype)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "PBNP8yTRfu_X" + "id": "eBPw8e8vrsom" }, "source": [ - "# Step 4: Define and Print TensorFlow Variables\n", + "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", "\n", - "To define TensorFlow variables, use the `get_variable()` function as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "3Twf_Rw-gQFM" - }, - "outputs": [], - "source": [ - "x = tf.get_variable(name=\"x\", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer)" + "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n", + "2. Tensors are immutable." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "45G7094TxsMb" + "id": "Dwi1tdW3JBw6" }, "source": [ - "## Printing TensorFlow Variables" + "### NumPy Compatibility\n", + "\n", + "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n", + "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n", + "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n", + "\n", + "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n", + "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 251 }, "colab_type": "code", - "id": "UJBJeZ5XxuwA" + "executionInfo": { + "elapsed": 238, + "status": "ok", + "timestamp": 1526420540562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "lCUWzso6mbqR", + "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorFlow operations convert numpy arrays to Tensors automatically\n", + "tf.Tensor(\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n", + "And NumPy operations convert Tensors to numpy arrays automatically\n", + "[[ 43. 43. 43.]\n", + " [ 43. 43. 43.]\n", + " [ 43. 43. 43.]]\n", + "The .numpy() method explicitly converts a Tensor to a numpy array\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]]\n" + ] + } + ], "source": [ - "# This does NOT print the Variable's actual value:\n", - "print(\"Printing a TensorFlow Variable:\")\n", - "print(x)\n", - "print(\"\")\n", + "import numpy as np\n", "\n", - "# A TensorFlow variable represents a reference to a tensor.\n", - "# The `read_value()` method provides access to the current value of the\n", - "# variable. Tensorflow Variables are automatically initialized according to the\n", - "# semantics defined in tf.get_variable().\n", - "print(\"Printing a TensorFlow Variable's value using .read_value():\")\n", - "print(x.read_value())\n", - "print(\"\")\n", + "ndarray = np.ones([3, 3])\n", "\n", - "print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n", - "print(x.read_value().numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "2njjWHcTpBEn" - }, - "source": [ - "## Changing a TensorFlow Variable's value\n", + "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n", + "tensor = tf.multiply(ndarray, 42)\n", + "print(tensor)\n", "\n", - "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "v3wr6Erbo_hB" - }, - "outputs": [], - "source": [ - "x.assign(42)\n", - "print(x.read_value())\n", "\n", - "x.assign_add(3)\n", - "print(x.read_value())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "uhtynjHVpTB5" - }, - "source": [ - "## Use a Variable just like any other Tensor" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "7PbktdnHoehR" - }, - "outputs": [], - "source": [ - "print(x + 3)\n", + "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n", + "print(np.add(tensor, 1))\n", "\n", - "# This code will broadcast the value across the list of numbers:\n", - "print(x * [1, 2, 4])" + "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", + "print(tensor.numpy())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "GVChqwlwy1SI" + "id": "PBNP8yTRfu_X" }, "source": [ - "# Step 5: Debug Errors with Instant Feedback\n", + "## GPU acceleration\n", "\n", - "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n", - "\n", - "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n", - "one being legal and the other being illegal, leading to a runtime error that is\n", - "raised immediately." + "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:" ] }, { @@ -334,125 +281,68 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, "colab_type": "code", - "id": "23ap04N0v4k0" - }, - "outputs": [], - "source": [ - "vector = tf.constant([10.0, 20.0, 30.0, 40.0])" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "FCUMsIYxxRRa" - }, - "outputs": [], - "source": [ - "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n", - "# arguments) are within the bound of `vector`.\n", - "print(tf.slice(vector, [1], [3]))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "executionInfo": { + "elapsed": 340, + "status": "ok", + "timestamp": 1526420543562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "colab_type": "code", - "id": "T8me2oCNxpFp" + "id": "3Twf_Rw-gQFM", + "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is there a GPU available: False\n", + "Is the Tensor on GPU #0: False\n" + ] + } + ], "source": [ - "# The following does NOT work, because the value of `size` (the 3rd\n", - "# argument) causes the indices to go out of the bounds of `vector`. The\n", - "# error is raised immediately.\n", - "try:\n", - " print(tf.slice(vector, [1], [4]))\n", - "except tf.OpError as e:\n", - " print(\"Caught error: %s\" % e)" + "x = tf.random_uniform([3, 3])\n", + "\n", + "print(\"Is there a GPU available: \"),\n", + "print(tf.test.is_gpu_available())\n", + "\n", + "print(\"Is the Tensor on GPU #0: \"),\n", + "print(x.device.endswith('GPU:0'))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "irxJhAgar84v" + "id": "vpgYzgVXW2Ud" }, "source": [ - "# Step 6: Using the GPU\n", + "### Device Names\n", "\n", - "You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n", - "\n", - "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster." + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host." ] }, { - "cell_type": "code", - "execution_count": 0, + "cell_type": "markdown", "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "7J4N9baqaKCL" + "colab_type": "text", + "id": "ZWZQCimzuqyP" }, - "outputs": [], "source": [ - "# The example code from here on will work only if your notebook\n", - "# is running on a machine with a functional CUDA GPU. The following\n", - "# line checks that.\n", - "is_gpu_available = tfe.num_gpus() \u003e 0\n", "\n", - "# Create some Tensors\n", - "SIZE = 1000\n", - "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", "\n", - "if is_gpu_available:\n", - " gpu_tensor = cpu_tensor.gpu()\n", - "else:\n", - " print(\"GPU not available.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "4E-2n7VbzY1n" - }, - "outputs": [], - "source": [ - "# Time a CPU-based matrix multiplication\n", + "### Explicit Device Placement\n", "\n", - "print(\"Time to conduct matmul on CPU:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)" + "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:" ] }, { @@ -463,65 +353,73 @@ "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, "colab_type": "code", - "id": "vbSFW-T5zhZF" + "executionInfo": { + "elapsed": 1762, + "status": "ok", + "timestamp": 1526420547562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "RjkNZTuauy-Q", + "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "On CPU:\n", + "10 loops, best of 3: 35.8 ms per loop\n" + ] + } + ], "source": [ - "# Time GPU-based matrix multiplications.\n", + "def time_matmul(x):\n", + " %timeit tf.matmul(x, x)\n", "\n", - "if is_gpu_available:\n", - " # First use of the GPU will be slow:\n", - " print(\"Time to conduct first matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)\n", - " print()\n", + "# Force execution on CPU\n", + "print(\"On CPU:\")\n", + "with tf.device(\"CPU:0\"):\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"CPU:0\")\n", + " time_matmul(x)\n", "\n", - " # Subsequent uses are much faster:\n", - " print(\"Time to conduct second matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" + "# Force execution on GPU #0 if available\n", + "if tf.test.is_gpu_available():\n", + " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"GPU:0\")\n", + " time_matmul(x)" ] }, { - "cell_type": "code", - "execution_count": 0, + "cell_type": "markdown", "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "E5pIOe3Rz7iW" + "colab_type": "text", + "id": "YEOJTNiOvnpQ" }, - "outputs": [], "source": [ - "# Second timing demo for GPUs, after it has been used once:\n", - "\n", - "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", - "print(\"Time to conduct CPU matmul:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)\n", - "print()\n", + "## Next Steps\n", "\n", - "if is_gpu_available:\n", - " gpu_tensor = cpu_tensor.gpu()\n", - " print(\"Time to conduct GPU matmul:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" + "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n", + "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks." ] } ], "metadata": { "colab": { + "collapsed_sections": [], "default_view": {}, - "name": "Eager Execution Tutorial: Basics", - "provenance": [ - { - "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg", - "timestamp": 1504118841551 - } - ], + "name": "TensorFlow: An introduction", + "provenance": [], "version": "0.3.2", "views": {} } diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb index e6c7c117333e1e..9c1af9c2084bac 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb @@ -7,12 +7,9 @@ "id": "vDJ4XzMqodTy" }, "source": [ - "# Eager Execution: Working with Gradients\n", + "# Automatic Differentiation\n", "\n", - "This notebook demonstrates:\n", - "\n", - "* How to get gradients using TensorFlow's eager execution capabilities\n", - "* How to apply the gradients so you can update your variables" + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." ] }, { @@ -22,7 +19,7 @@ "id": "GQJysDM__Qb0" }, "source": [ - "# Setup: Import eager and enable eager execution.\n" + "## Setup\n" ] }, { @@ -40,14 +37,10 @@ }, "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "import tensorflow.contrib.eager as tfe\n", - "\n", - "# Enable eager execution.\n", - "tfe.enable_eager_execution()" + "tfe = tf.contrib.eager # Shorthand for some symbols" ] }, { @@ -57,28 +50,15 @@ "id": "1CLWJl0QliB0" }, "source": [ - "# Fitting a Simple Linear Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Step 1: Synthesize some data\n", + "## Derivatives of a function\n", "\n", - "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n", - "\n", - "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model." + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -86,170 +66,115 @@ } }, "colab_type": "code", - "id": "rQsdCg9PfIL-" + "id": "9FViq92UX7P8" }, "outputs": [], "source": [ - "# The constants we'll try to fit our variables to:\n", - "true_w = 3\n", - "true_b = 2\n", + "from math import pi\n", "\n", - "NUM_EXAMPLES = 1000\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", "\n", - "# Our inputs:\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", + "assert f(pi/2).numpy() == 1.0\n", "\n", - "# Our labels, with noise:\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", - "labels = inputs * true_w + true_b + noise" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 360, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 127, - "status": "ok", - "timestamp": 1505502830690, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "O4lsC4ckAcar", - "outputId": "2f760690-cafb-4777-b970-91d839f99faf" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAesAAAFXCAYAAACC+2avAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXt8VPWd99+TK7kykxtJQIebqZfaqogtrhKNa1ooEKl9\nCrpVn9ZNW6x9VWsbCi7aVUt01NZ9tq21KVZlFey2YkQNohhj3QWK2liCF5RIBCc3yEwmIZnMTOY8\nf/zmzJwzSSBAYibh+369eIU5c87vXLh8zvdu0TRNQxAEQRCEmCVurC9AEARBEISjI2ItCIIgCDGO\niLUgCIIgxDgi1oIgCIIQ44hYC4IgCEKMI2ItCIIgCDHOiIj16tWrufjii1m8eHF4269//Wvmz5/P\n0qVLWbp0Ka+//vpInEoQBEEQTjksI1Fn/eabb5KWlkZFRQWbN28GlFinpaXx7W9/+6QvUhAEQRBO\nZUbEsr7wwgvJzMwcsF36rQiCIAjCyTOqMesnn3ySsrIybr/9drq6ukbzVIIgCIIwYRk1sb722mt5\n5ZVXqK6uJicnh8rKytE6lSAIgiBMaEZNrLOysrBYLAB885vfZPfu3cc8RtzmgiAIgjCQhJFaKFpo\n29vbyc3NBeDll1+mqKjomGtYLBba2yeuuzw3N0Pubxwzke9vIt8byP2Nd06F+zsWIyLWt912Gzt3\n7sTtdnPZZZfxwx/+kJ07d/Lee+8RFxfH1KlTueuuu0biVIIgCIJwyjEiYv3ggw8O2Hb11VePxNKC\nIAiCcMojHcwEQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBOHXo6HCzcmUt\nTU2Z2O2dOBwl2GzWsb6smEfEWhAEQfjMWLmylurq6wAL9fUasJ6qqqVjfVkxj7jBBUEQhM+MpqZM\nwBL6ZAl9Fo6FiLUgCILwmWG3dwJa6JOG3e4Zy8sZN4gbXBAEQfjMcDhKgPWhmLUHh+Pysb6kcYGI\ntSAIgvCZYbNZJUZ9AogbXBAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFr\nQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhx\nRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBODE6OtysXFmL02mjsLADh6MEm806rGOamjKx2zuHdYww\n9oyIWK9evZrXXnuN7OxsNm/eDEBnZye33norn376KdOmTeOhhx4iIyNjJE4nCIIgACtX1lJdfR1g\nATRgPVVVS037RIuzz9dDTc33AQv19YMfI8QeI+IG//rXv866detM237/+98zb948XnrpJb70pS/x\nyCOPjMSpBEEQhBBNTZkooQawhD6b0QW9vv4qqquvZ/v27mMeI8QeIyLWF154IZmZ5j/wbdu2sXSp\neltbunQpr7zyykicShAEQQhht3eiLGoADbvdM2CfaEGH7GMeI8Qeoxaz7ujoICcnB4Dc3FxcLtdo\nnUoQBOGUxOEoAdaHYtYuHI7LAbPru61tD1AM2AAXkyY5sVr/CBxi3rwMHI5FY3cDwrCJuQSz3NyJ\nHdeW+xvfTOT7m8j3BhPz/uLi+klOTgQgOTmBnJwMsrIyuPnm5w2x7DKmTbuPgoJzaG7ew8GDt6PH\nuDMyNpKdncFNNz3Pxx+nM2NGFw8/vJCsrNhLOJuIf37Hw6iJdXZ2NocOHSInJ4f29naysrKGdVx7\ne9doXdKYk5ubIfc3jpnI9zeR7w0m7v2Vlz8XFuVduzT6+lSy2N69KRhd3zk5Z/LCC5dRWtrPwYOR\n7Xv3pnDjjYOvEUtM1D8/neG8iIxYnbWmaabPJSUlPPPMMwBs2rSJK664YqROJQiCIDB0gtlQsezB\ntg8nSU0Ye0bEsr7tttvYuXMnbrebyy67jB/+8Id897vf5Uc/+hF/+ctfKCws5D/+4z9G4lSCIAhC\nCLu9M1R+pdzauijrsWxVruUJx7JXrZrDrl2VuFzTsNkOsnr1EtaufWvQNYTYYkTE+sEHHxx0+2OP\nPTYSywuCIAiDMFSCmc1mHdSVXVn5Nk7nKsBCb6/G2rXrhxR2IbaIuQQzQRAEYXjoojxYTHewTmWD\nubyHEnYhthCxFgRBmIAYu5vpncrsdk1c3uMUEWtBEIQYYai+3SfSz3swK/rpp+cgLu/xiYi1IAhC\njDCYNVxVtXTI7UdjsOQzcXmPX0SsBUEQYoShyqhOpLxKEscmFiLWgiAIMcJQpVjm7S7a2t6ltJSw\nS3ywphojYUXLOM3YQcRaEAQhRhjKGjZub2t7F6dzFU6nconX1T1AaelU7r770mEL6XBF+ETc78Lo\nIGItCIIQIwxlDRu3l5aC0xlxibvdZ/KnPy06rjahwxVh6W4WO4xYu1FBEARh9IluGQpqPvXxCOlw\nRXg4IziFzwaxrAVBEMYRuku8ttaPx5MCLAQ0CgoODXuNoWLjQ51LktTGHhFrQRCEcYTuEr/hhv+i\npiYBeBY4xNtvu3G53ANiz4PFp4crwlLqFTuIWAuCIIxDmpsLgF5gOWChtVWjomJg7Hmo+LSI8PhC\nxFoQBGEcEG0hFxT4qK+fwrFiz5IkNjGQBDNBEITPiI4ON+Xlmygt3UZ5+TO4XO5hf69byPX1V1Fd\nfT0QoLBwN8dKAJMksYmBWNaCIAifEdEu6V27KqmtvS4cZz5aSVW0hdzcXEBt7SIqKgaOyDQiSWIT\nAxFrQRCEz4howXU6P09FRe2Qgmx0WRcUNFNf/xSQAXgoKPAcdUSmznCTxKRbWWwjbnBBEIRBOJbL\n+kQwu6RdwLts3Up4/aO7rBOBa4DFwLWhzyNHtJu9oqJ2RNcXTg6xrAVBEAZhNFptOhwl7NpVidP5\neeBdYCW9vRaqq9X6DkcJfX3r2LEjDjiMz5cWLsdqbs7B7AbPOalriUYS0WIbsawFQRAG4XjFaziW\nuM1mpbb2OsrK3KSkFA5Y32azkpychNv9bdzun1JTsyJs4UZb3QUFLeHzLVv21Elb/pKIFtuIZS0I\ngjAIw+3ypROxxDupr3+RurqXKS6OHxD71WPI5eXPhCxq8/pDvSREJ4r5fAkmy/94eoMb0WPV+/Yl\nUFhYSXZ2ETNn9kgiWowhYi0IgjAIx5tFHRHZGmABbvcWqqvT2LXrCWprrx+QrOVwlODzPcL27V1A\nNj5ffzhuPdhLQnSiWGnpNoZr+R8teczo7geNuXNlslYsImItCIIwCMfbajMisunAFvTOYk7n4kE7\ni9lsVpKSUnG7vwdYqKnRSEpaf9SXBKPotrXtAcoYjuV/PCVhEquOTUSsBUEQBmEoa3So7brI1tW1\n4HafyXAEcDChPNpLgtkKLqawsJK8vLMpKurl7ruHtvyPJsjH6+4XxgYRa0EQhEEYyhodarsusi6X\nm8svfwKnczFDCaAu+Pv3t6CSuoYnlGbRtZGXdzZbt15Bbm4GH3xwgPLyTYO6uo8myNI0ZXwgYi0I\nQswylo06IsLoBmrC9dD79iUQbaV2dLi55ZaXQiVXh5gzJ5kvfnEdzc052O0eVq26wCSkPl8PNTXf\nBzqBDVitXoqLE44plEcT3aO5uo8myDJZa3wgYi0IQswyGrXOwyXSMcwJ3Bauhy4srCTaGl65spYt\nW24Mb9u2bQNlZQG2br0CgPLyTab7sFofCO1rBa5l+vRnqaq6Ilz+NdTLiS66+/bF09HRRGNjEeXl\nz/Doo2VHdXWLII9/RKwFQYhZxjb5Se8Y9rzpGrKzi5g7V1mpBQUt+HwJvPZaErABWIgS4AyamvrD\nK0XfB2SjOphtAdJoa9uDyzXnmC8nuuhef/3TNDSswum0sHu3xo03PoHdjsSeJzAi1oIgxCxjmfwU\n6RjWhdGSnjmzJyygRotZ7fMgUAD00NbWTmkphnGWkTXmzQvyzjsP43SuwpgxPtyXE+Vuj+xXV6ex\nY8cVSOx54iJiLQhCzDKWyU+RF4WFDBVXHmgxfw5YRHLy7TidP8XptFFfr7Fgwe8oKzPex1dYtuwt\nnM7IsVu3gs023HKsQxhfIOCQuLonOCLWgiDELCMtQEdLWIv+bvXqOUReFAI4HFcOSG6LtvyhO/T7\n01Eu7nSgiwMHMnn11SVHPba3N5He3psoLKwkK6uIjo697Ntnp7z8mQGx63nz0qmp2YCawNXF/Pky\nHWuiI2ItCMIpw2Ax4fvuu5yVK2upqwvgdicDl1FfP5nBktmGEvTaWj8eTwrKCteAg4BqdgIaHR2V\nA65F9xps3Qq9vX7Uf8dv0NOTwFlnfUJDw3SczgwaGjz4fM/z+OPfCl8DJGG1eoGDzJuXwaOPXkN/\n/4BTCBMIEWtBEE4ZBosJR7fbhI3ANYPGi4dKALvhhv+ipkYD/gDk4PMlocqyAGpwuQoHWMjmHuEp\nqGQ2C273Il5//U7g1vA1bd/+AKCEuqRkfTjWDarrWVaWdch51sLEQMRaEIRThsES1gbGndOJjhfr\nFvXWrRj27WTz5oMUFf03gcCh0PbbAQuapqGywy3ActMYzGhr3eEooa7uZdzuyDX090+PuqZsQL0s\nqPGa0h70VEPEWhCEU4bBEtYqKl41CbjV+j7FxS5TIlnEot5AJLHrRYLBVSGR1YDHMQusDzWF+OjC\narNZ+fKX+9myJXINOTkHaWszZ4+D7hno5ni6ngkTAxFrQRBOGWw2azhG3dSUSUXFq1GJZB4cjuUD\nEsn27YtHucctwL1YLFPQNLMQQzvmDG0L8Klp2/vvv0lJyRFmzQqYXOIWSwD1IqASxs49N430dHP2\nOOiegSWha0mjsLABh+O6UXteQuwgYi0IQkwQnby1atUcKivfHvFWoyfSFa2jowmIxImTk9fg9Z6F\nWZyt6CIKO1BW9W3AfcDZwBG83ttoaNhCQ8P1pvM2NxcAV4XPd/jws2zYcMWA61Cegc2hZ+LG4bgO\nTYNlyzawd2/KZ96SVfjsELEWBCEmiBbRXbsqw4lUx9NqdLDyrNzcjPD3J9IVbfLkqTidG9FLsU47\nrZDZsz1s3/4A3d2ZBAKTQmumAW8CPwXeAGzAOcBiw2rpA8473OYvg5WyRbcy/SxbsgqfHSLWgiDE\nBNEi6nJN40QSqQaznJ999vrw98cSxsHEvrPzU4yW9ZEjlfzqV9excmUtjY2pHD78AR5PBt3de1Ad\nzGxEOp8ZO6C5gJ1AO3v2OLnhBicPPbT4pJq/yDzqUwMRa0EQYoJoEbXZDtLbG/mcn38oPOSioKAZ\nSAxNtTK7fo8lXqtWzWHnzntoa8sjPv4Q3d3puFzu8PGDiX12dpGp21h2dtGAkq/k5DXARcAelCir\nzmeZmR34fHfg9Z4P7AXuBiz4/VqosckLJCWlnrC7X+ZRnxqIWAuCMGocz4jLaOty9eolrF0b+ezz\n+amuVpOt1DSsaxisucn+/QHgSeBrwOQB4lVZ+TYtLbOAawgGLWzbplFREXEdDyb2M2d2snu3uT94\n9H59fRcBS1Au7/tISSmktBQcjjKWLXuL+vqrgM2mYyCD7ds/xe3+HsNxYw/2PB2OEpKTN4Zi1tIT\nfKIiYi0IwqhxPMlcg8Vjq6rs4d+Xlm4jInQZGEVv61bYtesJnM6bUC5oNYayuHjKAPFSmd1O1DSt\nLmAhdXUBGhubqKx8m/37W4gujVq1ag67dlXick3DZjvA6tVl3HnndswJZkfC1wNTuOyyI1RVqa5j\nEeu3K+qYLlQN9fDc2EM9z6efvkaaokxwRKwFQRg1oq3PfftSB8xr1jSGZX2b3b0ejKKn+mqvRu8+\nBhamTz+DqqqBGdXRmd2wAbd7Epde+if8/n9HdR4zD+6oqKgNJ7v19mosXVpJd7debuUDmoHvh86g\nAclApP+ncQ71oUP30NMzlbi4w8yblw4khLqfHduNLfHpUxcRa0EQRo3oeGpHx14aGswZ3sCg1uJQ\nfbhVL20f8ERo3URgAZFsbDia6EXHn5XYXoXfHwh9tqLizVU0NZ1BRcWrNDamYRTJSBexxSjX9lVA\nDSrT+wPgX2lufi18zqMNJHG53CQlDS+5TOLTpy4i1oIgjBrRceh9++wGoeykrq6Vvr4pKAt1IWAN\nW4tDuXxVL20Vu1ax6eXo4lVY2EBeXvCoojdz5pFQ/LkTeDG09QXgI4zdydzun1Bfr85dWLiWgS5v\njYgrezLqheFFIAd4gYKCwYV0sLjzcEutxnJkqDC2iFgLgjBqRFuU5eXP0NBgFkTzAI3lYWtxKJev\nUbCUIK4LZYV7cDiuO2YmtX58bW0LHs9PDedfh8redtHb68bvj8S0s7KmM3euOmdb27s4nStQrvh7\ngQySklbS359Mf/9cVDvQBcBfBj3/8cTxT0bYhYmFiLUgCJ8ZRqHdv99rGl6RkuKntHR92FocyuUb\n/QJgFLSKilePWfqkH19auo36eqM73A/00tX1KZr2C4wx7Vmz+sOu+VtvPURPzyaOHPkYv//HgA2f\nL5Kdrr94NDfnDCq2xxN3PpFua8LERMRaEITPDKPQKnd2RIxLSzEJ0VAu32gB7Orq5NVXf4guaD7f\nOh5/fNmAc+vH7dsXT0dHE93d8Zhd25OBa9G05zCKqdXqxeG4ElDiWVNzI2ZvwDVEZ6dDGna7e1Cx\ntdu1YcedJaFM0BGxFgRhTDhW/HWopKxoAUxMXItR0LZvjxtwzOHDxjnQG1HZ4CrrOzPTS3d3C8Hg\nTaG9zVOtiosThmy4EkloM2enT5q0i9Wrl/G9731EtNg+/XT04BBJKBOOjYi1IAhjwtEypI9GtGD2\n9+dgtpAPDzjmpptqQhncnahJWJF4dFzcM+Tnazidk0N7LwDuwGqdQXFxAqtWXRAuN2tr2wMUo2q5\nXUyatAuLxU1m5l407S7a2s5HDez4MZdc8is0bSrwGCpbXDVoOZ77loQyQUfEWhCEMWG43c2i9yso\n8Jmszby8Vlpa9PGSrfT2urDbN2GzHWDTpjJmzLDz8cdqAIfK1r4NYzwaDvPHP17OkiVr6OubgcXy\nMZdcks4f/nAlmgaXXfY4LS0/ALYA55KYuJaUlCx6erLwes8EvkZv72Ss1gdQHcwUfv+Foc9DN2g5\nFif6QiNMPESsBUEYE4abPBW934IFv+OKKx6hrs5CMHiY/v4errjiEIcPp/L++014vSo5rLfXRXHx\nL5k9+4vs21cPfA7owezG9jFvXjq//e1H9PWpnt2appGVtR5Ng5KS9bS0fAEl1KpEzO/vxu83J5Op\nuHU2Q3U0G6pBiyAMFxFrQRDGhOEmT0Xv19xcQFvbuwQCqrlKe7vGe+9VUl9/RSimq++7Ba/3Lhoa\nLMDVKFFNxyioU6Y0AqezdStE13qvXFkbcp13o4+1VEQnk6k1580LkpS0nrq6AG53K8aOZhJrFk6W\nURfrkpIS0tPTiYuLIyEhgT//+c+jfUpBEMaI4xncEZ08ZZyqZTx2sCSrDz4wj89U4zTBZjsQmtTV\nCfQxUFQvJTPzfk4/fSYdHXvp7k4JZXfrDVKeBRIpKPDQ1FRApGb6d6huZQNbnVqt71Nc7MLh+Ao2\nmxWXy80ttzzP9u1/ALKZNy+Iw/GVkXvIwinJqIu1xWJh/fr1TJ48+dg7C4IwrhnKtR0t4qtWzaG7\n20Ni4lr6+3PIzm7i7bcTaWubA3RTX78E2ExV1VJWrDiDmprb8fnsQCtvvHGE9HSLaXympn1ISclL\n+P3dJCSsIRBIAaZjdkt3A5NJTw8wa1ZPqO3p86HvazDXSa8LvSQsAZ4DJmOxrCEjYzpz5/aQlGRs\nxLLc9EJis1l5/PFvfRaPWziFGHWx1jSNYDA42qcRBCEGGMq1HS3iu3ZV4nTmoeK8GbS3dwA/wxgH\nbmrKZN++JhYufJ5gMNKk5PDhDcTH/4NJk9agaTPw+/fh9f6UhgYbEXd3MlBCxPX9D1QG94N0d/tp\nbEwNradPwUrH7GrPCZVYbWbfvgQ6OtxkZ5/HzJlHcDiWhsW5o8NNRcXwPAmCcDJ8Jpb1jTfeiMVi\nYdmyZXzzm98c7VMKgjBGDFUXHC3iym3dBhgbjAxsKnL11c8RDH4u6rsM+vtn0t9fTmFhJU7nl1FC\nrH/vB3YDS1HWsga8AawGLHg8GocP672+F6Ji1fuARabr1jOxy8s30dCwCqfTEuopHvEWqNptFdeu\nr19CX99fSE5OEvEWRpxRF+uNGzeSm5tLR0cH3/72t5k5cyYXXnjhaJ9WEIQxYKi64GgRV7HlwtBn\nN7AntIKKERcWNrBq1RIuvvgg8CEDZ0C3As/jdPaihHmx4ftEYAZKhDOIWM+6ld1FZmYuUGmYnnUd\ncB9wNoWFDTgc14Xv6WjeAn1spr7+jh1xuN3SHlQYeUZdrHNzcwHIysriyiuvZPfu3UcV69zcjNG+\npDFF7m98M5Hv70Tv7fBhNzfdVMPHH6czY0YXjz66hKwsszX56KNlrFixMbRPN2vXfotLLnmMlhYN\nFS+OuMCnTbuPd965iRUraggGVwGfALejYtDtKCs6H7gUaADsqIEaBSj394LQmkYSME7n6u6+j6lT\nz8XpXBzeIzW1kEWLjvDwwzeZrr+oqMf0olFU1EtubgZOp41ob4DF8qlpm9Np+8z+zkzkv5sw8e/v\nWIyqWPf29hIMBklLS6Onp4c33niDm2+++ajHtLd3jeYljSm5uRlyf+OYiXx/J3Nv5eXPhePRu3Zp\n9PUNZk3G8+tfLzJtOf/8PGpqNgD6HGkACzk5Z9LfH8/evSmh7XagAvgD8AVU/PkHRIu8Emz98/6o\n78wtSW222WRkfAw8hbK+D5OW9iF7987lO9+pNrmvf/zjL/DGG5W4XNOw2Q5w221ltLd3UVjYgdHi\nLyxs4ItftFJTY9zm+kz+zkzkv5twatzfsRhVsT506BA333wzFouF/v5+Fi9ezCWXXDKapxQEYYQY\nbhnWiQ6baG4uQLXhfAyj6L3zzm7OO28PZ51lrImeDEwFFpGfX09Ly2Sik8JU0xPlyk5IsBIIGL/L\nMp1j5swedu7sBH4Y3tbevoH29qvCCXB5eWdjt3fi8/nD7u7eXo21a9dTVWUfxOWvXOdJSdIeVBh5\nRlWsTzvtNKqrq0fzFIIgjBLD7TA2VFLZYOValZVvG9qGHgkd14MxvqxpBTidNxIM3kNZ2XoaG1Np\nb3+Pnh6NuLg/cs45mZx//jq2b+/A7Y4khcFeVCMSK3APA/uFb8Bq9VJcnIDDcTnnnVdLdOKa/nun\n8/M4nUuor9ewWv/IYC8jQ7UClRi1MBpIBzNBEAZluBbzUEllt9zyElu2qGzv+nqNF164g0DgrvDn\nBQvWsWDBOmpqfIbzgJpkZaGz024Yp9kTfnHYts1Fbu6DdHUB3IPFkk1c3Kf09/8EJdQagUAPkYSy\nbiwWK0uWBHA4rgx7B1SSmwvVSjQNleR2KcqKj7QKhUMYhV+6kQljwcBZcoIgCCiLWYkUgMb+/R9S\nXv4MLpfbtJ/NZuW++y7Hbvewb18ql1/+BCUlz7FtWwuqMxiAhUDAjlH8X3stgaSkROLjm4GvEmnr\n+S7gQtP2hs9lfnF4hvb2NPr7LwJmoWnX0N9fBNRgtT5KYWElKhltOSpLfDmTJ3cDsGzZW+F72LSp\njEmTfhnabwnwMzIz/5NJk+5AJao9BbiYNy+DsrL1nHfes5SVrT8h13ZHh5vy8k2Ulm4b9BkKwrEQ\ny1oQhEFxOEro61vHK68ECAS6cLu7qK6eyvbtf+Svf/22KX5tdJmDhtO5EZXBvQG4FiX6jRgt1N7e\nZKqrl2Ox3INxUIYS2Pvwem+jokJ1MTO72l1EN1BRMenFTJv2JJ98YkH913Ynykp2091dSHV1PHBZ\nKCb9MHl5ZzNp0gy83sgLRFzcJLzefwuvXVhYyUMPXXfStdLDDSkIwlCIWAuCMCg2m5Xk5CQCAWPj\nko20ta3hllseISkpNRx/bmxMY2AfbjXVCjaj6qKTgcdR86SnAN9ATbmaiu76jhxfAPyOLVuslJc/\nw+rVc9Bd7Q0NGVHJY16U21qjo6MJj2cFSvwvBHYBd4X214UdnE7V5ASexBzbzjZdR17e2SPS1ORE\nk/AEQUfEWhCEIYkWGV2Et2/vwu3+HrqlOGXKHShhzkANuuhFiV8Lqq92I5oWaRmqLG5QrmY/8L+Y\nG5skAT+jr+9RqqsTqav7G8XF8Tz99Bxuuul5tm0zCmwLaWk9/PM/r6exsQin02ilw8Dr7yISz+4h\nM/NeZs48C7vdg8/Xbyq9Gqn49FBJeIIwXESsBeEURs/YdjptFBZ2DCjPihYZFVceaIEePpyAcRBG\nQsJdBAIbUNnZk5k82YXbbRRNN/BrlKtcubaTk9fg988kGExBNTbRXd7fwe22UF2t3MdJSQA/B+ag\nLOrvEx9fFWoN+gy7dycSEWM9acyGPiHL6+3E6707fK3p6ZVs3apmTbtc7lEpvRoqCU8QhouItSCc\nwkTHmqNjqQ5HCUeOPMJrr2kEAk7i4rK45JKH2Lv3CGoaVTdwMYFAPkbxPuusc5g5s4emptcGtVhV\n1na84RgbfX3TgY+BuahxlQuAHAa6jzNRLvUl4evs6ckIX++WLY/Q16eL8SLgDmy2WcyfH4fDsZxv\nfGMnu3dH1szOLgqvM1Q51skyWusKpw4i1oJwCjBUg5NjxVJtNitPPfUvpm3l5ZtoabkFXXgtltvR\ntHOIxH5d7NnzNh99VITNtodHHilD0+CVV+7E75+NillfS2Lievx+o4A3Ar8wrZuRkYHHE+0+1qiv\nbzadr7+8jtPXAAAgAElEQVT/AKWl27DbO5k58wzee89oxV/A7NkJVFVdBsDMmUdCAzkiDVIEIdYR\nsRaEcc5wOo0NlY18tFjqcAVe07JQFnEVqnd3F8FgJb29quPX0qWVzJ07Db//34kI8wbmz8/kf/7n\nDrzeuSh39udN606ePJudO6/kllseYfv2LiAbn6+fn/98Hps3dxAMRlzdmvYL6uvVveXn/wJz0th7\nfPRRIeXlz+BwlIhLWhiXiFgLwjgnWojr6h6guDjPJNpDWdC6cKmYtcskXMMVeNU0pNLw+bemc7lc\nBQPOHxfnYc8eNxbLLJQrfSHK9R1Zd9IkJwBJSanhZLaaGo2kpPV85Svp1NQsN5wzsnZPTz6RjmgN\nwApcLls45l1VtXTYLunhtlwVhNFGxFoQxjnRQuh2n0l19SKM8eehLGg9ljrYoITIum6ghq1bCZdR\n9fWtY8eOOI4c2Y/ff6bp/Kq1Z+RcmtaI3T47dP5O4EWCwSAtLbOAr6FqoTcCC4iLu51g8MvAEVpa\nfkBFxeZBXzSefnoO77xTidM5DeVWj2SS9/a2ohLXQL0IbEHPAt+3L/64BFjqo4VYQcRaEMY5g2ds\nm+PPJ+L6LShopr7+KVRJViK9vUuorp7MSy+t4Z/+yYbb/R3gEeAAZrezF2OrT6/XRm1tVyi2nQXc\nZth3I3ANKSl9XHbZ0/zP/2Tg8fhQXczcvPiik/nzC03rt7Q0cMstzbhc01D/hX0/tE4asAO/f7ph\n//0YG6h89NEdfPnLTtzunzAcAZb6aCFWELEWhHGOLsR1dQHc7kkol7Kyns1WpMbTT885DjduIsZy\nLF1Yvd6LeP31N1EW60qUtfwEaiBHO6qLsdFFvQGP59rQPtEzoPuA5wgEGnj11UmGLO6rgY34/d+n\noeEOpky5k9bWTCCHlpZp1NT0oCzq7xPp7f0ukAt8E3gUVfaVazqf13s+Xm8iwxVgqY8WYgURa0EY\n5+iubJfLTUVFbbhcyuG4nIqKod24RiEvKurh7rsvNQl5c7O5bEpZyhpwhP5+O5GuY1ZUE5PridRG\n3wecg5o9nQc0AR+iyq6Mk7KSgCX4/Xpf8IENWDyeWSQnt2O2yO8AZgPVqBeEbuAW8vP/MzQ+MxX4\nDip2bbT6+1CW//AEWJLRhFhBxFoQJgjRtbwdHW7q6gIMZUVGx2O3blWJafooy/37WzAL3QcoUfwq\nmnY/0EYkVmxsF2pDCfWi0P7LUeJ9F8oK3xD6eRi4OXSMGo9pPp9qwKJp+4AZmIXc+HKgERdXyeLF\nm1m9+uusXbuerVuht9eC8jJsJDXVj9V6EKdzRegY87jM4T5TQRgrRKwFYYKycmUtbncyRgFsa3sX\nl2vOoCVYbvdsqqt7eeGFZwkEfoBqevI4CQkHiY930dcHyjLdiKaBGtDxBMpSNSd5KYu6m0gnsjwi\nVvi1obUnh36BalGqhFUJ//+GjrkTTZvOpEmfmu7DYslG0yLXnpmZHxbVqio75eXPhLK/rcByFi3a\nyN13XxdOWLPbzeMyBSHWEbEWhAmKEuPLiCR7fYDTuWKISVYaKuY7hUAgDlV+dROwhUDgCwQC/wvM\nAv7VsP8TKAs3A9iHxXIP8fF5BAKTQts04K+AhylTPqa11Xiut4F+LJZ7yMgoJDHxQw4f/hg4HdUi\nVG9peit9fRZaWlwUFlaSl3c2druH7m6LqT/4vHlB071Hu68ffngJ/f3xYiUL4xYRa0GYYOix6P37\nA8ALRMqjGgAL+/bFU16+iX37EigsrKS7Ow+PJxXoR8V6VwHPM3Bs5YOYXdFeIq7opSQn34HFkkIg\ncD1qulYkOe3ccx8hGFxDe7sVSEFlmF+IpnnxeBaQmPhbIuVWoCzvFoyu9by8s009vCsqjLHkr5ie\nQbT7OitrYGmaIIwnRKwFYYKgi3RdXWu4NElZqA8CU1GZ0y/S0dFEQ8Oq8PcLFqwjI8PCn/6Ui7KI\nLaj4cXTCVzbmmHIbcC9wJtCL16u3En0KJeSRY1991UJcXAoqSWwjymqPZJn7/YVRax9BxbUHTwST\nWLJwqiFiLQgThEjC2POYRfZzKMsYrFYv2dlFoVnO6vvm5hxefPEqtmypxOPJRAnkQuBhzHHoD4B7\nUOVQn6Cs9dNQ/43obvR7UWIcB/wB1VAlm2CwjWBwNsYs78j1pQE+EhPvxO+/ECXUXwX+jB7DLixs\nwOG4bljPYbCmJ7m5GcN/kIIQg4hYC8I4xihMjY3NKGs0Oqv6g9C2i0lNbaGpqdXwvYuWlgYuuiie\n1FQfHs8hIsM0koG1wNmo+dR+4N8M696DuQ57PxExVo1U4EbD9/eGfkZf35vArcyf/0fee68Bl6uA\nYPBBEhPzSEhwMW9eBg89dJ0pGexoXcgG6zr27LPXj+RjF4TPHBFrQRjHDBxxuQFlFW/AYulE0yaj\nksImAz/B6ZyDsnZ19/X7tLTcTkuLGiepYtjxeDyRrl9qNvVUVDmWbhF3hn4+jxLfhahY9FMoV/jn\nQvsaLegzQ9fXgYpPzwQ+4owzTufsszfj82XidN4aPm9f30ZgOe+8U3nU+46uH5euY8JEJG6sL0AQ\nhBMnWpiURfssYCEpCVSZlJVIzPkaVLz4Z6i4snnSVV7e2cTFTQltawLuIxA4DTW+8gPUCwGooRv/\nhnKTXwO8SE5OV+j35aiMbo9hfw34O6rL2b+EzvuvQCVnn51OVdXSIZqwdOJ0JvGlL71MefkzuFzu\nQe/bKMh2ux7rVueVrmPCREAsa0EYx+Tnt2N2KX+KsliXk529FqfT+F02A8XQQ3QS1/79h4hY6SsN\nx98R+mVHZY4b1+rC69Vbieq11A+jeocfRjVKuZWMjN/j9f4Wv/8H4WN1oR28x/mLwG243Zbw1Kyf\n/ewC3n//TZStoWq5jYIsXceEiYiItSCMA4aK0VosAZRL+xxUYtZNJCb+ioUL13PTTZdTVnYHXu8Z\nKBHXk8f0lqC7gHxgLSkpUykuDuDz+QkGe1FCrTcyIfTzDOA6VH11E+aXhAz6+oyNS14GvoDKLs9A\nxbxtBAIF5OYewOmcjHLHv0hjo5fzzvt/ZGbmUlhYSUdHPl7vx8BZKE+B2YK++urn8HrvDp970qQ7\ncDi+G35WkikuTERErAVhHDBUjLa5uQBlyR5BWco1zJ49i6qqpZSXb8LrvQtdnJOSKklIuJ2enjhg\nGiqurGqws7PvIzm5kOrqG9HHWMI+zILsDH13AGVdr0G9JAAsJCnpMIHAGjStCJVsdhvKotbLxzR6\nexPp7b2JwsJKenoScbt/gsdjwePRcDo3AuUUFq7F6bwRlQneHzpeXdP+/V48Hv2zcu9bLGdIJzJh\nwiNiLQjjgH374ol0IusKfdZdx06MYyA7OytDx6QSsUq34PPdh893H2bX9qNAKocPF1BX10JEBK8F\nqlCZ4fkogc5CDc643XD8htC+GoFAK5p2t+E7NaVLfc5An1kNVvLyzgagvn7g4I7U1Azi4n5PMHgm\nyoJ/mMREF37/atzugee12Q6OwBMWhNhGxFoQxgEdHU2ozmJKrDo6lCA7HCVs2/Ys3d0RIXc6M7nh\nho20tzehRk0aB20UYnZtu4Dv0NtrobdXL69agcoeTw/9Mo67/H3U8T5SUp6gtBS2bp0V9V1a6Pca\neXlO2toy0NuPFhR4SEpKHSRGrXHwYDvB4C8M2+8jIeE0/P7I2gkJXSQmPoHNdpBNm5aMwBMWhNhG\nxFoQxgHRjUy6u/MpLd2G3d5JWlob3d03YxS3mpofkJFRScQa34PK3Nbjyrqr20qk3MsKTCUh4X7i\n4vz4fBehksOMAtwedbwPTfuE1auXs2tXdUjw9VjyLs48Mxjq5Z3Ntm3Gmux14USwxsZUDh/eS1aW\nnVmz1g8i+oXYbAdMa3/taykSlxZOKUSsBWEcMHPmEXbvjoiVxzOJ+vqrqK/XyMy8n4Edyyx0dWWh\nOoFtQcWoVwE5KDd2GrAas8t6OZBIIHAPSsD/GZXR/RyRCVr5qASzT9AbpHi9LoqLf8mMGbPp6FiD\nxTILm62ZTZuWMWOGHYDS0m2ma2xuzglN7oL4+ATmzp2KwzEfm83Keef9P5Mwx8W9z2OPLeI3v5EM\nb+HURcRaEGIUYwZ4QcERFixYR3NzDvv3f4jbXR7ay0JcXA4DO5Y9h7Ki70fVNH+CyuZuAaaj/ukb\nBb6XSExZjzHXYIyFq45lFtRkrFzD8Vvweu/ivffUfmVl66mq+qHpXqLLsvLzD1FSsh6n8/NAN/X1\nSwA1DWzTpjKKi+/A650LHCEY/Cm/+c1msaSFUxoRa0GIUaIzwBcseCRUB52NcZqWGg+5jr/+tZfu\n7k+Bi1CW8OmoxiMbMVvRG1ATuIwCvxeoNHzuIjLUg9DPPOC7od8/aTg+zbSfPtXLWGbmcJTQ17eO\nHTvigMP8/e+dtLYas8U3huutZ8ywc+aZc0ICrpAuZMKpjoi1IMQI0bXU+/aZrd/t27twu7+HLqiZ\nmfeTnh7gwAE7s2YFuPRSqKkxCq4+0jJ6cEYGkEpy8hr6+opQJVnXAveRklLIxRd3smePO9yCNLJe\nu2Gdr6EsbTvwIcaBH8apXvX1Gjt33oPXO4nu7kwCgRRUh7PJRJLZrEAaBQVt4WcRbYlLFzLhVEfE\nWhDGEKNAt7Xtwem8CbBRX69RWFiJ2fo1dyCLi8vB6VyK07mFhgYbCQnNmEVZH2kZPTiji4yMOI4c\nyUEN2zgHZWmfTmlpAJhMS8vNqCSyDajGJMmobmcuVAw8DeU670MN67gPOJ1Jk97D5TIniLW0nAbc\nYDi/XtJ1DsrVvhyVABeplT6RLmRHG+4hCOMdEWtBGEPMgzjK0OueIZ3u7ngWLPgdzc0F2O0efL5+\namoiohsMtqISwFTcNxBIxizKHwLrAB8JCXeQmmqnt7cVvz+Prq6bQsdGyrL0TmDLlr2FuW3oY6hE\ntXZUDFwvq1qMst5/HfrcH2rCsiHqOj7GPPAjncjMaj9KvFfQ3Pxa+LmcSBeyW255iS1b1JSv+noN\nn28djz++7LjWEIRYRcRaEMaQgYM4VN0zWPB4FvHOO5U888ylVFa+zYEDqRQWVpKdXcTMmT3s2NGD\nx6N3KFPlUCrTuwg4hEokawb6uPLKqTz++DJKS7dRX38VqtXnFNO5LZaZVFS8SkGBL6r+OQnV4/t7\nqKYo0Znnt6EEWo9xL0QJsB/1wvBjIrHpDSi3ezfqBaAGZWWfvKtbxcONYQOZUyRMHESsBWEM0F22\n+/cHUMlaKlksISGDQCAiOE7n5/n615/D6VyFXtvc3d3K4cOddHbOwFwjnQccxOxyvg/4N/7+919Q\nWrqNtrY9QDHKlW22xHt7J1Fd/VVycx8gLq6SYPBclKguBJ4hMfGXTJ6sceiQUcg7iMTBdXe7FWWx\nbwQuQAm1up+MjB4uuSSN5uYUCgr+Avhpbn52hMqx9AEk+rUdPsn1BCF2ELEWhDEgeg611foAxcVT\n8PniTK5u2EN7+ySU8H0K3IbHsxGP5ybMMWA97mu2llUi1ye0tEyipUUD4oiLewzoIRj8FuamKWnA\nb2lvn40S/UuIWMQp+P134fGsJGJFd6GsZz0uvjD0nQc129ofWidyPxkZbTz00HWjEkueNy+dmprI\ntc2blz7i5xCEsULEWhDGgGj39/TpZ1BVdQUul5va2kiNMXyf/v77gVtRcd9OlGgbY8CHgZ8Dp6Hi\nwy4iItsM/BfG0q1g8FGUqNeiXNyXoBLMsgFzJzRlraeg11/7fJ9HxbHdKBd2H6rZyixUL3Er8+f3\n8NFHHQZvQCRJzelcQUXF6NRMP/TQYpKSamlq6sduD+BwLBrxcwjCWCFiLQhjwGClSR0dbm699QW8\n3lTgXVQf799hsVhD+3Whz3c210w7iSR96f29P48S4FuBNxgqLq72vxPlEg9gdqufTWLiLvx+Y1xc\nb1eqZ3FvBCJWfn7+PVRV/V+WLXsr1B5VT1J7At3CHq2aaRmNKUxkRKwF4SQYrFxI0zhmCdGqVXPY\ntasSl2saNtsBVq8uY+XKWmpqMlGCGBHI/v5VKKH7J1Ss2Si8XSir1rgtK7R/AcrCPow5lpsVtX8i\ng7ce3cP8+Vbee68y1GnsCPA14uJuJxiczWA13IcO5bFs2VuG2Lhu4SeG1tyA3R44mUcuCKckItaC\ncBJEdxnbseNOLJZEWlq+SHQbTSOVlW+H3MRq2tXatetDFmc8oAshqCztIpYsWU9dXStudwCz8LpQ\n7m/jtkmoGHRC6LNuMetx5vej9j8Ns3gfAdYQF5cNTGLTpq+wdu3bNDVl8v77/43X+wsi5VnmGu5A\noIv6+u8BZRQWqpeR3t5EdDe61erF4bhyJB69IJxSiFgLwkkQHXtubc3E7KbeyL598dxww5Ns394F\nZDNvXj8HD9pMx+lWeH19AsqtHRHA5OSPqaqqoKTkJdzuuURiyZ+gBnTEoVzZXyAu7m0SE0/Dau0h\nP9/PO++sQpVyAVyKcksfQrnN1QuFwije7cDdBIMWtm1TLxL6y4YqrzKWZx1Cud3PRDVJ0T0IFvLy\nzmbu3E6qqyO13MXFCdKoRBBOABFrQTgJomPPaqqV0UpN49Chf9DQMBNVp2yhpkajsHAtRoFsa3uX\nRx5ZwvPPr6e/H2ANMAP4kOeeUz2yOzo+QM2n/hmq3KsIVaOs1igsrKS2dkVYDE8/3YHRnR5xbx9C\nJY3prURdpKTcyec+d0FoSMh0ol8kdDIzP6S39ymUlR4EWiksTCMvz0Jb236czhWhPbVQOdbxdyIT\nBGEgItaCcBI4HCXs2mWM6YJRhAsLG+juzid6KEZW1nSCwXtCrTgP4XTm8POf/5WMjM/hdn8ntJ+b\nxMTf8KMfHaCx8QX6+rJRTU+CqCEdlqg1i/jRj14KNQc5hNc7jYHu7TtQGdr5JCSsITGxCJvtIK+/\nfiOZmVmUl3dSXR003YOxWcm5506ltdU8lzovL4etW6/A5ZpDRcVmkzBL0pcgjAwi1oJwEthsVmpr\nr6OiQh9l6QbWceCAlY6OvWRl2Wlvfx9lyUYEsKOjiba2eIwNTF555U5SUuyoEqgkQMPvn857730F\n+CbKMs4Grg8d86RpzY8+eoeGBqMlvQqze/sQytJ+ArievLxK6uuVkObmZtDe3oXDUUJX1yb++te1\n9PfnkJfXyurVXw/fb0tLtOcgKyzmgwmz9OsWhJFBxFo45TlZQRlMpMrLN9HQsCpUvuQCfoXqo51D\ncvJenM6fAq9hFD6//zz8/q8DT2F0b0cGX6SjMrvNk6+s1qmkprbgdJ6FWUhPJ9L0pBs1IUvPFrdw\n6JCV8877T7KzizjrLB93330pNpuVjAwrfv8PUUM4VMz6vvsms3JlLR988AnGF4BJk/6Ow/HdIZ9N\ndAIerBdLWxBOABFr4ZTnZAVlMLGPTjxLTEwmISEPm+0AaWmz+fBDGwOzsveimo34MIuuPviiG5X8\npR8zGYsFtmy5iNLSF4nUQOvrdaJGUBpFXwM+ADz4fB04nbfjdFrYvVtj69YHKC7OGzCas6kp0/CM\nzE1OZs8+86gvNtHPQeZSC8KJIWItnPKcrKAMJvYFBUeor9cTsRrw+1fj96syrbi421GiOR01ZcuF\nSkzTUIMyEjGKrsXyDzTtDSAXaMNYhlVSMpnKyrfxeH6KEtInAC8WSxslJalYLI/wt78l0Nv7CX7/\n5NCx/4pqQ/p703273WdSXb1owGhOu91jeEZ6k5PNwCJmzVp/1Gcjc6kFYWQQsRZOeU5WUAYT+4IC\nH2ZXduT7YLAIZeUeRDUNKUSJbyJKjL9NxH39Ppr2A5S4bgD+D/AUcXFTyM9vZu3aMr73vY9QQl2D\ncnHXk5CQQnp6DqtWzaGy8m2ami6goaGVQOBaw5V7MFvi3YCFtjYbmZn3Ehc3hXnzgjgcX6Gi4lXT\nM7Ja36e42HXM7G7JBheEkUHEWjjlOVlBGUzsm5qMiVjdmEVxHzAXZSk3oTK0dSt6DZo2GX1spGpu\nYkW5x53AfwM/Ixi04HSqeLLdrlFf/yKq8cgW4Iv4/X+jujqVHTueprVVTzp7LOo6JqFqtnWL/VpU\nYxM3Hk8u8G2SktZjs1kHeUbLhxXXl2xwQRgZRKyFU56TFZTBxN5siS5g0iR9OEcD5vnOD2O0ujVt\nKuaksNND3+k9wfVhHjVAOi+++AlPPTWX6upGlFDrDUgWAxtob3cb1r8KPclNZZtnYh7c8SAwFfg+\najZ2JCQgoisIY8uoi/Xrr7/O2rVr0TSNq6++mu9+d+jMUUGINYzJY0VFPeGM6ejv7HaNp5+eg6ZB\nRUUtH3zQR1LSSgKByWhaFgkJCeTkvMGhQzMwzneO7tsdH3+Q/v7lKOFNA/4GrEe5rI3DPJSL3e9f\nxHXX3YHqIKYnkaWH9rMQF5dNMOgCnkPVWbtRZWT7UF3QjIlsn0OJPOgxdIkxC0JsMKpiHQwGufvu\nu3nsscfIy8vjG9/4BldccQWzZs0azdMKwogRnTzW1xfJFDd/52LXrofp6cnH7U5GtQA9D11Uu7s1\nurvvZaBLvBdju87MzG5crl+i3OTdKGv6d8TFHSYYfCp0nC7cABb6+magyrgexNyx7A4CgWyUq/si\nlBv9bsP390ZdS1doTY3MzGYuv3y9xJgFIUYYVbH+xz/+gd1uZ+rUqQB87WtfY9u2bSLWwrjhaJni\n5u+2hAdzRFzKUzBbrlNR85/10qckoAKYTGbm/Vx+eT61tfmodqLGcqtzCAZ3EElYM2drJyU10tc3\nGbgg6nznh873o9Dn56K+n0JcXCWZmflcfHEQTfPT3PxsyJX/LWleIggxxKiKdWtrKwUFBeHPU6ZM\nYffu3aN5SkE4LvQZ0sYhGw899NWwUB0tU9z8XRpmIcxhYLb1p6h48JbQfsbM7CyqqpYye/bTUesk\nomZbzyYya/paVO/wmUAj55+vMWXKeurqWnC7jefrwzzCMtqqn0QwuBq3WyM9fSO//vWiE3+QgiCM\nKqMq1pqmjebygnDSRGZIR4ZsJCVFXN3G5LGiol7uvjviFjZ+19a2B6fzUpT1qgGfEB9/iJSU/Xi9\nOaSmdjB3bhJJSX/hwAErDQ31GIWzp6cJgNTUZjweo6C+jZqQFT2M42z07O3333+A555bSmNjE5dd\npiey7UG9GNQYzrMAWE1CwukEgx0Egz8I3YmFl1/uw+VyizUtCDHKqIp1fn4+Tqcz/Lm1tZW8vLyj\nHpObmzGalzTmyP3FFk6nMdlL/Xz5Zbj55s08/PBCiopO49lnrx/02I8//pitWz/C651OUpKb5OT7\n6euLCGt//4P097t5//2vMmuWPXzcsmUbaGiwY8z61rQscnMzyM+fTUuLMRu8ELOl3YNyg98U3max\n5JKbm8HNN+8OCfUSYD5KqA9jjImXlc3i2Wf/lWXLnuJPf5ocWkPD5UpizZo3ePrpa07mccY04+3v\n5vEi9zexGVWxPvfcc/nkk0/49NNPyc3N5YUXXuCXv/zlUY9pb+866vfjGX1YwkTls7y/kRoQUVjY\ngbI8jVauxp/+dA11dXdywQWn09ycg93eyaOPltHfHx8+trj4v/F6VUJXX9/AMiz4HL29izjnnDWc\nddaF4evcuzcF1RAl0go0Le1+PvjgAG1tjcBqIpb0PZhd1wdRLvGI0H75ywHa27tC6+qubiuwHKv1\nAdzuVeFrfuGFRzj33Cc57bROMjPvx+M5K3TMQvbufW3C/v2Uf3vjm1Ph/o7FqIp1fHw8a9as4Tvf\n+Q6apvGNb3xDksuEESE6S9vne4SkpNTjFm+Ho4QdO35Pa2ukhSf4AQutrZnU1NwYPseKFea4rsrC\nNoqzLvzmjmB9fRdRX78k3IpUNTHRO5KloHqEZ1BS8gRO57+gLO40EhPfRNM6CQSM1xYAesOJYfPm\nBXnooa8Aegx9Sfj4KVPexGJJQLnmu4EFBAIZNDRYaGhYQWHhWjyeReHrLShoobx8k0zIEoQYZNTr\nrOfPn8/8+fNH+zTCKUZ0lvb27V243SrufDzDOGw2K7m5X6S19RuGrZtRYmseB/nxx+mmYxMT38Pn\n08up9gNWLJZVaNoUVCb4wtA6R8JrNDVl8vTTc/D5nmf79k85csSH378aj8cSilXrE7bgnHOCfPCB\nJ6pF6BPAdSxePPD+VAxdnyftxuc7Pfyyoa7j58A09KSzrKzpzJ0bicd3dSXIhCxBiFGkg5kwLonO\n0lZzno8+jGMo13lHxweYLeJ/oCxRv2n71KkdpvXmzTuNurprUAKryq1UUuUToWP+GlrrJlQzkhfZ\nv99LRcWr3HnnpVRWvs3WreD361neVlRWOeiZ521tB+jtjVxDYuJHLFwYqX8+WjigtHQbZsv/QmAR\nqu5aY9as/rAY5+ZmcP75zx7zGQqCMDaIWAvjkugWnz5fPzU1Rx/GMdQozKys6TidxqQuG5BGUtLf\n8fkiLmhN8wMRgfzb3zKJuLKNomhDJXlp5OfXc/75f2H7dhdu909wuy1UV2vs2lUZVZetZ3nvITPz\nAOnpnTQ2FnHWWekEg/fQ2WnHZjvIpk3fZMaMSLLa0cZ75ucbx2lG3PLJyZlkZ1fS2FhEefkzOBwl\n5OZmyIQsQYhhRKyFcUl0r2qXy01SkhLv/PxD+Hx+Skqeo6OjiezsImbOPGKY0+wGati6FcrLn+G0\n047Q0BA993k+gUAzxlpop3MzYBbIwTuBvQtYsFrfp67u/2KzWSkt3UZ9fUTQW1ryMQu8L3TeFfT2\n/gaPZzVOp1qvrGxod/TRmrZYLAHMDViUWz47243TuSo8xxrW8+yz18uELEGIYUSshQmBUbzLyzdR\nXX0jSvwiohSZ01wDLKe3V1m5Cxaso6xsPXV1AdzuScDngQcJBmcCT6JaeU5mxoxuYKBAKsv7DlQ8\nuAOYDnhITvawbNlb2O2dFBT4TFZrMNiIWeCdwCpAw++fynDd0UezhpubC1DDO9TLSUrKc5SWQmNj\nUagP95kAAB2VSURBVOhFwLy+DOsQhNhFxFqYcETE1Ni9y0J2dhFz565n61bo7Y1s37LFD8SRlLSP\nSy9NYceO9/H7jT221wCn8cYbbXz8cdMQ8fJ/Ae4HvoxyN/8Tra1NtLbGU1+vYbM1oAR9BtCIGk/5\nKOACckhI8HHmmU9y8KATtzsXo5C3tb2Ly6WGhETHp49mDUeuU5VxlZYqC728/JmQRS3ubkEYL4hY\nCxOOiEh1YU4QcwNJJCe3mJK21Pzoa+nr0/jb39bQ3z8Ts+V8EbAEp1Nj6dJKamuvo69vHVu39hMM\ndqFqnp/D3GnsTuDfw59drmYiPb9dwC9DP28DLAQCGtOmraOjw4/bnYkS9rMAC07nCioqlAteud87\nqa9/kc2bXyQ//xCbNpWZ4tg6Qwm5uLsFYfwhYi1MOHQx2rcvno6OylDMugefzx9yj3cCG7Bavbjd\nzcC3ULHddPr6JqHKsIyWc6T0yuWahs1mJTk5iWBQj1u7gD9hFvjZUZ+Nru0tqOlYz5v22bEjLtTA\nxAIsxVjGFXGFW1Bu/GsIBi3hF4j6+h8OeA5DubXF3S0I44+4sb4AQTgROjrclJdvorR0G+Xlz+By\nucPf6WL05z/PZ+7cacTHJwAaBw7o7nErcC3Tp2eRnNwN/A8qE/tS1HCM6cDtKOt3FZAMPAW4sNkO\nAkZXuxvVuSwdJewQGdox1Gd96EdX1D6HMQu8uYzLbu8M7Wd277tc007gCQqCMJ4Qy1oYM06mZejR\nSpaG2ieSYBaJ1WZm5vL66z6MFmvEoq4M/VKfU1LuZNOmbwJGV3sNKiFtPpFe3/XAdeidxPLz/8E5\n56Tw1lsPEAza6OvbT1/fYlR2trLwi4sT8PnSTOVn8CbgJjHxQ1avXoamESr5ysKY+Ka/QAiCMHER\nsRbGjOEI7lAcrWRpqH2ys4v4whfWsWNHHHAYny8Nl+t0Iv20zRYrmMurZs/+AmvXvk1T00cUFPhY\nsOB3vPZaGr293ai49TWhdeopK3s93EnM4bjB9BLicrmpqNBjxgEcjiux2azh8jOVld4K3ArY8Ps1\n1q5dD2CqzY6LqyQ/HzZtWjKsZyYIwvhFxFoYM4YjuDC4BR6dkd3W9i6NjbOprHw7FKtuort7CkYL\n9PDhvRw4kIjb/RNAjcMsLFwL5KFi1p+iOnzplu1HGC3xDz98h927fwxsob5+Cvn573DxxbBt23J0\nK1qNpnRTVXXLkPetu+n1+9LLuxyOEqqqluJyufnSl17G7Y5MBDPHrNXPL3zhbLZuveL4HrogCOMS\nEWthzBhux6zBLHCHoyTkEv48cASncwVf//rDIctT1Vfr61qtD5Ca6sfpXAG8gVHwsrKm09PTh9t9\nLSr+vBHoBeJJT7fS3R3pbOb1TkElhy0HOmlp6aalxQU4UAlk76EakMwIdwY7mls/+r5eeukOzjjj\ni8yceYR587yDdGTTpMOYIJyiiFgLY8ZwSog6OtzU1bWiMqe7gIXU1QUAyMs7G6cz4gJWiVadKAs5\nsj9kk5WVHJpdbSznctHR0YT6Z2BM9Ipj0qQP+dKXckNWs3EQhje0dgORUiy9i9mtqBj2tVRXR9z6\nQ8Xmoz0LXu9cdu9ewu7dkUYtA5+NlFwJwqmIiLUwZgynhGjlytqw21qJ4gbc7klUVNSGRk1GLE2b\n7QC9vS+i1y4b909N3R/6HEnqSk1tCVnincCDoTOqY71eDYvlEcrK9CYqicBpgHGKlTG+fQ7K6s4I\nb9Nd10PF5gc2V4mUiG3fHsfOnZcPsMyl5EoQTk2kdEuIaQa29vQBC2lqysThKKGsbD3nnfcsZWXr\n2bSpDKvVO+j+2dlFoX1fo6wswM6dV5KXdzaRUq5CoMh07JtvJlFVtZTSUg3l+p5i+F5PSoOI0Kah\nLHe1TXUecw8am+/ocOPz9ZCYeCeqocq9wFfDx+ovJIIgCCCWtRCj6K7j/ftbMDcoSQYmY7d7BrXM\ni4vfCrmgzfvPnNkzYF+zZbsA1S50cfjYI0eacbnchvi4RiQBbQEqLn4xSqi/Sn7+b9C0Plpbn0OP\no1dUbB7gAbDbPaxcWUtNzfdRVv2LZGZm0Nv7K/z+81Gu9oU0Nb02sg9VEIRxi4i1EJNEXMeq21hm\nppf09BaysuzMmrWeVasuoLx804A4sB4Hb2xM5fDhvUPuv2LFGezc+Qlxcb9H09rQtGyU5RwZien3\n9/GlL71McXE8mzYtYfHi/6Kt7V5UMtmnXHppOllZ7tCam3E4bmDZsrdobdXj6CrePm1aIYWFkU5q\nDsflLFv2FsYGLTNnPovdnkF19VVE9wQfbu25IAgTFxFrIWYwJmIpi7oTo5ht3fp/wvuqyVqROPCO\nHXdywQX/v717D66yvvM4/s4dSAI5QIBEuiGAEay2TC11YVxCsY0SwKBopXWkRZuV0sEx7Qw3124t\n3VBTrbZDhyJip1AqWNYkUAhVA4RWKcvWTTEqZYg0CLmS5DQJhlzI2T8eTs41yUlyDufJyef1jyR5\n8jy/x4if/G7f379QVTWelBQb+/bdicVyj5frjbra+/f/DZttKvZtXUbFskSMFd23AGeBWVitVyks\nXAgcYO7cmRQUrMAepnFxO/rorR/qPsMabMye7VhwVlv7IcYsVAuw8PqCMc8V7mvXHtA8tYgorMU8\nPM+Jfg3jPGnPbUru88A1NaMpKjIWf5WW2igpeZ709AleVl4bVcpsNvszXgVGYdTyjsGo8x2O8yEc\nsIeKitFERUW4PLOqarzHO2zYcAenTm2msXEyHR2tdHZ67iNft+6oS3GT5OTN5OU9isWS4LHCvbfj\nMUVk+NACMzEN9wBOSLjavXjMfZuSo0421/853uV7rdYZFBau6F6k1VNdbSOclwMP4Dhw4zxGr95+\nTSy1tR9y7twZl2c6/wJhr1V+773/Q2VlCq2t99HZGef1evf3nDDh1u6hbvf30l5qEQH1rMVE3Lcy\n/eu/dhET00RFxWjWrj3iUmTEfcjY4LywrAXn3qx9LrukpBqr1blKWTze64I7evUjRpyisvJx4G3g\nBSIj4/nqVyPIy3MMs3uOCuwBFpGQ8DyTJ6fS0HCW8vIUsrPfICmpvcfiJjq+UkS8UViLaTiOthxF\nQ8NZ3n13NE1NI4H5lJaOwbl2uMWSwNGjj/LUUwc5caKZrq6RjBr1X3z6adL178nEOQjtK8dd63I3\n0dLSRXGxZ4979OirhIe/CtTT1TWRq1dPYN9j3dlp429/20xj4z9Zu9Y+x96Ja489DhhDevpE4FPK\nyjZQWRlGWZmNhQt/1UPBEx1fKSLeKazFNOxBlZ2dT1mZY07Xfq6z+/ytxZJAdPQorNYngDCamowg\njI6OoqLimNeeqc3m8hG5uf9Gbu4ujh6toqnJ0eP+9NOP6ezcdP3j3TiOtQQIo7LyNpYuzae6ehoQ\nAdTg3LNPSDhDenqj28pv43urqpJU01tE+kVhLUHR2/GYnoVQjLlfb/O37td6C0LnZ9XWfkBl5SPA\nCUpLLZw6Vcirr36ZoqIyjCpmxtx3Z2eM030XERb2PDabYw82NFJdzfW2NQNfJyrqP/nsZ79w/ZeE\n5S7z0KrpLSKDobCWoOjteEz3cHPupdo5iqZ0Ar/AmKOezJkzZzl/fjqpqSlenwVZwHPAOowe8hKW\nLv0B7e3P4dqTHwf8DmNOu4nY2Gu0tPwEo6zoFYzKaP/h8j2xsVO89pg1Dy0ig6WwlqDo7XhMz3Bb\n7lIYpKHByoIFu64vLmsB/oF9q9XVqzbuv38zpaVrenyWUVrU8XFbW6rb12MxVoR/B8ee6k20tKzC\nqP8dS2Rkk8u2LIhlzhz7QjdXmocWkcFSWEvA+XIetfPQcF/h5r5PGTbjHLaNjZNdnlldfRqoAyYB\nTURHv097u+PZMTEfc/Wq4+Pw8L8QE3Mzra2OeyYm3sq8eYc5e3YkKSlW2tvDXY6wTE4u46WXHvXr\nvyNVLhMRO4W1BFxP51E79557Kh/qjWtP+Z/ANYzDMIxqYBbLRS9D369h1P22MW9eM7Gxjmd/97uZ\nfOtbRiETi+Ui+fnfIDfXtcb41KmfsnfvCurqjIM6GhutREc79/4fHVS49jYtICKisJaA8zbk7d57\ndi8f2ltYTZpUh2Pl9SGc545HjPgB+fkP88QT53Ad2o4HrEAR77wziowMG3v3Oupul5be7vKMvDxj\nq1hP88z+HtrubVpAREQVzCTgfKnK5R5Wb74J2dlv0Nho7b7GXiXs3XerMHrKBzAWejm+b8aMO0hN\nTfFS4awZo/DJclpbV7hUN3O/f0ZG8fUiLF9mz547AFi27CSf+cxmFizY79Euf1DlMhHpjXrWEnB9\nrYb2drBFa2sUhYXLgV0899yXWbfuKCUlnVitMcDNGNXGwFix7Riurq39kIwMSEq6wsKFO6iqGk9S\n0mWgg2PHYl3mod17r84nfZWWHqKk5C1Gjap2mR+/eHEPZWUr8PcwtVaMi0hvFNYyIN4WRCUmxvf6\n9Z7mdD0XjD0HrMIeqJ6lPH+CUdP7MAAjRjzDzTfPor7+LJWV36Gy0kJpqY2srF28+ebd3W2Jiemk\ntXU39pO2ej4cxCg9arWGYbXux3PPt/+HqbViXER6o7CWAfG2IMo4PrLnr/cURp5bq27FOBrTGA72\n/PoM4JcYx1oa27UmT95BRMStVFZauq9zPuXKOewTEp4nPX2i18NBjLY6lx5twbPmuIapReTGUljL\ngHhbEFVfbyU7e7+X86h774m6b+OaNOk0V69eBuppb48lKanNrUjKOVpaEl32OZ84EU56uvftYO5t\nnTLlZrZv77l4iethHwtJTt7MuHFpNDaeIyHhM0yb5jgFTFuuRORGUFjLgHjbJ716dZHP51E7y8tb\nQFvbDv7yl3CgHputDat1JRBGUZG3gy+Wc+edr2G1Ovd468nLM+a43ed9fS332dNhH/ZtWYmJD3Zv\n3bLTlisRuREU1jIg3vZJL1z4v7ifRz1lSkGfC6YslgRiYqKxWpdgzEN3YAR9JpDgtd73nDlxFBW9\nhrElq5k5c+KwWBJYv/4LLFu2n7//PYk//nEbqak3M2WKY7GZL4u3+jN/rC1XInIjKKxlQLztk25s\njMJ5fjc9PdLrcLOd8xCyMWy+H1iBo7e8B1jutSf80ktLiI4+SkXFNVJSOsnLWwzAsmX7XRarffTR\nHj76aEX3YrPBcB7m96USm4iIvyisxS+MHuV8jIAdQVTU/1FefgvZ2W/0OI9rDCHbe9MzgCqce6kj\nR3aQkbGrx+pm3nq/jY2TXe7hz9XbzsP8PVVi05YrEQkEhbX4hdHDHIOx//l3dHQ8S1lZGGVlrvO4\nnr3p/wYex3FutKOXmpFB9/nWvs4LWyyf0NoamNXb5887rxL3XolNRCQQFNbiF3l5CwgL28mxY9do\nauqgq8v7PK7nnukXcD43Oioql8jIz2CxXGTjxvsAKC8fhXNIfvzxKI/n238JGDNmOg0Nz2CzJREW\nVk1q6nTS0nb5pcebmtrMqVMa8haRG09hLT5raLDy1FN/vL5q+zJz5sTx0ktLsFgSsFgSiI6Oxmpd\njrE4zHuouS/IioyMp7PTfu0YOjpS6ej4Bq2tNnJzd7F9ewoNDX93uV99/VngHpe2uf4S8DWysnax\nffsK/Gnr1kza2jTkLSI3nsJafLZu3VEOH7YPWdsoKnqN6Oij3cPAjmHiTGDP9TlnXELNfUHWqFEN\nxMUZ+5g/+eQ8Vms29gM39u/v5NSpXxAbm4QxFx4HtDB2bIpH227EquyxYzXkLSLBobAWn3lWEoun\nouJa99cdw8QJwHIyMjznlh2FRzqxWkfQ1PQdmprGMHv2LqZOnUBh4Rjsq8BttjAqK22MGPEMsAl7\nwE+btsujbVqVLSKhTGEtPjMC0V6TOxb4gKQkxypvX4aJ7QuyMjKKKS1d2v35iorR7N17B7CLwsJm\nHD3pZq5ds7gVRfG8r1Zli0goU1iLz/LyFnDy5C+prjZqcsMSYEf31/szTOzeE66t/ZCHH4aUFBsx\nMVW0ta3u/lpExA/Yvv3fe72fVmWLSChTWIvPLJYEJk26jepqx1B4VdX4Ad3LuSdcW/uhy2lZ8fHb\naWtzPCM19TZ/NF9EZMhSWIvPvJ07PdC5YeeecEYGLqdlRURYcV79nZbWNui2i4gMZQpr8Zn7udPJ\nyZvJy3t00Pd1HxKfMyee6GjNP4uI2CmsxWfuq8EnTLgViyWhuyBJZaWF5OSGfh8T6bk4bLGOmRQR\ncaKwHqYGcg5zT9ujPKuS9e+YyIEsDtM50iIynCish6mBnMPc0/aoG3VMpHNA19Z+QGXlasCic6RF\nJOQprIcJ957oxx/H0t+A7akH3FOP29+9X9cefBbGXuyv+9x+EZGhSmE9TLj3pJOTc+mpfndf3EN4\n40ajmIkxZ93Y3eMeSO+9N54V1GKv/1kVy0QktAUsrLds2cLrr7/OuHHjAMjJyWHevHmBepz0wT3o\nxo6dwuzZA1tx3VMIJybGU1fX3OMzB9v7de/BJyeXMWFCl1aMi0jIC2jPeuXKlaxcuTKQjxAfuQfd\ntGnXBtzL9TWE+1uvu69hc88580e1qExEhoWAhrXNZgvk7aUf/Fk729cQ7u8z+xo2V0lRERmuAhrW\nu3fvprCwkNtuu43169cTHx8fyMdJL/wZdL6GcH+feaNWlYuIDDVhtkF0f1euXMnly5c9Pp+Tk8Os\nWbOwWCyEhYXx4osvUldXR25u7qAaKwNXX29l9eoizp+PIzW1ma1bMxk71lxDyA8//Dtef91Y3Q02\nvva1Pezd+/VgN0tEJOgGFda+unTpEqtWreLAgQN9Xuu8QCnUuC/AupGys/NdCpdkZfl/X/Jg36+x\n0cratUddeuxmmpMO5s8v0EL53UDvN9QNh/frS8CGwevq6khMTATgrbfeIi0tLVCPEh84hpitQBFv\nvgnZ2W+YqvKX5qRFRLwLWFj/9Kc/5aOPPiI8PJybbrqJH/3oR4F6lPjAsSisCFhOa2sYhYUD3/vs\nbeW2L78diohI/wUsrPPy8gJ1axmADRvu4NSpzVRVTcJmG/wiLm8rtwsKVviruSIi4kQVzIaJzZvf\nu3685Wv0VrnM3mMuL4+goaGCcePSmDr1isdwuVZui4jcOArrYcIRrpnAHkaO7CAjA49tV44e8x5g\nA5WVYbz/vudwube91vX1VrKz9+skLBERP1NYDxOOcE0AlpOR4Qhf5/nnf/yjEyOA43DuOZeXjyI7\nO9+jHrh95faGDV9g1qxfcfHiOvpbC1zHXYqI9E5hPUz0VsjE9TSr3RjD5M04D5dfvnyGsrKnsQdx\ne/sOfvObh7vvkZ2dz8WLtzKQoXF/H/ghIhJqFNbDRG/bolznnxeRkPA8kycn09Cw+fqc9accPZqA\ncxCfOBHu5R4tDOQkL81/i4j0TmEtbvPPY0hPn8j27fe5XJOWthXnIIZ6L/e4D2OuO5bk5DLy8h4d\nwPN13KWIiDuFtfhU63vOnDiKil4D4oFm5syJ87hHTMxhzp4dSUqKtV8nYvnzkBERkVB0Q8qN9keo\nl5Qbqu/nSynQofx+vgjl9wvldwO931A3HN6vL+pZB0ioVfhSKVARkeBRWAeIKnyJiIi/hPd9iQyE\nVjiLiIi/KKwDJCXlnxirpkErnEVEZDA0DB4gWuGsymQiIv6isA6Q/i7ICsVgU2UyERH/UFibRCgG\nm+btRUT8Q3PWJhGKwaZ5exER/1DP2iRCseSm5u1FRPxDYW0SoRhsKqQiIuIfCmuTULCJiEhPNGct\nIiJicgprERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgpr\nERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NY\ni4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NYi4iImJzC\nWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicoMK68OHD7N48WJmzpzJBx984PK1\nbdu2kZGRwcKFC/nzn/88qEaKiIgMZ4MK67S0NLZs2cLs2bNdPl9eXk5RURGHDh1i+/btPPvss9hs\ntkE1VEREZLgaVFhPnTqVKVOmeARxcXExmZmZREZGMnnyZFJSUjh9+vSgGioiIjJcBWTOuqamhqSk\npO6PJ06cSE1NTSAeJSIiEvIi+7pg5cqVXL582ePzOTk5LFiwwOv3eBvyDgsLG0DzREREpM+w/vWv\nf93vm06aNImqqqruj6urq5kwYYJP35uYGN/v5w0ler+hLZTfL5TfDfR+Q12ov19f/DYM7tybXrBg\nAYcOHaK9vZ1PPvmECxcu8LnPfc5fjxIRERlWwmyDWKb99ttvs2nTJhobGxk9ejQzZszglVdeAYyt\nW/v27SMyMpKnn36au+66y2+NFhERGU4GFdYiIiISeKpgJiIiYnIKaxEREZNTWIuIiJicacN6x44d\nzJgxA6vVGuym+NXPf/5z7rvvPpYuXcrjjz9OXV1dsJvkV3l5eSxcuJCsrCzWrFlDS0tLsJvkN73V\nwh/Kjh8/zr333ss999zDyy+/HOzm+NXGjRuZO3cuS5YsCXZTAqK6upoVK1aQmZnJkiVL2LlzZ7Cb\n5Dft7e089NBDLF26lCVLlrBly5ZgNykgurq6uP/++1m1alWv15kyrKurq3n33XdJTk4OdlP87tvf\n/jb79++noKCA+fPnh9x/gHfddRcHDx6ksLCQlJQUtm3bFuwm+U1PtfCHsq6uLjZt2sSOHTv4wx/+\nwMGDBykvLw92s/zmgQceYMeOHcFuRsBERESwYcMGDh06xJ49e9i9e3fI/Pyio6PZuXMnBQUFFBQU\ncPz48ZAsW71z506mTZvW53WmDOvc3FzWrl0b7GYERGxsbPefW1tbCQ835Y9gwObOndv9TrNmzaK6\nujrILfKfnmrhD2WnT58mJSWFm266iaioKBYtWkRxcXGwm+U3X/ziFxk9enSwmxEwiYmJzJw5EzD+\n3zJt2jRqa2uD3Cr/GTlyJGD0sjs7O4PcGv+rrq6mpKSEhx56qM9r+6xgdqMdOXKEpKQkbrnllmA3\nJWBefPFFCgsLiY+PD6lhK3f79u1j0aJFwW6G9MJbHf/3338/iC2Sgbp48SJnzpwJqQJUXV1dPPDA\nA1y4cIFHHnkkpN4NHB3T5ubmPq8NSlj3VG/8qaeeYtu2bbz66qvdnxuKvZi+6qnn5OSQk5PDyy+/\nzG9/+1vWrFkThFYOnC/14rdu3UpUVNSQmyscSC38oWwo/v0ST1euXOHJJ59k48aNLqN3Q114eDgF\nBQW0tLSwevVqzp07x/Tp04PdLL84duwY48ePZ+bMmZw8ebLP64MS1j3VGz979iyXLl0iKysLm81G\nTU0Ny5Yt4/e//z3jxo27wa0cOF/rqS9evJgnnnhiyIV1X++Xn59PSUnJkBw1GEgt/KFs0qRJVFZW\ndn9cU1Pjcx1/MYfOzk6efPJJsrKy+MpXvhLs5gREXFwcX/rSl/jTn/4UMmH93nvvceTIEUpKSmhr\na+PKlSusXbuWvLw8r9ebasI0LS2Nd955h+LiYo4cOcLEiRPJz88fUkHdl4qKiu4/FxcXM3Xq1CC2\nxv+OHz/OK6+8wtatW4mOjg52cwImVHqkt99+OxcuXODSpUu0t7dz8OBB7r777mA3y69C5WfVk40b\nNzJ9+nS++c1vBrspftXQ0NA9PHz16lVOnDgRUv+//N73vsexY8coLi7mZz/7GXfeeWePQQ0mnLN2\nFhYWFnJ/0V544QXOnz9PeHg4ycnJPPvss8Fukl/9+Mc/pqOjg8ceewyAz3/+8/zwhz8MbqP8xLkW\n/qpVq1xq4Q9VERERPPPMMzz22GPYbDYefPBBn1amDhXf//73OXnyJFarlfnz57NmzRqWLVsW7Gb5\nzV//+lcOHDhAWloaS5cuJSwsjJycHObNmxfspg1aXV0d69evp6uri66uLjIzM0lPTw92s4JGtcFF\nRERMzlTD4CIiIuJJYS0iImJyCmsRERGTU1iLiIiYnMJaRETE5BTWIiIiJqewFhERMTmFtYiIiMn9\nPyQ+uNKCpR6MAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0xa813090\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Plot the Data (Optional)\n", "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.scatter(inputs.numpy(), labels.numpy())\n", - "plt.show()" + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JaFHyAG9nDET" + "id": "v9fPs8RyopCf" }, "source": [ - "## Step 2: Define our TensorFlow variables\n", + "### Higher-order gradients\n", "\n", - "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/contrib/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias.\n", - "\n", - "(**Note**: We're using the implementation of `Dense` found in `tf.layers.Dense` though the documentation link is for `tf.contrib.keras.layers.Dense`. When TensorFlow 1.4 is released, the documentation will also be in `tf.layers.Dense`) " + "The same API can be used to differentiate as many times as you like:\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "height": 34, - "output_extras": [ - { - "item_id": 1 - } - ] + "height": 276 }, "colab_type": "code", "executionInfo": { - "elapsed": 22, + "elapsed": 730, "status": "ok", - "timestamp": 1505502830753, + "timestamp": 1527005655565, "user": { "displayName": "", "photoUrl": "", "userId": "" }, - "user_tz": 240 + "user_tz": 420 }, - "id": "z9r-ZeyrXu3A", - "outputId": "6230a7a3-29fe-4d08-f101-da80425bad82" + "id": "3D0ZvnGYo0rW", + "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff" }, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n", "text/plain": [ - "[]" + "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e" ] }, - "execution_count": 4, "metadata": { "tags": [] }, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "# Create TensorFlow Variables using Keras's Dense layer.\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", + "\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", "\n", - "wb = tf.layers.Dense(units=1, use_bias=True)\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", "\n", - "# We can access the underlying TensorFlow variables using wb.variables.\n", - "# However, the variables won't exist until the dimensions of the input\n", - "# tensors are known. Once the dimensions of the input tensors are known,\n", - "# Keras can create and initialize the variables. Until then, Keras will\n", - "# report the variables as an empty list: [].\n", + "import matplotlib.pyplot as plt\n", "\n", - "wb.variables" + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "docKLUaonYG_" + "id": "-39gouo7mtgu" }, "source": [ - "## Step 3: Define our loss function\n", + "## Gradient tapes\n", "\n", - "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)." + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -257,145 +182,42 @@ } }, "colab_type": "code", - "id": "0_w8ZJSCtuY7" + "id": "MH0UfjympWf7" }, "outputs": [], "source": [ - "def loss_fn(inputs, labels, wb):\n", - " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n", - " predictions = wb(inputs)\n", - " return tf.reduce_mean(tf.square(predictions - labels))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 34, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 24, - "status": "ok", - "timestamp": 1505502830875, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "RkNbXoXkpjVH", - "outputId": "c36fc98d-3a57-4074-901d-c10ae017ae3f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=40, shape=(), dtype=float32, numpy=7.3549819\u003e" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Test loss function (optional).\n", - "\n", - "loss_fn(inputs, labels, wb)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 51, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 57, - "status": "ok", - "timestamp": 1505502830981, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "K_7beXoHOU7t", - "outputId": "1ad0856a-02ec-4117-a6c0-b41030981d87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w: tf.Tensor([[ 1.56891453]], shape=(1, 1), dtype=float32)\n", - "b: tf.Tensor([ 0.], shape=(1,), dtype=float32)\n" - ] - } - ], - "source": [ - "# At this point, the variables exist, and can now be queried:\n", - "\n", - "w, b = wb.variables\n", - "print(\"w: \" + str(w.read_value()))\n", - "print(\"b: \" + str(b.read_value()))" + "def f(x, y):\n", + " output = 1\n", + " for i in range(y):\n", + " output = tf.multiply(output, x)\n", + " return output\n", + "\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", + "\n", + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "YIlebeb_qYtC" + "id": "aNmR5-jhpX2t" }, "source": [ - "## Step 4: Create our gradients function using `implicit_value_and_gradients()`\n", - "\n", - "With a loss function defined, we can calculate gradients and apply them to our variables to update them.\n", - "\n", - "To calculate the gradients, we wrap our loss function using the `implicit_value_and_gradients()` function.\n", - "\n", - "`implicit_value_and_gradients()` returns a function that accepts the same inputs as the function passed in, and returns a tuple consisting of:\n", + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", "\n", - "1. the value returned by the function passed in (in this case, the loss calculated by `loss_fn()`), and\n", - "1. a list of tuples consisting of:\n", - " 1. The value of the gradient (a `tf.Tensor`) with respect to a given variable\n", - " 1. The corresponding variable (`tf.Variable`)\n", - "\n", - "Test it out below to get a feel for what it does. Notice how the first value of the returned tuple (the loss) is the same as the value returned in the cell above that tests our loss function." + "For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -403,94 +225,48 @@ } }, "colab_type": "code", - "id": "v1spZQ4NwW1U" + "id": "bAFeIE8EuVIq" }, "outputs": [], "source": [ - "# Produce our gradients function. See description above for details about\n", - "# the returned function's signature.\n", - "\n", - "value_and_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 153, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 46, - "status": "ok", - "timestamp": 1505502831114, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "21WMcpsmFFLd", - "outputId": "f51b3171-33f5-4f87-8bf7-0be2dc8edc8a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Outputs of value_and_gradients_fn:\n", - "Loss: tf.Tensor(7.35498, shape=(), dtype=float32)\n", - "\n", - "Gradient: tf.Tensor([[-3.00773573]], shape=(1, 1), dtype=float32)\n", - "Variable: \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e\n", - "\n", - "Gradient: tf.Tensor([-4.06519032], shape=(1,), dtype=float32)\n", - "Variable: \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e\n" - ] - } - ], - "source": [ - "# Show outputs of value_and_gradients_fn.\n", - "\n", - "print(\"Outputs of value_and_gradients_fn:\")\n", - "\n", - "value, grads_and_vars = value_and_gradients_fn(inputs, labels, wb)\n", - "\n", - "print('Loss: {}'.format(value))\n", - "for (grad, var) in grads_and_vars:\n", - " print(\"\")\n", - " print('Gradient: {}\\nVariable: {}'.format(grad, var))" + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JVDWpL9VYWdP" + "id": "DK05KXrAAld3" }, "source": [ - "## Step 5: Create an optimizer\n", + "### Higher-order gradients\n", "\n", - "We'll use a `GradientDescentOptimizer` to fit our model." + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -498,362 +274,45 @@ } }, "colab_type": "code", - "id": "DudNEebMKDWN" + "id": "cPQgthZ7ugRJ" }, "outputs": [], "source": [ - "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "YBeJYxY8YaiO" - }, - "source": [ - "### Step 5a: Test Our Optimizer\n", - "\n", - "Now we have everything needed to start fitting our variables to the data!\n", + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", "\n", - "In the next cell, we'll demo these capabilities. We'll:\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", "\n", - "1. Print the current values of `w` and `b`\n", - "1. Calculate the loss and gradients\n", - "1. Apply the gradients\n", - "1. Print out the new values of `w` and `b`\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", "\n", - "You can run the cell multiple times. Each time, you should see the values of `w` and `b` get closer to their true values of 3 and 2." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 102, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 103, - "status": "ok", - "timestamp": 1505502831285, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "diDZfrMJM3OC", - "outputId": "d585fff0-ecb3-4e98-9b33-bbae07a95d8c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Values of w, b, BEFORE applying gradients:\n", - "(array([[ 1.56891453]], dtype=float32), array([ 0.], dtype=float32))\n", - "()\n", - "Values of w, b, AFTER applying gradients:\n", - "(array([[ 1.86968815]], dtype=float32), array([ 0.40651903], dtype=float32))\n" - ] - } - ], - "source": [ - "# Test the optimizer.\n", - "\n", - "print(\"Values of w, b, BEFORE applying gradients:\")\n", - "w, b = wb.variables\n", - "print(w.read_value().numpy(), b.read_value().numpy())\n", - "print()\n", - "\n", - "# Calculate the gradients:\n", - "empirical_loss, gradients_and_variables = value_and_gradients_fn(\n", - " inputs, labels, wb)\n", - "optimizer.apply_gradients(gradients_and_variables)\n", - "\n", - "print(\"Values of w, b, AFTER applying gradients:\")\n", - "print(w.read_value().numpy(), b.read_value().numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "61TgeLVlKEQp" - }, - "source": [ - "## Step 6: Create a training loop\n", - "\n", - "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 397, - "output_extras": [ - { - "item_id": 1 - }, - { - "item_id": 2 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 225, - "status": "ok", - "timestamp": 1505502831550, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "VukGe-huNaJ4", - "outputId": "f0a8d665-1910-477c-d8ab-c94ccdc4afcd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2.111051321029663, 2.3047544956207275, 2.4602210521698, 2.5850086212158203, 2.6851789951324463, 2.7655951976776123, 2.830157995223999, 2.8819968700408936, 2.9236228466033936, 2.9570505619049072]\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAFXCAYAAADnFpTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd4FFUbBfAzu+m9koSShBQCSC+igIAgRRGkChJEiggo\nHURAEBQBQeADRcWCha50ULFLk6IivYRQQwskhPS6O/P9sckmm4Rkk2x2difn9zz7bLuZvC8JHO7M\n7FxBkiQJREREVOlUchdARERUVTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMjArdlJQU\njB8/Hk8//TS6d++OkydPVnZdREREiiMY8znd6dOno2XLlujbty80Gg0yMzPh4uJijvqIiIgUo9TQ\nTU1NRa9evfDbb7+ZqyYiIiJFKnX38s2bN+Hp6YkZM2agd+/emD17NjIzM81RGxERkaKUGroajQbn\nzp3DoEGDsH37djg4OOCzzz4zR21ERESKUmro+vv7w9/fHw0bNgQAdO3aFefOnSvxa3g5ZyIioqJs\nShvg4+ODgIAAXL16FbVr18aRI0cQGhpa4tcIgoC4uBSTFSkHX19Xq+8BUEYfSugBYB+WRAk9AMro\nQwk9ALo+jFFq6ALArFmzMHXqVGg0GtSqVQsLFy6sUHFERERVkVGhW7duXWzdurWyayEiIlI0XpGK\niIjITBi6REREZsLQJSIiMhOGLhERkZkwdImIiMyEoUtERCbRuXM7uUuweAxdIiIyCUEQ5C7B4hn1\nOV0iIqKy+OijFTh69BAEQYUhQ4ajU6fOuH8/HnPmzER6ehq0Wi2mTJmOJ59sgwUL3kZU1HkAArp3\n74nnn39B7vIrDUOXiEhh5s6dhd27d5h0mz169MLcue8aNXbv3t9x+XI01qz5Fg8eJODll4egadNm\n+PXXn9Cq1eN48cVhkCQJmZmZOH/+POLi7uGbbzYBANLSUk1at6Xh7mUiIjKp06dP4qmnugIAPD29\n0LRpc5w/fw716j2CH37Yha+++hyXLkXD0dERtWrVwp07t7F8+RIcPXoYTk7OMldfuTjTJSJSmLlz\n3zV6VloZCq80l/e8ceOm+Oijz3H48EEsWDAXAwcOxuDBA/D11xtx9Ohh7Ny5DX/88StmzHhLjrLN\ngjNdIiIyifxwbYbff/8VoijiwYMHOHXqBOrXfwSxsbHw8PDEs8/2wrPP9sLFixeQmJgIUdSiffsn\n8fLLoxEdHSVzF5WLM10iIjKJvLOX27d/EmfPnsbQoS9AEFR49dXx8PT0wp4932PjxrWwsbGBk5Mz\nZs16G7GxsXj99TcgSSIEQcDo0eNk7qJyCVIlrThv7esjKmmNR2vvQwk9AOzDkiihB0AZfSihB8D4\n9XS5e5mIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIismjHjx/DmTOn9M93\n7NiKn3/+0STbXrv2K5Nsx1gMXSIismjHjx/D6dP5odurV1907fqMSba9Zo15Q5dXpCIiogrbsGEN\n7O3t0bfvAHzwwVJcvnwJK1Z8gmPH/sGPP+7C7NnzDMZHRV3Ahx8ug0aTDWdnN7z55hx4eXlj8+ZN\n2LlzG2xsbBAcXBujR4/Fzp1boVbb4Ndf92DixNfx779/w8nJCQMHDsa4caNQp04ETp48gczMTMya\nNRdr136FK1cuo2PHzhg5cgwAYMaMqYiLu4fs7Cz07/8CevTohVWrViI7OwvDh0eidu0QzJ49D7/8\nsgebN2+CVqtB/foNMGXKdJOuE8zQJSJSGOe5s2Bv4qX9snr0QloJiyg0btwM3367Hn37DkBU1AXk\n5ORAq9Xi1KkTaNy4mcFYjUaD5csX4733liEsrBY2bdqGTz/9CDNmvIX167/Bli27YWNjg7S0VDg7\nu+C55/rqQxYA/v33b4Pt2dra4Ysv1mDz5k2YPn0KvvpqPVxcXDFgQC8MGBAJNzc3zJw5B66ursjK\nysLIkUPQvn1HjB49Ftu2bcaXX64HAFy/fg2///4LVq36Emq1GkuXLsIvv+wx2awaYOgSEZEJRETU\nRVTUeaSnp8PW1hYREXVx/vw5nDx5HJMmTTMYGxNzHVeuXMakSa9BrVYhO1sDHx9fAEBYWDjmzn0T\n7dp1wBNPdDDqe7dt2w4AEBoahpCQUHh6egEAqlevgXv37sLNzQ3ffbcBBw7sAwDcu3cPN2/GoH79\nBgYrIv3779+4eDEKI0cOgSRJyM7OhpeXV0X/aAwwdImIFCZt7rslzkorg42NDfz9A/Djj7vQsGFj\nhIWF4/jxf3H79i0EBQUXGi0hJCQUn3zyZZFrL7///gqcOPEfDh7cjzVrvsSaNd+W+r1tbe0A6BZc\nsLW11b8uCAK0Wi2OHz+G//77F5999jXs7OwwbtwoZGdnF7MlCd26dceoUa+V40/AODyRioiITKJx\n46bYuHEdmjRphkaNmmDHjq0ID69TZFxgYDAePEjEmTOnAeh2N1+9egUAcPduLJo2bY4xY8YhLS0N\nGRnpcHJyQlpaWrnrSktLhaurK+zs7HD9+jWcPXtG/56trS20Wi0AoHnzR7F37+948OABACA5ORmx\nsbHl/r7F4UyXiIhMonHjpli79is0aNAQ9vYOsLe3L3I8F9DNit99dxGWL38fy5cvQnZ2Dp5//gXU\nqhWId96ZnRuwEvr3HwhnZxe0adMOs2a9gb/+2o+JE183OLGppJOc8t5r1ao1duzYisGDn0dgYBAa\nNGioH9OzZ2+89NJARETUxezZ8/Dyy2MwefJrEEUJtra2mDx5Gvz9/U32Z8Sl/R5CSctNWXsfSugB\nYB+WRAk9AMroQwk9AFzaj4iIyOIwdImIiMyEoUtERGQmDF0iIiIzYegSERGZCUOXiIjITBi6RERk\ndt99txFZWVlyl2F2DF0iIjK7zZs3Iisrs9j3RFE0czXmw9AlIqIK27BhDbZu1V0n+YMPlmLCBN2S\neseO/YN582YbjN2yZRPi4+MwbtxovPTSSwCAzp3bYeXK5Rg2bBDOnDmF/v17Ijk5CQBw4cJ5jBs3\nCgCQmZmJhQvfwciRL2H48ME4eHC/uVo0CV4GkohIgbyaNyj29YRjZ4p9vazjCyvL0n79+g3Et99u\nxIcfforQ0BqIi0tBZmYGGjRoiLFjJ+aOMry8Y94lHb/5ZjWaN38UM2a8hdTUVIwcOQQtWz4Ke3sH\no+qUG0OXiIgqrCxL++lIuTcdtVqN9u07Fnq/qH/+OYpDhw5g48Y1AHSLJdy9G4vAwGCT9VKZGLpE\nRApk7Ay1vOMLK9vSfkXZ2dkbLF6gVqshirrgzc7OP+FKkiS8++5i1KoVWKF65cJjukREZBLGLu0H\nAE5OzgbL9RVeeycgoDqios4DAPbt+0P/+qOPPoYtWzbpn0dHR5myhUpn1Ey3Y8eOcHFxgUqlgo2N\nDbZs2VLZdRERkZUxdmk/AOjZsxemTh2PgAB/LFmyssgSfUOHjsR7770DFxcXNG3avMDrL+ODD5bi\npZcGAgD8/QOwaNH/Kq8pEzNqab9OnTph27ZtcHd3N2qjFy9ehKdnQIWLk5OSlpuy9j6U0APAPiyJ\nEnoAlNGHEnoATLy0nyRJZfrc1IABA5CTk2P0eCIioqrAqNAVBAEjRoxA37598d1335U6/sSJE/jw\nQ+uZ7hMREZmDUcd0N23aBF9fXyQkJGDYsGEICQlBixYtHjq+Ro0aWLp0Ebp164769R8xWbFERETW\nzKhjugWtXLkSzs7OGDZs2EPH/PDDD3j22WfRvHlzHDlyBDY2/GQSERFRqWmYkZEBURTh7OyM9PR0\nHDx4EGPHji3xa7p3747nn38B3323EXPnvosJE6aYrGBzUdLBfWvvQwk9AOzDkiihB0AZfSihB8D4\nE6lKDd34+HiMHTsWgiBAq9WiR48eaNu2bakbfvfd97Bv3594//2F6NatOyIi6hpVEBERkVKVeiJV\nrVq1sHPnTuzYsQO7d+/GK6+8YtSGPTw88f77y5GdnY0JE8ZAo9FUuFgiIrJMsbF3MGTIAJNuMzr6\nIg4f/kv//ODB/Vi//huTbFuupQUr9YpU3bo9g759n8d//x3DqlUfVea3IiIimRW+wEVFXbp0EUeO\n5Idu27btEBn5kkm2XdLSgpWp0s9wmj9/Efbv34tFi95F165PP/SSYEREZN00Gg3eeWc2Ll68gNq1\nQzFr1tuwt7c3GHPr1k0sW7YYSUmJcHBwwHvvLYCLiw/++OM3fP3151Cr1XB2dsHy5R/jiy9WITs7\nG6dPn8TgwcOQlZWJCxfOYdKkaViw4G3Y2dkjOjoKiYkPMGPGW9iz53ucPXsa9es3wMyZcwAAS5a8\nh6ioc8jKykKHDp0wfPgrBksLenh4YMWKT/D330fw5ZefIScnBzVq1MTMmXPg4GD6lYsqPXS9vLyx\nePH/MGxYJCZMeBW7d/8MtVpd2d+WiKjKmjvXHrt3m/af9x49NJg7t+TdsTEx1zFjxhw0aNAQCxe+\ng+3bN2PgwMEGYxYvXoBp02aiRo2aOHfuDObOnYslS1bim2++wLJlH8HHxwdpaamwsbHByy+PRlTU\neUyc+DoAYM+e7w1m06mpKfj0069w8OA+vPHGJKxa9RVq1w7BiBEv4tKlaISFhWPUqNfg6uoKURQx\nYcIYXLlyyWBpQTc3NyQlJWLNmi+xYsXHsLd3wPr132DTpnUYOvRlk/4ZAmZaZah79x7o1asPduzY\nhs8//wSjR5d89jMREVkfPz9/NGjQEADQtesz2LLlW4PQzcjIwJkzJzF79hsFFjjQ3Tds2Bjz589B\nx46d0b79k0Z9vzZtngAAhISEwcvLG7VrhwAAatcOQWzsbYSFheP333/Grl07oNVqkZBwH1evXkVI\nSBgKLi149uwZXLt2BWPGjIAkSdBoNGjQoFHF/0CKYbYP0C5YsAQHD+7HggXvoEuXbrlNExGRqc2d\nm1XqrLQyFD6mW/gQrySJcHV1w5dfrte/lveRoalTZ+D8+bM4dOggRox4EatXryv1+9nZ2QEAVCqV\n/nHec61Wizt3bmPTpvVYvXotnJ1dsGDB2wbLBObXJaFly8cwZ867ZWm3XMy2tJ+Pjw/ee28pMjMz\nMWHCa2W6ljMREVm+2Ng7OHtWty7vr7/+jEaNmhi87+TkjICA6vjzz9/0r124cAGA7lhvvXqPYMSI\nUfDw8MS9e3fh5ORksPxfSYq7zlNaWhocHR3h5OSMhIT7OHLkkEEtedt+5JGGOH36JG7dugkAyMrK\nxI0bMWXo3HhmvVRUz5690aPHduzevQOrV3+KkSPHmPPbExFRJQoKCsa2bd9h4cK3ERwcgl69+hUZ\nM2fOu3j//YX45psvodVq0LNnD/Tv/yI+/ngFbt68AQBo3rwlwsLCUa2aH9at+xrDh0di8OCHXwUR\nKP7M6bCwcISHRyAysh+qVfNDo0aN9e/lLS3o4+OLFSs+wcyZczB37kxkZ+dAEASMHDkGtWoFVvBP\npJg6y3oZSGM97AojcXFxeOKJlsjMzMSffx7S74O3NEq6Soq196GEHgD2YUmU0AOgjD6U0ANg4qX9\nTMnX1xcLFy5Beno6Jk0ay93MRERUZZg9dAGgV6++ePrpZ3Ho0EF8/fVqOUogIiIyO1lCVxAELF78\nP3h4eOCdd97C9evX5CiDiIjIrGQJXQDw8/PD/PmLkZ6ehsmTxxV75hkREZGSyBa6ANCv3wB06dIN\nBw7sw5o1X8lZChERUaWTNXQFQcCSJSvg7u6BuXNnVdrnooiIiCyBrKELAP7+AZg3byHS0lK5m5mI\nyEoZu7Tfnj3f4/79eDNUZJlkD10AGDBgEDp16ox9+/7Ehg1r5S6HiIjKwZil/X78cTfi4uKKfa8q\nfITUIkJXEAQsXfoBXF3d8NZbM3H79i25SyIiojLKW9pv8OD+mD17epFF4vfu/R0XLpzHvHmzMXx4\nJLKystCxY0d88smHGDHiRfz5528YN24UoqJ0l4ZMSkpE//49AegC+eOPV2DkyJcwdOgg7Nq13ez9\nmYJFhC4AVK9eA++8swApKcmYMmU8dzMTEVVA8+bOxd5MNb44MTHX0afP81i3bjOcnJywfftmg/c7\ndOiEevXqY86cd/Hll+v1a+26u3tg9eq16NSpSzFb1c2ev/9+J1xcXPH559/g88+/wa5d2xEbe6dM\n9VkCiwldABg06EV06NARv//+K779doPc5RARURkUXtrv1KmTRcZIkoTCc6pOnTqXuu2//z6Cn376\nAcOGDcIrr7yE5OQkqzz51qwLHpRGEAQsW/Yh2rV7DLNnz0CHDh3h7x8gd1lERFbn2DHjVucp7/ji\nlLa038M4OjrqH6vVakiS7thudnZ2gVESJk16HS1bPlbRMmVlUTNdAKhZsxbmzJmHpKRETJ06gbuZ\niYisRGlL+wGAs7Mz0tJSH7qNgIAauHDhHAAYLAH46KOPY9u2LdBoNACAGzdikJWVacryzcLiQhcA\nhgwZhieeaI9ffvkJW7Z8K3c5RERkhLyl/QYP7o+UlORil/Z7+ulnsWTJQv2JVIVnxy+8EInt27di\n+PDBSE5O1r/eo0cvBAfXxogRgzFkyAAsWbIQWq220nsyNbMv7WesmJjraNfuMdjZ2eLAgX/g5+dn\nosqMo6Tlpqy9DyX0ALAPS6KEHgBl9KGEHgALXtrPWIGBQZg9+20kJiZi2rRJ3M1MRERWz2JDFwCG\nDXsZrVu3xZ4932PHjq1yl0NERFQhFh26KpUK//vfSjg5OWHGjKm4d++e3CURERGVm0WHLgDUrh2C\nN9+cg4SEBMyYMVXucoiIiMrN4kMXAEaMGIVWrR7H7t07rPbSX0RERFYRuiqVCitWfAQHBwdMnz4F\n8fFVd4UKIiKyXlYRugAQEhKGGTPeQnx8PGbO5G5mIiKyPlYTugDwyitj0KLFo9ixYxt++GG33OUQ\nERGViVWFrlqtxooVH8Pe3h7Tpk1CQsJ9uUsiIiIymlWFLgCEh9fBG2/MQlzcPbz55htyl0NERGQ0\nqwtdABgzZiyaNWuOrVu/w08//Sh3OUREREaxytDV7Wb+BHZ2dnj99YlITHwgd0lERESlssrQBYCI\niLp4/fUZuHs3FrNnz5C7HCIiolJZbegCwGuvTUDjxk3x7bcb8OuvP8ldDhERUYmsOnRtbGzwwQef\nwNbWFlOnTkRSUqLcJRERET2UVYcuANSrVx+TJ0/DnTu3MWfOm3KXQ0RE9FBWH7oAMH78ZDRo0Agb\nNqzFH3/8Jnc5RERExVJE6Nra2uKDDz6BjY0NJk8eh5SUZLlLIiIiKkIRoQsADRo0xMSJU3H79i3M\nnTtb7nKIiIiKUEzoAsDEiVNRv34DrF37Ffbt+1PucoiIiAwYHbqiKKJ3794YPXp0ZdZTIXZ2dvjg\ng4+hVqsxefI4pKamyF0SERGRntGhu2bNGoSGhlZmLSbRqFETjB8/CTduxGDevDlyl0NERKRnVOjG\nxsZi37596N+/f2XXYxKTJ7+BunXr4auvvsDBg/vlLoeIiAiAkaG7YMECTJs2DYIgVHY9JmFvb48V\nKz6GSqXCxIljkZaWJndJREREsCltwN69e+Hj44N69erh6NGjRm/Y19e1QoVVVJcuHTBt2jS89957\nWLZsAT744IMyb0PuHkxFCX0ooQeAfVgSJfQAKKMPJfRgLEGSJKmkAcuWLcOuXbugVquRlZWFtLQ0\ndO7cGYsXLy5xw3Fx8p/ElJmZiaeeegIXL0Zh5849ePzxNkZ/ra+vq0X0UFFK6EMJPQDsw5IooQdA\nGX0ooQfA+P84lLp7efLkydi7dy9+//13LFu2DK1atSo1cC2Fg4MDli//CCqVChMmvIr09HS5SyIi\noipMUZ/TLU6LFo9i9OixuHbtKhYunCd3OUREVIWVKXQfffRRrFq1qrJqqTRvvPEmQkPD8NlnH+Po\n0SNyl0NERFWU4me6AODo6Ijlyz8GAEyc+CoyMjJkroiIiKqiKhG6ANCq1WN45ZUxuHz5EhYtmi93\nOUREVAVVmdAFgBkz3kJwcG2sWrUS//77t9zlEBFRFVOlQtfJyQkrVnwMURQxYcKryMzMlLskIiKq\nQqpU6ALA44+3wcsvj0J09EUsWfKe3OUQEVEVUuVCFwDefHMuAgODsXLlchw/fkzucoiIqIqokqHr\n7OyM5ctX6nczZ2VlyV0SERFVAVUydAGgbdt2GDp0BC5cOI///c86rrBFRETWrcqGLgC89dY7qFUr\nECtWLMOpUyfkLoeIiBSuSoeui4srli37EFqtFuPHv4rs7Gy5SyIiIgWr0qELAO3bP4kXXxyGc+fO\nYPnyJXKXQ0REClblQxcA5s6dhxo1amL58iU4c+a03OUQEZFCMXQBuLq6YenSD6DRaDB+/Bjk5OTI\nXRIRESkQQzdXx45PYdCgF3HmzCl8+OH/5C6HiIgUiKFbwNtvz4e/fwCWLl2E06e5m5mIiEyLoVuA\nu7sHli5dgZycHAwdOhSpqalyl0RERArC0C2kc+duiIwcgv/++w8DBvRGcnKS3CUREZFCMHSL8f77\nyzFo0CD8889R9O3bEwkJ9+UuiYiIFIChWwwbGxusWbMGgwa9iJMnj6N372cRFxcnd1lERGTlGLoP\noVarsWzZhxg+fCTOnz+LXr2exp07t+Uui4iIrBhDtwQqlQoLFy7Bq6+OR3T0RfTs2Q03bsTIXRYR\nEVkphm4pBEHAnDnzMGXKG7h+/Rp69uyGK1cuy10WERFZIYauEQRBwBtvvIlZs+bi1q2beO65pxEV\ndUHusoiIyMowdMtg/PjJmD9/Ee7ejUWvXk/j9OlTcpdERERWhKFbRiNHjsGSJSuQkJCAPn2exfHj\nx+QuiYiIrARDtxyGDBmGDz9chZSUZPTt2xNHjhyWuyQiIrICDN1yev75F/DZZ18hMzMDAwf2xoED\n++QuiYiILBxDtwJ69uyNr75aD41Gg0GD+uG3336WuyQiIrJgDN0K6tr1aaxb9x1UKhVeemkQfvhh\nt9wlERGRhWLomkCHDh2xceNW2NnZ4+WXh2Dbts1yl0RERBaIoWsirVu3xebNO+Ds7IIxY17Ghg1r\n5S6JiIgsDEPXhFq0eBTbtu2Gp6cnJk58DatXfyZ3SUREZEEYuibWqFETbN/+I3x9q2HGjKn4+OMP\n5S6JiIgsBEO3EtSrVx87d+5BQEB1zJ37JpYuXQRJkuQui4iIZMbQrSRhYeHYuXMPAgODsGjRfCxY\n8A6Dl4ioimPoVqLg4NrYuXMPQkJCsWLFUsyePZ3BS0RUhTF0K1mNGjWxc+dPqFu3Hj777BNMnToR\noijKXRYREcmAoWsGfn5+2L79RzRo0Ahr136FceNGQ6PRyF0WERGZGUPXTLy9vbFt2240b94Cmzdv\nwujRI5CTkyN3WUREZEYMXTPy8PDE5s078fjjbbBr13YMHz4YmZmZcpdFRERmwtA1MxcXV2zcuBXt\n2z+Jn3/egyFDBiI9PV3usoiIyAwYujJwcnLC2rXfokuXbti79w8MGtQPqakpcpdFRESVrNTQzc7O\nRv/+/dGrVy/06NEDK1euNEddiufg4IAvv1yHHj164dChg+jfvxeSkhLlLouIiCqRTWkD7OzssGbN\nGjg6OkKr1eKFF15Au3bt0KhRI3PUp2h2dnb49NMvYW9vjy1bvkWfPj3w3Xc74O3tLXdpRERUCYza\nvezo6AhAN+vlR11My8bGBitXfooXXxyK06dPok+f7rh7967cZRERUSUodaYLAKIook+fPoiJiUFk\nZGTps9zgYHiJRa+8lHDsTLHDvZo3KPZ1WcerhCI9VGY9XwFwGDkan3++Cr16PY2tW3ejevUaFd9+\ngT6s6s+/oNweLKaeco5HzHWLqofjOd4SxisiL4CH/v0uzKjQValU2LFjB1JTU/Hqq6/i0qVLCAsL\nK/Fr1CqhyGu+vq4P+QZFx1rC+MI9VHY9n376Mby83LFo0SL07v0M/vjjDwQHB1d4+3l9yP3nWZHx\napVgUfWUZ/xDv8ZK6i843uBrLaCe8ozXP7eQeso7vrh/a+Wsp8zjoYy8MJYglfFiwCtXroSzszOG\nDRtW4ri4OOs+G9fX11WWHiRJwtKli7B48QJUr14D27btRkhIyf/BKYlcfZiSEnoA2IclUUIPgDL6\nsPgeRBHIzISQmQEh9x4ZmRCyMiFkZgKZGRAyMuE+dJBRmyt1ppuQkABbW1u4uroiMzMThw8fxiuv\nvFLhPqh4giBg6tTpcHBwxDvvzEbPnk9jy5ZdqFu3ntylERHJy8gAzHtfN7bg89z3szLzt5M7Hpm5\n28nIHZf3fna2cbWZKnTj4uIwffp0iKIIURTxzDPPoH379sYVQeU2duwEODo6YMaM19G79zP47rsd\naNiwsdxlEREVJUlAdjaE9DQI6em5tzT9PdLTIaQ95D1JA9fEFMMAzMrSPy9XAJa1fEEAHB0hOThA\ncnCE5OICyccXkoM9JAdHIO91BwdIjo6Avb3hcwcHuBj5vUoN3YiICGzfvr2CLVF5jBgxCg4Ojpg8\neRz69OmBTZu2onnzlnKXRUTWSJKAjIwioWd4nw4UfC2taEgWHZf7ulZb7tIcCpZZXAB6+0BydCga\ngA4ORQPRwQGSfe57jo4FxjoCDvaGz/O2aWsLCGU7NluYyUKX5BUZOQQODg4YO3YU+vV7Dhs2bMbj\nj7eRuywiqkyiqAuylBQIqakQUpINH6emQJWaCkg5cI5/UCQQ9Y/T8h8jIx2CCdbzllQqSE7OkJyc\nACcniN4+kJyc9K9JTk6QnAs8dnIGDN43fM+rpi/i00VdANo7AHZ2FQ5AS8bQtQJ9+z4POzt7jB49\nHAMH9sGaNZvQvv2TcpdFRAVJku64YEpKbiimFB+aqbrHKv17KbrX8h6npEBISzU6IJ2KK8XWVh9u\nors7pIDqucH38PArGJYlhSTs7U0bir6ukCz5RCoTY+haiR49noODw3oMH/4iBg9+HqtXr0GXLk/L\nXRaR9cvJyZ095oeeKi0lPwALhmZaam5gFgzRlPyvL+fFgyQ7O0iurpBcXCEGBUN0dc197gLJxS3/\nsasrJFfc0+i4AAAgAElEQVQ3iC4ukFxc4FHTDwlZAJxzw9HRUReMtram/TMik2HoWpHOnbth/frN\nGDJkIIYOjcSnn36JHj16yV0WkbxEEUJyEoTERKiSEiEkJkJISoQqMbH415ISgdRkeCcl6YKynMtr\nSmo1JBddOIoB1SE560JRdHXLD0gXV/2Y/OB0g+iS/1hycdHNHsvD1xXaKjRLVAKGrpVp164DNm3a\nhkGD+mPkyKH48MNV6N9/oNxlEVWMkcGpSnxQJECF5KQyHauUnJwAd3eInl6QagUWmUnqQzMvLAuG\npqsrRGfdPRwdFX3skSoHQ9cKPfZYa2zZshMDBvTB2LGjkJWVhcGDX5K7LKrqSgzOB/qQzAvSigan\n6O4BsXp1iPXqQ/LwgOTuAbHQveThAdHDE5KHJ0R3D0ju7oC9PXx9XfGAM0SSAUPXSjVr1gLbtn2P\n559/DpMnj0NmZgZefnm03GWRUmRkQBUfB9X9eKjux0OIj4fq/n2o7scDmalwi40zTXB6eEKsXgNi\n/UfyQ1Iflh6FXvPUvwc7u0psnqjyMHStWMOGjbBjxx707dsDM2dOQ0ZGJsaNmyh3WWSJ0tL0AaoP\n0fgCz+/H54bsfaji43UXLShB3hFIyckZoocHg5PISAxdKxcRURe7du1B3749MW/eW8jISMfrr8+A\nwGNNyiVJhiEaHwchNyzzn+cFqm52KqSnl75Ze3uI3j7QhIVD8vaG6O2ju/n4QPLxzX3uDc/QWojX\n2up21TI4icqEoasAISFh2LlTN+NdsuQ9ZGZmYvbstxm81kKSdB9FiYszDMr4eMNdvPfv658bc8at\n5OCgC9HwiEIh6gvJx0cfonnPJWcX404MqmKfqyQyJYauQgQGBmHXrp/Qt28PrFy5HBkZ6Zg/f7Hc\nZVVdkqQ7eSg2Fqo7t6G6GwukJcL5+q1Cx0lzH2dllb5JR0eIPr7Q1K2nuwpQboDqZ6O5AZoXrnB2\n5tm1RBaGoasgAQHVsWPHHvTv/xxWr/4MWVlZ+Prr1XKXpTzp6VDF3oH6bm6g6oP1DtR37kAVeweq\nu7HFzkYLXj1IcnKG6OMDTf1HdCFaIDAfGqJEZNUYugpTrVo1bN/+PQYM6IN1677BvXt3MG/eYtSu\nHSJ3aZZPo4Hq3l1daBYIT/Wd27rHsXd0AZuU+NBNSCoVRN9qutmof4D+pg2oDrewIDywdc4PUafi\nLuBHRErG0FUgLy9vbN26C6+8Mgy//PIL9u/fjwkTpmDs2ImwL++Vb6yZJEF4kKAL0oKz0dhYqGIL\nzFTj7pX4kRfRwwNiQAA0TZvlBmkARL8AiAHVIfr76+59fAGbh/y18nWFhsdCiao0hq5Cubm5Y+PG\nrfjzzz2YMGEiFi2ajy1bvsWiRcvQrl0HucsznbQ0qO8WmJnmBqsqNm+GGgvV3TslHjOVHBwg+gcg\np9XjEIsJUq2fP0T/AN0ViIiIKoChq2CCIGDAgAFo0aINFi2aj9WrP0O/fj3Rp08/vP32Qvj5+cld\n4sPlnoikvhEDJMXB4eKV/BlqgWBVJSc9fBMqFUQ/f90xU78AXaDm7uoV/fz1wSq5e/CEIyIyC4Zu\nFeDm5o758xdjwIBBmDZtErZt24Jff/0FM2fOxtChL0OtVstTWFoa1DdioI65BlXMdaivX4c6RndT\nxVyHKiVZP9S10JeKnp4Qa9SEpnkLaP0Dip2hij6+gFy9EREVg6FbhTRq1AQ//PAb1q79GvPnv40Z\nM17Hpk0b8P77/0OTJs1M/w2zs6G6dVMfpLowvaZ7fP06VPFxxX6Z5OQEbWAQcgJbQxsYBKd6dZDs\n6gWtf26g+gcADg6mr5eIqJIxdKsYtVqNoUNH4JlneuDtt2dh8+ZN6Nr1SQwdOgIzZ74Fd3cP4zcm\nirqPzsRch+r6NYNZqjrmOlR3bkMQxSJfJtnaQluzFjSPNIA2MAjawCCIuffawGBIPj4Gu3udfF2R\nxROQiEgBGLpVVLVq1fDRR59h0KAXMW3aJHz11Rf4/vtdePvt+ejb93nd1awkCUJCAtS5s1OVfvdv\n7u7gmzcgZGcX2bYkCBADqiPn0ccKhGkQxKBg3b1/AHf7ElGVxNCt4to2boIDH32OXz/7GCd3bkPW\nqyNxcdZ0NPX0hGNsLFRpqcV+nejjkztTDS4UrEHQ1qhV/kW5iYgUjKGrdFlZUF+OLjBLzdv9mzt7\nTUgAAAzOvQEAEu4jOeE+Yn184dGmLVA7JDdYdTNVba1AwMVFro6IiKwWQ1cJJAlCfDxsoqOgvhgF\ndXQUbC5GQX0pGrh9C17FXPBBsreHtlYgNI2b5odpkC5Qf4m+iCnz38btO7cReOEC3hs6Ak891VWG\nxoiIlIWha01EEapbN2Fz8QLUFy/mh2t0FFQPHhQZrq1eA2jfHhkBNQ1OVBKDgiBW8wNUqmK/Taem\nzXHwmR5YunQRPv30Iwwa1B/du/fEu+++hxo1alZ2l0REisXQtUQ5OVBfvQL1xagCs9eLsLl0sci6\nqJJKBW1wbeS0ehza8Aho6kRAWycC2vA6kFxc4evritRynPnr4uKCOXPm4fnnX8C0aZPwww+78Oef\nv2PatJkYOXI0bG1tTdUtEVGVwdCVU1oabC5dzA/V3Fmr+uoVCBqNwVDJwQHa0HBo6tTJD9fwCGhD\nQiv1pKV69epj5849+PbbDXj77VmYO/dNfPvtBixe/D+0avVYpX1fIiIlYuiagXD/ftHjrdEXob55\no8hY0d0DmibN8kO1Th1owiMg1gqU7WM2KpUKL7wwGF27Po13352Ldeu+QY8eXRAZOQSzZ78NLy9v\nWeoiIrI2DF1TkSSobt8qsEv4ItQXL8AmOgqq+/eLDNf6+SP7iQ76UNXWiYAmPAJStWoWex1gLy9v\nLFv2IQYMiMS0aZOwfv0a7NnzPd56ax4GDoyE6iHHiImISIehW1YaDdTXrhaatUZBHR1d5DOtkkoF\nMTAIWc1bFtglXAfaOhGQ3NxlaqDiWrV6DL/9th9ffPEpFi2aj4kTX8OGDWuxePH/UL/+I3KXR0Rk\nsRi6D5OeDpvTJwuEq+5sYfWVyxBycgyGSnZ20IaGI7tAqGrCI6ANDVPsNYJtbW0xZsxYPPdcb8ya\nNR3ff78TnTq1xahRr2Hq1Olw4ed4iYiKYOgCEFKSYXPqJGxOnoDNqeOwOXkCuHIZnoU+3yq6uELT\nsBG0deoW2CVcB2JQcJW9rGH16jXw5Zdr8dtvP2P69Nfx8ccfYMeOrZg/fzGeeeZZ3eUkiYgIQBUM\nXSE5CTanT+kC9uR/uvsrlw3GiG7uQLt2yKgdVuCEpgjdNYMZIsV66qmuOHCgHVasWIIPP1yOYcMi\n0blzVyxY8D6CgoLlLo+IyCIoOnSF5KQiM9giAevugewn2kPTqAk0TZoip1ETiMG14VvNrVyfb63K\nHB0dMX36bPTtOwBvvDEZv/76Mw4e3I9Jk17Hq6+Oh52dndwlEhHJSjGhW6aAbdwUmsZN9AHL2atp\nhYfXwdatu7F163eYM+dNLFjwDjZv3oRFi5ahbdt2cpdHRCQbqwzdIgF74jhsrl4xGKML2A7QNG7C\ngJWBIAjo128AOnfuioUL5+Grr75Anz7Pol+/AZg7dz6qVasmd4lERGZn8aGrD9gTx/NnsKUFbOOm\nupObGLCyc3f3wHvvLcWAAYMwbdpkbNnyLX755Se8+eYcDBkyDOoqegIaEVVNFhW6QlJi0V3EJQRs\nTpOm0DRqwoC1Ak2bNsdPP/2Br79ejQUL3sEbb0zGpk3r8P77y9GoURO5yyMiMgvZQteogPXwQHa7\nJ3Nnr00YsFZOrVZjxIhX8OyzPTFnzkxs27YFXbp0wPDhIzF9+iy4WfEFQ4iIjGGW0DUI2JPHYXvy\nONTXrhqMYcBWHX5+/li16ku88MKLmD59Cr744lPs2rUD8+YtRK9effnZXiJSrMoJ3T/+gOPev2Bz\n6oRxAdu4KcTAIAZsFdO+/ZPYu/cwVq5cjuXLl2DUqOFYv34tFi1agtDQcLnLIyIyucoJ3U6dkHcR\nQIOAzTsGy4ClXPb29pgy5Q306dMfM2ZMxR9//Ib27R/HuHGTMGHCFDgo9DKaRFQ1VU7oTp+OpPD6\nDFgyWu3aIdi4cSu+/34XZs16A0uXLsLWrd/lnvncW+7yiIhMotS12GJjYzFkyBA888wz6NGjB9as\nWVP6VhcuRHaPXjwmS2UiCAJ69HgOf/31D0aNeg03bsRg4MA+6NevH44d+wdSoWthExFZm1JDV61W\nY8aMGfjxxx+xadMmrF+/HpcvXy7ty4jKzcXFFfPmLcSvv+5HixaPYuvWrXj66U7o0OFxfPbZx0hI\nKLo+MRGRNSg1dH19fVGvXj0AgLOzM0JDQ3Hv3r1KL4yoQYOG+P77X/DTTz+hZ8/euHQpGrNmTUej\nRhEYNWoY9u/fC1EU5S6TiMhoZTqme/PmTVy4cAGNGjWqrHqIDKhUKnTt2hXNmrVGfHw8Nm/ehHXr\nvsb27VuxfftWBAUFIzJyCAYOjIS/f4Dc5RIRlUiQjDxQlpaWhhdffBGvvvoqnnrqqRLHBgej2BnI\nsWNpxY5v3ty52NflHK9SqYr0YE315ynYhyXUU57xeT3kjZckCX//fRTr13+DXbu2Iz39LADAwcER\nLi4ucHBwgCAIFlN/npgYFeKKWbnK0v/8C4/39XU16EPuesozvmAPllBPecf7+roiMLD4vT3WUD8A\ntGzpavV5Aej+fhvDqJmuRqPB+PHj8dxzz5UauHlUqqIF+Pq6PmRs8duQe3zhHuSup7zj8/qwlHrK\nM16lUhmMf/bZznj22c5ISkpCSIgKqakpyMzMQGZmBtRqNVxcXJCUdB9hYWEWUX9JX2MNf/6Fxxd8\nbAn1lGd83nNLqaf844v/AmupX/c11p8XxjJqpjtt2jR4enpixowZRm+4uP/RW5PC/5u3Vkrow9ge\nTp8+hQ0b1mDLlu+QlJQIAGjbth0iI4ege/eesn/mVwk/C0AZfSihB0AZfSihB6Dk/1QUVGpGHzt2\nDLt378aRI0fQq1cv9O7dG/v3769wgUSm1rBhIyxcuASnTkXh448/R5s2T+Dgwf0YM+ZlNGpUBzNn\nvo6zZ8/IXSYRVWFGH9MtK2v/n4uS/vdl7X1UpIcrVy5hw4Z12LhxHeLidGfdN2vWHJGRL6F3775w\ncTHuf6emoISfBaCMPpTQA6CMPpTQA2DCmS6RNQsJCcOsWXNx4sR5fPPNRnTp0g0nThzHlCnj0aBB\nHUyc+Br++ecoL7xBRGbB0KUqwdbWFk8/3R3r1n2H48fPYcaM2fDx8cWGDWvRvXtntGvXCqtWrcT9\n+7zwBhFVHoYuVTkBAdUxadLr+PvvE9i8eSd69eqDq1ev4K23ZqJRozoYOXIo9u79gxfeICKTk20R\neyK5qVQqtG//JNq3fxL379/Hli2bsH79GuzcuQ07d25DYGAQXnhhMF54YTCqV68hd7lEVMmys4H0\ndCA9XShwr3uclpb/WkZG0TEbNxr3PXgi1UMo6eC+tfdhzh4kScKxY/9g/fo12L59K9LT06BSqdCx\n41OIjHwJXbp0g62tbbm2rYSfBaCMPpTQA6CMPsrSgyiimMDLv8/IKDksC79XOEA1mvIv0GNsknKm\nS1SAIAho0eJRtGjxKObNW4gdO7Zh/fpv8Ntvv+C3336Br281DBgwCIMHD0FISNELbxCRjiQBaWlA\naqqAlBQBKSmGj9PSdI+1WiA+3v6hQVg4LE1BpZLg5AQ4OenuvbzEAs91rzk7S3B0zB9jeF/0NehX\nkS8ZZ7oPoYT/QQLK6MMSejh37iw2bFiDzZs34cGDBwCA1q3bIjJyCJ599jk4OjqWug1L6MMUlNCH\nEnoATN+HJAFZWXnhmB+SqanIDUvd4/zwzH+vuK+RpPKHpL19ySFX8N7R0XCMs/PDw9LREbC3N/2q\ns8Z+ZIih+xD8S2k5LKmHzMxM7NnzPdatW4MDB/YCANzc3NGv3/OIjHwJDRs+fDEQS+qjIpTQhxJ6\nAPL70GhQKAx1j4ubZRYOybzHea/n5JQvjezsJLi6SnBxAVxcdI9dXQFXVwnOzvmPdWN0z11cJNSs\n6YTs7LQiM0y12sR/WJWMoVtBSvtLac0stYdr165i48a12LhxPWJj7wAAGjduisjIIejTpx/c3NwN\nxltqH2WlhD4srQdJ0h2rTEwUkJgoICkp7x548CD/eeH309JUSE6Wyr3bVRAMw9DZuWAwokBA5j/P\nC1NdkOaHp719+Xq3tJ9FeTF0K0hJvwjW3oel96DRaPD7779i/fpv8OuvP0Or1cLR0RE9e/ZGZORL\naNXqMQiCYPF9GEsJfVRWD5mZKBSQMAjJwqFZ8P3sbOOD09ZWgru7BC8vFRwdtUVmjwXDMO/1ggGa\n956Tk+l3s5aVEn6fAIZuhSnpF8Ha+7CmHmJj7+Dbbzdg/fo1uHbtKgAgLCwckZEvYeTIobCzc5O5\nwoqzpp/Hw5TUQ04ODELz4YFZ9P3MTOMTTK2W4OEhwd0d8PCQ9Dd3d6nQ86Lv54Wl0n8W1oShW0FK\n+kWw9j6ssQdRFHHo0EGsW/cNfvhhF7KysgAAoaFhaN26rf4WEFBd5krLzlp+Hnlnz8bHC7h/X0B8\nvID4eBXu3xeQnm6PO3dy9KFZcBduWXbVCoIuFN3dJXh65gem4fOi73t46HbXVnSWaS0/i5IooQeA\noVthSvpFsPY+rL2HBw8SsG3bZhw48Cf27z+A1NT8XkJCQvUB3KbNE1YRwnL+PDIzDUM0Li7vsapQ\nuOoeZ2QYl2pubg+bZepC82GzUFfXsq+nakrW/ncDUEYPAEO3wpT0i2DtfSihB0DXx507D3DmzCn8\n9ddBHDp0AEeOHEZKSrJ+TO3aIWjT5gk8/ngbtGnzhEVeCcuUP4+cHCAhQReexYWmLlhV+sepqaWH\nqL29BB8fw5u3twQfH1H/PDTUCZKUCg8PCW5ugI2VXrFACX83lNADwNCtMCX9Ilh7H0roASi+D61W\naxDChw8fMgjh4ODaBiFco0ZNc5ddREk/D61Wd7Zt4QDNn5EWfE+FxMTSQ9TGJi808wPU17domOa9\n7uxc+m5bJf9OWRsl9AAwdCtMSb8I1t6HEnoAjOtDq9Xi7NnTBiGcnJykfz8oKNgghGvWrFXZZUOj\nAe7dE3DnjoDYWBUyMx1x7VpWkRlpfLyAhAQBolhy4qlUEry8Cs9Ciz7OC1N398q5kEFV+Z2ydEro\nAWDoVpiSfhGsvQ8l9ACUrw+tVotz587gr78O4NChgzh8+BCSkhL17wcGBqNNm/wTs2rVCizT9tPS\ngNhYAbdvq/SheueOgNu38x/fu1d6kHp46EKy5Bmp7ubpKcl+4YOq/DtlaZTQA8DQrTAl/SJYex9K\n6AEwTR+6ED6LQ4cO4K+/DuLw4b8KhXAQWrdui8cea4v69dtDrQ7MDVEVYmMF3LmjC1LdTYXk5IeH\nqZ2dBH9/CQEBIgICJAQESPD3FxEW5gBb23T4+OhC1ctLQjnXgJANf6cshxJ6AIwPXSs9fYCoalKr\n1ahTpxHc3BqjcePx6NVLwokT93DqVDyuXMnErVu22LTJD5s21QBg99DtuLtLqFFDRPPmulD195dQ\nvXr+44AA3ey0uN26vr4OiIvTVl6TRArG0CWyEJIEJCejwK5e3Wy04K7e2FjdCUiGaufedMdLfXyy\n4eAQj5yca0hMPI2srCsAbgG4CT8/EW3ahKB9+1Zo3botAgODIMh9SSKiKoShS2QGWi1w6xZw5oyq\nwK7egrt7da+VdGEGJyfdDLRuXU3uzFTM3eWrm6FWr67b3as7XuoKoCFE8RGcP38Ohw8fxF9/peDw\n4YPYtu0Atm37BgBQo0ZN/WeEW7dui6CgYIYwUSXiMd2HUNJxBmvvw1p6SEkBrl9X4do1Fa5fF3D9\nukr//ObNkldv8fExPG4aEKAL1bxdvQEBItzcKn4WryiKuHDhfG4IH8Thwwdx//59/fs1atTUnxnd\nunVbBAfXLhLC1vLzKIkSegCU0YcSegB4IlWFKekXwdr7sJQetFrdmb7Fher16wLu3y/+0kQ+PiKC\ngiSEhqrh5ZVtcGJSQIAIP7/yr9BSUaIoIirqAg4dOoBDh/7CoUMHDEK4evUaBiFcu3YIqlVzs4if\nR0VYyu9URSmhDyX0ADB0K0xJvwjW3oc5e0hNhT5M84JVF6oq3LhR/EowtrYSAgMlBAWJ+ltwcP5z\nFxfz91FekiQhKuoC/vrrAA4f1oVwfHy8/n1//wA0bdoEgYEhqFMnAuHhEQgPrwNvb28Zqy47a/hZ\nGEMJfSihB4BnLxMVSxR1s9W8UL12LT9Ur18v7iQlHW9vEQ0aFAxV3ew1KEg3a5X7c6emIggC6tat\nh7p162HEiFcgSRIuXozSh/CRI4ewZ8+eIl/n7e2tD+D8WwRq1qwFlZwXJyayMAxdUpz0dBjMVAvu\nAo6JUSErq+hs1cZGQq1aEho00BQJ1aAg3fHUqkgQBERE1EVERF0MHz4SAGBjo8GRI/8hOvoiLl6M\nwqVLuvu//z6CI0cOGXy9o6MjQkPDUadOnQKhHIGQkFDYy7VPnUhGDF2yOpIE3L1reGy14Gz13r3i\nZ1aenhLq1Ss6Uw0K0p35a60XvTc3T09PtGjxKFq0eNTg9czMTFy9egXR0VEFwvgiLl+OxpkzpwzG\nqlQqBAUFo06dCISF1cndVa2bIbu7e5izHSKz4j8zZJEkSbcb+MIFFWJjgTNn7PWhGhOjKnbJNrVa\nQs2aEtq10+hDVXevu7m7y9BIFeLg4IB69eqjXr36Bq+LooibN2/khvFF/cw4OjoKP/+8Bz//bLi7\nulo1v9wwDjc4bhwQUJ0fZyKrx9AlWUmS7mL6Fy6oEBWlu124oMbFiyokJRX8B1Z3dSU3Nwnh4WKB\nMJX0M9caNThbtUQqlQqBgUEIDAxCp05dDN67f/8+oqOj9Luqo6OjcOlSNA4e3I+DB/cbjHV2dkF4\neDjCwyMMZsjBwbVha23XoaQqi/9EkdnExeWHa37Iqoss76ZWSwgJEfHEEyIiIkS0bGkPb+80BAWJ\n8OCeR0Xx9vaGt3drPPZYa4PX09PTcflydIEw1s2Qz507ixMnjhuMtbGxQe3aIQXCOFx/7+Ji3Bml\nRObC0CWTu39fKBSsulvhz7GqVBJq15bQurUGdevqAjYiQkRoqGjwuVVfX3vExYlm7oLk5OTkhIYN\nG6Nhw8YGr2s0GsTEXEN0dLTBSVzR0RcRHX0RP/6422B89eo1DM6mzpsh+/i4mLMdIj2GLpXbgwdA\nVJS60K5hVZGP3QiChKAgCS1b5hiEa1iYCAcHmYonq2RjY4OQkDCEhISha9en9a9LkoR79+4VOYkr\nOjoK+/b9iX37/jTYjouLC/z9A+DvHwA/P//cez+D1/z8/OHk5GTuFknhGLpUqqQk4MIFtUGwRkWp\nij1LODBQRJcuGkREaBERIaJuXV248t8uqkyCIMDPzw9+fn5o27adwXupqSkFPt6kmyHfvHkdt2/f\nxqVL0SVu193dA/7+/vDzC4C/v39uKPvnhnL+Y378iYzF0CW9lBTkBqraIFxjY4uGa61aIp56SpM7\na9Wibl0R4eEinJ1lKJyoBC4urmjatDmaNm2ufy3vKkhZWVm4d+8u7t6NRWxsLO7evYPY2FjExt5B\nbOyd3NfvICrqQonfw8vLq5hgDjAI6WrV/HjCFzF0q6LUVODixfwzhfPC9fbtouFao4aIjh01ubNW\n3ey1Tp38SxsSWTN7e3vUqhWIWrUCSxyXkZGBe/fuFgjm/HDOC+Zbt27i/PmzD92GIAjw9vbRB3HB\nXdsFw9nHxxc2PA1fsfiTVbD0dODff4HDh230s9eoKBVu3CgargEBIjp00Oh3CeftHnblyZ9EcHR0\nRFBQMIKCgkscl5aWhrt3Y/VBnB/M+Y+vXLlc5GIhBalUKvj6Vis0Yy46g7a2612TDkNXITQa3a7h\n48fVOH5chf/+081gRREAHPXjqlUT8cQT+WcL581eeeEIoopzdnZGSEgoQkJCSxyXmppisBu74K7t\n/F3a53Hy5PGHbsPGxgZeXl5wc3OHu7s73N09Ctx76J97eHjAza3ovVopFwy3MgxdKyRJQEyMgOPH\n1fjvP13InjqlNrhKk5OThJYttWjZ0gaBgZn62aunp4yFExEA3XHmsDBXhIWFP3SMJElITk4q9hjz\n3bt3ERt7B8nJiUhIeICYmOvIzs4uUw2urm7FhLW7QVjnv+ZpEOCOjo68Olg5MXStQEICcOJEXsDq\nQrbgx3JUKgl164po1kyLZs1ENG2qm73a2OSdMJIjY/VEVB6CIOhnrBERdYsdk3dCmCRJyMjIQHJy\nEhITE5GUlISkpAe597rniYmJ+vcL3sfEXEdKSnKZarOzs9PPmjnLLhuGroXJyABOn87bTawL2mvX\nDI/BBgaKeO65HDRtqgvZhg21PGuYqAoTBAFOTk5wcnKCv39Amb9eq9UiOTnJIKSTkhILBHhigZth\nkF+/fg05OWX7j72rq5s+gH18vGBraw9HR139jo6OcHJy1t87ORV87lRgXP69s7Pu3hrCnKErI60W\niI5W4b//VPpZ7PnzKmg0+bttPD0ldOyoyQ1YLZo0EeHrK8lYNREpjVqthqenFzw9vcr8tWWdZRcM\n7piY6zh79rTJ+rC3ty8S2oXDOu/2sJA3DHvDcLezs6vwbnWGrplIEnD7tqA/Bnv8uBonTqiRlpb/\nA7S3l9CkiW43cdOmulvt2hJ46ISILFVFZ9ne3s6IibmH9PR0ZGSkF3ufd8vIyEB6elqh+4LjdK+l\npaUjOTkZsbGxyMhIhyia5jKyarW6SFjnzcT3799r1DZKDd2ZM2di79698Pb2xu7du0sbTrmSkqDf\nRZx3NnHBKzgJgoSICBFNm4r6WWzduiLs7GQsmojIzFQqFZydneFcScfIJElCVlZWgSDXBXZ6enEB\nXtmry4cAAAsRSURBVFyQFwx9w/vExESkp6eVafd6qaHbp08fvPjii5g2bVqFGleyrCzg7FmVwdnE\nly4ZHluoXl1E9+45aNpUN5Nt3FjLz8ASEVUyQRDg4OAABweHcu0+N4ZJQ7dFixa4detWhQpSElEE\nLl/WHYfNm8meOaNCTk7+PmBXV91C6rrdxLqZrL8/j8MSESlRWS7vyWO6pbh7V3ccNu9kpxMn1EhJ\nyQ9YW1sJDRqI+mOwzZrplqZTFb3oExERVXEM3QIkCTh/XoUDB9Q4fhw4csS5yPWIw8K06NYt/2Sn\nRx4xXPuViIjoYSotdH19reOA5Y0bwG+/6W6//w7cvZv/np+fCj17Ao8+CrRqBbRoAXh4qAGoAVjP\naiHW8rMoiRJ6ANiHJVFCD4Ay+lBCD8YyKnQlqezHI+PiUsr8NeaQlAQcPGiD/fvV2L/fBpcv589k\nq1UT0a+fFu3aadCzpyMcHVMMPq6TkwPExclQdAXkXbHGmimhB4B9WBIl9AAoow8l9AAY/x+HUkN3\nypQpOHr0KBITE9GhQweMGzcOffv2rXCB5pKVBfzzj1ofsidOqCCKuiR1dpbQpYsG7dpp0K6d7tKJ\neSHr62t9AUtERJat1NBdunSpOeowGVHUfXxn3z5dyB49mr8QgI2NbhGAdu10t2bNtOCa0kREZC6K\nOJHq+nUB+/frdhkfOKBGQkL+LuN69fJCVoPHH9dy8XUiIpKNVYZuQoLuuGzebPb69fyQrV5dxMCB\nOWjXToMnntDCz4+fjyUiIstgFaGbkQEcPZp/XPb0aRUkSbfL2M1NwjPP5Ohns6GhvFYxERFZJosM\nXa0WOHVKpd9l/PffamRl6ZLUzk5Cmzb5u4wbNdKtG0tERGTpLCKuJAm4elXAvn26kD140AZJSfnT\n1YYN80O2VSstnJxkLJaIiKicZAvde/cEHDyYv8v45s3847KBgSJ69tTtMm7TRgsfHx6XJSIi62e2\n0E1NBY4cUetns+fP56/C4+kp6UO2XTsNgoMZskREpDyVFro5OcDx4/nHZf/9Vw2NRrfL2MFBQvv2\nugtStG+vQYMGXCCAiIiUr1JCt2dP4M8/XZCaqgtZQZDQpImov/JTy5ZaODhUxncmIiKyXJUSurt3\nAyEhEvr10+0ybttWAw+PyvhORERE1qNSQvfaNcDJKa0yNk1ERGS1KuVIalBQZWyViIjIuvH0JSIi\nIjNh6BIREZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIR\nEZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eI\niMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMGLpE\nRERmwtAlIiIyE4YuERGRmRgVuvv370e3bt3QtWtXfPbZZ5VdExERkSKVGrqiKGLevHlYvXo1vv/+\ne/zwww+4fPmyOWojIiJSlFJD99SpUwgKCkKNGjVga2uL7t274/fffzdHbURERIpSaujevXsXAQEB\n+ud+fn64d+9epRZFRESkRKWGriRJ5qiDiIhI8WxKG+Dv74/bt2/rn9+9exfVqlUrdcO+vq4Vq8wC\nKKEHQBl9KKEHgH1YEiX0ACijDyX0YKxSZ7oNGzZETEwMbt26hezsbPzwww/o1KmTOWojIiJSlFJn\numq1GrNnz8bw4cMhSRL69euH0NBQc9RGRESkKILEg7ZERERmwStSERERmQlDl4iIyEwYukRERGZS\n6olUZbF//34sWLAAkiShb9++eOWVV0y5ebOYOXMm9u7dC29vb+zevVvucsolNjYW06ZNQ3x8PNRq\nNfr3748hQ4bIXVaZZWdnIzIyEjk5OdBqtejatSvGjh0rd1nlIooi+vbtCz8/P6xatUrucsqlY8eO\ncHFxgUqlgo2NDbZs2SJ3SeWSkpKCN998E9HR0VCpVFiwYAEaN24sd1lGu3r1KiZNmgRBECBJEm7c\nuIEJEyZY5d/xr7/+Glu2bIEgCKhTpw4W/r+9u3mJag8DOP6dHKRQexElCyzIjCySFr1AEyamSTXV\nxGCLNiVRbdIow14oghYJLfoHWkREEBEaRG1EszGmQiuGYIgwIhhMKkRT5yXPnOcu4l64G+89x7nz\na7rPZz1n+A6HmYcznHmmo4P8/HzTWY7cunXrr/fCv/qslQxJp9NSX18vsVhMfvz4IXv37pWhoaFM\nPX3WDAwMSDQaFb/fbzrFtS9fvkg0GhURkcnJSdmxY0dOngsRkXg8LiIilmVJU1OTRCIRw0Xu3Lx5\nU9ra2uT48eOmU1yrq6uTsbEx0xmzdvbsWbl//76IiExPT8vExIThIvfS6bT4fD4ZHh42neLYyMiI\n1NXVSSqVEhGRkydPSldXl+EqZ96/fy9+v19SqZRYliWHDx+WT58+zXhMxr5e/l12NG/YsIH58+eb\nzpiV0tJSqqqqACgoKKCioiJnV3fOmzcP+HnVa1mW4Rp3RkZGePr0KU1NTaZTZkVEsG3bdMasTE5O\nMjg4SDAYBMDr9VJYWGi4yr1wOMyyZcv+tqo3l9i2TSKRwLIsksnkv1q89Cv58OED69evJz8/n7y8\nPDZu3Eh3d/eMx2Rs6OqO5l9TLBbj3bt3VFdXm05xxbZtAoEAPp8Pn8+Xk6/j6tWrtLe34/F4TKfM\nisfj4ciRIwSDQe7du2c6x5VYLMaiRYs4f/48+/fv59KlSySTSdNZrj1+/Jjdu3ebznBl8eLFNDc3\nU1tbS01NDUVFRWzZssV0liOVlZUMDAwwPj5OIpEgFArx+fPnGY/J2NAV/bnvL2dqaorW1lYuXLhA\nQUGB6RxX5syZw4MHDwiFQkQiEYaGhkwnOdLX10dJSQlVVVU5/x65e/cunZ2d3Lhxgzt37jA4OGg6\nyTHLsohGoxw8eJCuri7mzp2bs/8RPj09TW9vLzt37jSd4sr379/p6enhyZMn9Pf3E4/Hc+4+moqK\nCo4ePUpzczPHjh1j9erVeL0z3yqVsaHrdkez+m9YlkVrayv79u2jvr7edM6sFRYWsmnTJvr7+02n\nOPL69Wt6e3vZvn07bW1tvHz5kvb2dtNZrpSWlgJQXFxMQ0MDb9++NVzkXFlZGWVlZaxbtw6AxsZG\notGo4Sp3QqEQa9eupbi42HSKK+FwmPLychYuXEheXh4NDQ28efPGdJZjwWCQzs5Obt++zYIFC1i+\nfPmMj8/Y0P2ddjTn+hUJ/LwLe+XKlRw6dMh0imujo6NMTEwAkEwmef78OStWrDBc5czp06fp6+uj\np6eH69evs3nzZq5du2Y6y7FEIsHU1BQA8XicZ8+eUVlZabjKuZKSEpYsWcLHjx8BePHiRc6utX30\n6BF+v990hmtLly4lEomQSqUQkZw9F6OjowAMDw/T3d39j+ckYz8Z+l12NP95NTI2NkZtbS0tLS1/\n3XSRK169esXDhw9ZtWoVgUAAj8fDqVOnqKmpMZ3myNevXzl37hy2bWPbNrt27WLbtm2ms/6Xvn37\nxokTJ/B4PKTTafbs2cPWrVtNZ7ly8eJFzpw5g2VZlJeX09HRYTrJsWQySTgc5sqVK6ZTXKuurqax\nsZFAIIDX62XNmjUcOHDAdJZjLS0tjI+P4/V6uXz5MkVFM/9jku5eVkoppbJEN1IppZRSWaJDVyml\nlMoSHbpKKaVUlujQVUoppbJEh65SSimVJTp0lVJKqSzRoauUUkpliQ5dpZRSKkv+AO2e4yf8wTuC\nAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0xc1dc310\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Train our variables.\n", - "\n", - "# numpy is used for its asscalar() function.\n", - "import numpy as np\n", - "\n", - "num_training_steps = 10\n", - "\n", - "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n", - " loss_at_step = []\n", - " w_at_step = []\n", - " b_at_step = []\n", - " for step_num in range(num_training_steps):\n", - " loss, gradients_and_variables = value_and_gradients_fn(inputs, labels, wb)\n", - " loss_at_step.append(np.asscalar(loss.numpy()))\n", - " \n", - " optimizer.apply_gradients(gradients_and_variables)\n", - " w, b = wb.variables\n", - " w_at_step.append(np.asscalar(w.read_value().numpy()))\n", - " b_at_step.append(np.asscalar(b.read_value().numpy()))\n", - "\n", - " print(w_at_step)\n", - " t = range(0, num_training_steps)\n", - " plt.plot(t, loss_at_step, 'k',\n", - " t, w_at_step, 'r',\n", - " t, [true_w] * num_training_steps, 'r--',\n", - " t, b_at_step, 'b',\n", - " t, [true_b] * num_training_steps, 'b--')\n", - " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n", - " plt.show()\n", - "\n", - "train_model(inputs, labels, wb, optimizer, num_training_steps)" + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "UNurY9VJ-hpH" - }, - "source": [ - "## Other Ways to Compute Gradients\n", - "\n", - "Using our loss function as an example (`loss_fn()`), there are several other ways we could compute gradients:\n", - "\n", - "1. `tfe.implicit_gradients()`\n", - "1. `tfe.gradients_function()`\n", - "1. `tfe.implicit_value_and_gradients()`\n", - "1. `tfe.value_and_gradients_function()`\n", - "\n", - "Each of these functions does the following:\n", - "* Wraps a function.\n", - "* Returns a function with the same input signature as the wrapped function.\n", - "\n", - "They differ only in what information they return.\n", - "\n", - "### Gradients-only functions\n", - "\n", - "The following two functions return a function that returns only the variables' gradients:\n", - "\n", - "1. `tfe.gradients_function()`: Returns the partial derivatives of the function `f()` with respect to the parameters of `f()`.\n", - "1. `tfe.implicit_gradients()`: Returns the partial derivatives of the function `f()` with respect to the trainable parameters (`tf.Variable`) used by `f()`.\n", - "\n", - "In our example above, the `tf.layers.Dense` object encapsulates the trainable parameters.\n", - "\n", - "### Value and gradients functions\n", - "\n", - "The following two functions are identical to their counterparts above, except that they also return the value of the wrapped function.\n", - "\n", - "1. `tfe.implicit_value_and_gradients()`\n", - "1. `tfe.value_and_gradients_function()`\n", - "\n", - "### Gradient demos\n", - "\n", - "In the demos below, we show examples for the `implicit_*` functions, since our existing loss function works seamlessly with these versions. (The other versions require that your parameters are tensors and tensors only; in our example, we're using a `Dense` layer.)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 85, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 100, - "status": "ok", - "timestamp": 1505502831671, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "aEoCftnfAIH5", - "outputId": "72f1c1dc-a574-463f-f860-c4e5f48fcdaa" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[(\u003ctf.Tensor: id=673, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n", - " (\u003ctf.Tensor: id=671, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)]" - ] - }, - "execution_count": 13, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# tfe.implicit_gradients() demo\n", - "gradients_fn = tfe.implicit_gradients(loss_fn)\n", - "\n", - "# Returns only gradients and variables:\n", - "gradients_fn(inputs, labels, wb)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 102, - "output_extras": [ - { - "item_id": 1 - } - ] - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 88, - "status": "ok", - "timestamp": 1505502831785, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "bbgCUdCzAVhH", - "outputId": "152aa9b6-9e42-4b7e-848a-9423c0b1929c" + "id": "4U1KKzUpNl58" }, - "outputs": [ - { - "data": { - "text/plain": [ - "(\u003ctf.Tensor: id=688, shape=(), dtype=float32, numpy=1.0623235\u003e,\n", - " [(\u003ctf.Tensor: id=720, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n", - " (\u003ctf.Tensor: id=718, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n", - " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)])" - ] - }, - "execution_count": 14, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], "source": [ - "# tfe.implicit_value_and_gradients() demo\n", - "value_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)\n", + "## Next Steps\n", "\n", - "# Returns the value returned by the function passed in, gradients, and variables:\n", - "value_gradients_fn(inputs, labels, wb)" + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." ] } ], "metadata": { "colab": { + "collapsed_sections": [], "default_view": {}, - "last_runtime": { - "build_target": "", - "kind": "local" - }, - "name": "Eager Execution Tutorial: Working with Gradients", + "name": "Automatic Differentiation", "provenance": [], "version": "0.3.2", "views": {} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb index 0088da5c4b583d..bfcc7feb075c40 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb @@ -16,7 +16,9 @@ "\n", "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", "\n", - "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." + "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", + "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", + "As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." ] }, { @@ -48,11 +50,8 @@ "# Import TensorFlow.\n", "import tensorflow as tf\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "import tensorflow.contrib.eager as tfe\n", - "\n", "# Enable eager execution\n", - "tfe.enable_eager_execution()" + "tf.enable_eager_execution()" ] }, { @@ -137,32 +136,27 @@ "source": [ "# Step 3: Iterate\n", "\n", - "Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n", - "\n", - "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls." + "When eager execution is enabled `Dataset` objects support iteration.\n", + "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 0, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "height": 153, - "output_extras": [ - { - "item_id": 1 - } - ] + "base_uri": "https://localhost:8080/", + "height": 153 }, "colab_type": "code", "executionInfo": { - "elapsed": 201, + "elapsed": 388, "status": "ok", - "timestamp": 1505952405928, + "timestamp": 1525154629129, "user": { "displayName": "", "photoUrl": "", @@ -171,7 +165,7 @@ "user_tz": 420 }, "id": "lCUWzso6mbqR", - "outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a" + "outputId": "8e4b0298-d27d-4ac7-e26a-ef94af0594ec" }, "outputs": [ { @@ -179,9 +173,9 @@ "output_type": "stream", "text": [ "Elements of ds_tensors:\n", - "tf.Tensor([4 9], shape=(2,), dtype=int32)\n", + "tf.Tensor([1 9], shape=(2,), dtype=int32)\n", "tf.Tensor([16 25], shape=(2,), dtype=int32)\n", - "tf.Tensor([36 1], shape=(2,), dtype=int32)\n", + "tf.Tensor([ 4 36], shape=(2,), dtype=int32)\n", "\n", "Elements in ds_file:\n", "tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n", @@ -191,22 +185,19 @@ ], "source": [ "print('Elements of ds_tensors:')\n", - "for x in tfe.Iterator(ds_tensors):\n", + "for x in ds_tensors:\n", " print(x)\n", "\n", "print('\\nElements in ds_file:')\n", - "for x in tfe.Iterator(ds_file):\n", + "for x in ds_file:\n", " print(x)" ] } ], "metadata": { "colab": { + "collapsed_sections": [], "default_view": {}, - "last_runtime": { - "build_target": "", - "kind": "local" - }, "name": "Eager Execution Tutorial: Importing Data", "provenance": [], "version": "0.3.2", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb new file mode 100644 index 00000000000000..84f1d031d40604 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k2o3TTG4TFpt" + }, + "source": [ + "# Training Models\n", + "\n", + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3LXMVuV0VhDr" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "PJ64L90aVir3" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eMAWbDJFVmMk" + }, + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VkJwtLS_Jbn8" + }, + "outputs": [], + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wfneTXy7JcUz" + }, + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "itxmrMil6DQi" + }, + "outputs": [], + "source": [ + "v = tfe.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-paSaeq1JzwC" + }, + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BMiFcDzE7Qu3" + }, + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "gFzH64Jn9PIm" + }, + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_WRu7Pze7wk8" + }, + "outputs": [], + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tfe.Variable(5.0)\n", + " self.b = tfe.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xa6j_yXa-j79" + }, + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Y0ysUFGY924U" + }, + "outputs": [], + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qutT_fkl_CBc" + }, + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gxPTb-kt_N5m" + }, + "outputs": [], + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-50nq-wPBsAW" + }, + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 293 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1210, + "status": "ok", + "timestamp": 1527005898290, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "_eb83LtrB4nt", + "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current loss: 9.48636\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "sSDP-yeq_4jE" + }, + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "MBIACgdnA55X" + }, + "outputs": [], + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RwWPaJryD2aN" + }, + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 446 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 569, + "status": "ok", + "timestamp": 1527005915434, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "XdfkR223D9dW", + "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: W=5.00 b=0.00, loss=9.48636\n", + "Epoch 1: W=4.58 b=0.42, loss=6.28101\n", + "Epoch 2: W=4.24 b=0.76, loss=4.29357\n", + "Epoch 3: W=3.98 b=1.02, loss=3.06128\n", + "Epoch 4: W=3.78 b=1.23, loss=2.29721\n", + "Epoch 5: W=3.61 b=1.39, loss=1.82345\n", + "Epoch 6: W=3.49 b=1.52, loss=1.52970\n", + "Epoch 7: W=3.38 b=1.62, loss=1.34756\n", + "Epoch 8: W=3.30 b=1.70, loss=1.23463\n", + "Epoch 9: W=3.24 b=1.76, loss=1.16460\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vPnIVuaSJwWz" + }, + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Training Models", + "provenance": [], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb new file mode 100644 index 00000000000000..4fe3a0e3f3d431 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -0,0 +1,551 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "pwX7Fii1rwsJ" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UEu3q4jmpKVT" + }, + "source": [ + "# High level API\n", + "\n", + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zSFfVVjkrrsI" + }, + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8PyXlPl-4TzQ" + }, + "outputs": [], + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensionss is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fn69xxPO5Psr" + }, + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1527783641557, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "E3XKNknP5Mhb", + "outputId": "c5d52434-d980-4488-efa7-5660819d0207" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", + "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" + ] + }, + "execution_count": 3, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1527783642457, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Wt_Nsv-L5t2s", + "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 226, + "status": "ok", + "timestamp": 1527783643252, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "6ilvKjz8_4MQ", + "outputId": "f647fced-c2d7-41a3-c237-242036784665" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" + ] + }, + "execution_count": 5, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "O0kDbE54-5VS" + }, + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes requires to create the variables will need to be explicitly specified." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 391 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 251, + "status": "ok", + "timestamp": 1527783661512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "5Byl3n1k5kIy", + "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", + "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", + " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", + " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", + " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", + " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", + " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", + " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", + " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", + " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", + " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" + ] + } + ], + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tk8E2vY0-z4Z" + }, + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Qhg4KlbKrs3G" + }, + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 190 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 420, + "status": "ok", + "timestamp": 1527783698512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "N30DTXiRASlb", + "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]\n", + "\n", + " [[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" + ] + } + ], + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wYfucVw65PMj" + }, + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 361, + "status": "ok", + "timestamp": 1526674830777, + "user": { + "displayName": "Alexandre Passos", + "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", + "userId": "108023195365833072773" + }, + "user_tz": 420 + }, + "id": "L9frk7Ur4uvJ", + "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", + "array([[[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]],\n", + "\n", + " [[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]]], dtype=float32)\u003e" + ] + }, + "execution_count": 26, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "c5YwYcnuK-wc" + }, + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "4 - High level API - TensorFlow Eager.ipynb", + "provenance": [], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 8517a3bf7b6aeb..b14ef1df8ff4c6 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -36,9 +36,7 @@ def device_and_data_format(): 'channels_last') -def random_batch(batch_size, device_and_format=None): - _, data_format = device_and_format or device_and_data_format() - +def random_batch(batch_size, data_format): shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) shape = (batch_size,) + shape @@ -51,15 +49,17 @@ def random_batch(batch_size, device_and_format=None): return images, one_hot -def train_one_step(model, images, labels, optimizer): - - with tfe.GradientTape() as tape: +def compute_gradients(model, images, labels): + with tf.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - grads = tape.gradient(loss, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) class ResNet50Test(tf.test.TestCase): @@ -70,7 +70,7 @@ def _apply(self, defun=False, execution_mode=None): if defun: model.call = tfe.defun(model.call) with tf.device(device), tfe.execution_mode(execution_mode): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) tfe.async_wait() self.assertEqual((2, 1000), output.shape) @@ -91,7 +91,7 @@ def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) @@ -101,7 +101,7 @@ def test_apply_with_pooling(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): - images, _ = random_batch(2) + images, _ = random_batch(2, data_format) output = model(images, training=False) self.assertEqual((2, 2048), output.shape) @@ -115,8 +115,9 @@ def _test_train(self, execution_mode=None): name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) - images, labels = random_batch(2) - train_one_step(model, images, labels, optimizer) + images, labels = random_batch(2, data_format) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) self.assertEqual(320, len(model.variables)) tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) @@ -134,20 +135,22 @@ def test_no_garbage(self): model = resnet50.ResNet50(data_format) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): - images, labels = random_batch(2) + images, labels = random_batch(2, data_format) gc.disable() # Warm up. Note that this first run does create significant amounts of # garbage to be collected. The hope is that this is a build-only effect, # and a subsequent training loop will create nothing which needs to be # collected. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() previous_gc_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL) for _ in range(2): # Run twice to ensure that garbage that is created on the first # iteration is no longer accessible. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() # There should be no garbage requiring collection. self.assertEqual(0, len(gc.garbage)) @@ -182,9 +185,7 @@ def _train_batch_sizes(self): return (16, 32, 64) if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': - # TODO(iga): Training fails with batch size of 16, probably because of - # no layout optimizations with op-by-op mode. Investigate more. - return (8,) + return (32,) return (16, 32) def _report(self, label, start, num_iters, device, batch_size, data_format): @@ -202,18 +203,18 @@ def _force_device_sync(self): # which forces a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False, execution_mode=None, - device_and_format=None): + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format or device_and_data_format() + device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) batch_size = 64 num_burn = 5 num_iters = 30 with tf.device(device): - images, _ = random_batch(batch_size, device_and_format) + images, _ = random_batch(batch_size, data_format) for _ in xrange(num_burn): model(images, training=False).cpu() if execution_mode: @@ -227,37 +228,44 @@ def _benchmark_eager_apply(self, label, defun=False, execution_mode=None, self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_apply_sync(self): - self._benchmark_eager_apply('eager_apply', defun=False) + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) def benchmark_eager_apply_async(self): self._benchmark_eager_apply( - 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) def _benchmark_eager_train(self, label, make_iterator, + device_and_format, defun=False, execution_mode=None, - device_and_format=None): + compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_format or device_and_data_format() + device, data_format = device_and_format for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size, device_and_format) - num_burn = 3 - num_iters = 10 + (images, labels) = random_batch(batch_size, data_format) model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + num_burn = 3 + num_iters = 10 with tf.device(device): iterator = make_iterator((images, labels)) for _ in xrange(num_burn): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() @@ -266,25 +274,29 @@ def _benchmark_eager_train(self, start = time.time() for _ in xrange(num_iters): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train_sync(self): - self._benchmark_eager_train('eager_train', MockIterator, defun=False) + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) def benchmark_eager_train_async(self): self._benchmark_eager_train( 'eager_train_async', MockIterator, + device_and_data_format(), defun=False, execution_mode=tfe.ASYNC) def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, defun=True) + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) def benchmark_eager_train_datasets(self): @@ -294,7 +306,8 @@ def make_iterator(tensors): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, defun=False) + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) def benchmark_eager_train_datasets_with_defun(self): @@ -304,7 +317,8 @@ def make_iterator(tensors): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, defun=True) + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 492adbe1d80941..5ee2176154ec70 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -152,7 +152,7 @@ def __init__(self, rnn_cell_sizes, label_dimension, keep_prob): self.label_dimension = label_dimension self.keep_prob = keep_prob - self.cells = self._add_cells( + self.cells = tf.contrib.checkpoint.List( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") @@ -204,14 +204,6 @@ def call(self, inputs, training=False): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of layers.Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - def loss(labels, predictions): """Computes mean squared loss.""" diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py index 75b342ba78bd5d..b7d8395e277b52 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -67,5 +67,5 @@ def testTest(self): if __name__ == "__main__": - tfe.enable_eager_execution() + tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index be5d60449d7e08..c2340a293a8092 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -50,7 +50,7 @@ class RNN(tf.keras.Model): def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - self.cells = self._add_cells([ + self.cells = tf.contrib.checkpoint.List([ tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) for _ in range(num_layers) ]) @@ -74,14 +74,6 @@ def call(self, input_seq, training): # tuple (output, output_states). return [input_seq] - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - class Embedding(layers.Layer): """An Embedding layer.""" @@ -304,7 +296,7 @@ def test_model(use_cudnn_rnn): def main(_): - tfe.enable_eager_execution() + tf.enable_eager_execution() if not FLAGS.data_path: raise ValueError("Must specify --data-path") diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD new file mode 100644 index 00000000000000..69660fbeabf2f9 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "gpu_py_test") + +gpu_py_test( + name = "scan_test", + size = "small", + srcs = ["scan_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) + +gpu_py_test( + name = "scan_graph_test", + size = "small", + srcs = ["scan_graph_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py new file mode 100644 index 00000000000000..d4b8c8941ec411 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under graph mode execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + with tf.Session() as sess: + sess.run(sum_op) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py new file mode 100644 index 00000000000000..a02fc24c79dae6 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under eager execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan64000(self): + self.runScan(64000) + + def benchmarkScan128000(self): + self.runScan(128000) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index f825a2a7363fbe..8ac553e0ae7138 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -34,10 +34,10 @@ from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util -from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -421,10 +421,8 @@ def testTrainSpinn(self): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph_string = checkpoint_utils.load_variable( - config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") - object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() - object_graph.ParseFromString(object_graph_string) + object_graph = checkpointable_utils.object_metadata( + saver.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 907f9204c2d31a..c947ed9dcc4156 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -25,12 +25,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") @@ -367,6 +368,9 @@ def call(self, labels, predictions, weights=None): Returns: The arguments, for easy chaining. """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index f0fe4ce8c53bb8..02ee05487515b8 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -26,12 +26,13 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils class MetricsTest(test.TestCase): @@ -117,6 +118,11 @@ def testAccuracy(self): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): + m = metrics.Accuracy() + with self.assertRaises(errors.InvalidArgumentError): + m([[0], [0]], [0, 1]) + def testWeightedAccuracy(self): m = metrics.Accuracy() # 1 correct, total weight of 2 @@ -146,8 +152,6 @@ def testTwoMeans(self): self.assertAllEqual(2.0, m2.result()) def testNamesWithSpaces(self): - # Verify two metrics with the same class and name don't - # accidentally share state. m1 = metrics.Mean("has space") m1(0) self.assertEqual(m1.name, "has space") @@ -186,8 +190,8 @@ def testGraphAndEagerTensor(self): self.assertEqual(self.evaluate(value), 2.5) def testTwoMeansGraph(self): - # Verify two metrics with the same class and name don't - # accidentally share state. + # Verify two metrics with the same name in the same graph raises a + # ValueError. with context.graph_mode(): m1 = metrics.Mean() m1(0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 2f8721324f5fc1..f801d9a47b2f83 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -23,14 +23,16 @@ import weakref from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils # pylint: disable=protected-access # Explanation for protected-access disable: Network has lots of same-class and @@ -52,9 +54,40 @@ def _network_name_scope_naming(current_variable_scope): return current_variable_scope.name + "/" +_NETWORK_DEPRECATION_MESSAGE = ( + "Please inherit from `tf.keras.Model`, and see its documentation for " + "details. `tf.keras.Model` should be a drop-in replacement for " + "`tfe.Network` in most cases, but note that `track_layer` is no longer " + "necessary or supported. Instead, `Layer` instances are tracked on " + "attribute assignment (see the section of `tf.keras.Model`'s documentation " + "on subclassing). Since the output of `track_layer` is often assigned to " + "an attribute anyway, most code can be ported by simply removing the " + "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow " + "`Layer` instances, including those from `tf.layers`, but switching to " + "the `tf.keras.layers` versions along with the migration to " + "`tf.keras.Model` is recommended, since it will preserve variable names. " + "Feel free to import it with an alias to avoid excess typing :)." +) + + class Network(base.Layer): """Represents the composition of a set of Layers. + *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation + for details. `tf.keras.Model` should be a drop-in replacement for + `tfe.Network` in most cases, but note that `track_layer` is no longer + necessary or supported. Instead, `Layer` instances are tracked on attribute + assignment (see the section of `tf.keras.Model`'s documentation on + subclassing). Since the output of `track_layer` is often assigned to an + attribute anyway, most code can be ported by simply removing the `track_layer` + calls. + + `tf.keras.Model` works with all TensorFlow `Layer` instances, including those + from `tf.layers`, but switching to the `tf.keras.layers` versions along with + the migration to `tf.keras.Model` is recommended, since it will preserve + variable names. Feel free to import it with an alias to avoid excess typing + :). + `Network` implements the `Layer` interface and adds convenience methods for managing sub-`Layer`s, such as listing variables. @@ -112,6 +145,7 @@ def call(self, inputs): # - Detect layers used in __call__ that weren't registered with track_layer. # - Convert inputs to __call__ to tensors. + @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE) def __init__(self, name=None): """Configure the `Network`. @@ -130,6 +164,10 @@ def __init__(self, name=None): ValueError: If `name` is not valid. Note that some naming errors will instead be raised when the `Network` is called. """ + if context.executing_eagerly(): + logging.warning( + ("** tfe.Network is deprecated and will be removed in a future " + "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE) if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -152,6 +190,11 @@ def __init__(self, name=None): self._variable_scope_counts_on_init = ( variable_scope.get_variable_scope_store().variable_scopes_count) + def _gather_saveables_for_checkpoint(self): + raise NotImplementedError( + "tfe.Network does not support object-based checkpointing.\n\n%s" + % _NETWORK_DEPRECATION_MESSAGE) + def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" return _network_name_scope_naming( @@ -502,10 +545,10 @@ def __init__(self, layers_funcs=None, name=None): def add(self, layer_func): if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) + args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): - args = estimator_util.fn_args(layer_func) + args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " @@ -706,6 +749,9 @@ def _strip_variable_prefix(original_variable_name): return _strip_variable_prefix +@deprecation.deprecated(date=None, instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.save_weights.")) def save_network_checkpoint( network, save_path, global_step=None, map_func=None): """Save variables from the Network to a checkpoint. @@ -905,6 +951,9 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, _add_deferred_restoration(network, deferred_restoration) +@deprecation.deprecated(date=None, instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.load_weights.")) def restore_network_checkpoint(network, save_path, map_func=None): """Restore the Network from a checkpoint. diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index f43376d5d777a7..c92bd15b253b67 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: disable=not-callable @@ -62,6 +63,12 @@ def call(self, values): class NetworkTest(test.TestCase): + def test_checkpointing_not_implemented(self): + checkpoint_directory = self.get_temp_dir() + checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + with self.assertRaises(NotImplementedError): + checkpoint.save(checkpoint_directory) + def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() checkpoint_path = network.save_network_checkpoint( diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 4032e755f6e7de..90a3711475719a 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -60,15 +60,9 @@ def model(): def testSameNameNoClobbering(self): with ops.device(self._dev()): - # Note that this test purposefully uses Graphs rather than - # IsolateTest. Users are more likely to accidentally create the same - # variable name this way. - first_graph = ops.Graph() - with first_graph.as_default(): - v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') - saver = _saver.Saver([v1_first_graph, v1_second_graph]) + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1, v2]) ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) @@ -126,12 +120,11 @@ def model(init_val): saver = _saver.Saver([v1]) saver.save(ckpt_prefix) - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - with _saver.restore_variables_on_create(ckpt_prefix): - # Value is from checkpoint, but not from argument. - ret, _ = model(2.0) - self.assertEqual(ret.numpy(), 1.0) + saver = _saver.Saver([v1]) + with _saver.restore_variables_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret, _ = model(2.0) + self.assertEqual(ret.numpy(), 1.0) def testRestoreNotFound(self): with ops.device(self._dev()): @@ -184,17 +177,17 @@ def model(x): 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 - with ops.Graph().as_default(): - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + ops.reset_default_graph() + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 79dd117854e5fe..fee9db46fa4f79 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -115,14 +115,15 @@ from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable import Checkpointable -from tensorflow.python.training.checkpointable_utils import CheckpointableSaver -from tensorflow.python.training.checkpointable_utils import Checkpoint +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.util import CheckpointableSaver +from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index e80ccbb74d8623..db50b33af2e4f1 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -57,7 +57,7 @@ def square(x): return math_ops.multiply(x, x) grad = tfe.gradients_function(square) - self.assertEquals([6], [x.numpy() for x in grad(3)]) + self.assertEquals([6], [x.numpy() for x in grad(3.)]) def testGradOfGrad(self): @@ -66,7 +66,7 @@ def square(x): grad = tfe.gradients_function(square) gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) - self.assertEquals([2], [x.numpy() for x in gradgrad(3)]) + self.assertEquals([2], [x.numpy() for x in gradgrad(3.)]) def testCustomGrad(self): @@ -80,7 +80,7 @@ def grad_fn(_): return y, grad_fn grad = tfe.gradients_function(f) - self.assertEquals([12], [x.numpy() for x in grad(3)]) + self.assertEquals([12], [x.numpy() for x in grad(3.)]) def testGPU(self): if tfe.num_gpus() <= 0: diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index cca441e9fb6996..ef9932f8355443 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -14,11 +14,14 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":baseline", ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", + ":hooks", ":linear", ":logit_fns", ":multi_head", @@ -28,6 +31,49 @@ py_library( ], ) +py_library( + name = "baseline", + srcs = ["python/estimator/baseline.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:baseline", + ], +) + +py_test( + name = "baseline_test", + size = "small", + srcs = ["python/estimator/baseline_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":baseline", + ":head", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:metric_keys", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "boosted_trees", srcs = ["python/estimator/boosted_trees.py"], @@ -77,6 +123,7 @@ py_test( tags = [ "no_pip", "notsan", + "optonly", # times out http://b/79220679 ], deps = [ ":dnn", @@ -179,6 +226,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ @@ -228,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", @@ -238,6 +323,37 @@ py_test( ], ) +py_library( + name = "hooks", + srcs = [ + "python/estimator/hooks.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "hooks_test", + size = "medium", + srcs = ["python/estimator/hooks_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":hooks", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "linear", srcs = ["python/estimator/linear.py"], @@ -283,9 +399,9 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:util", ], ) @@ -450,20 +566,25 @@ py_test( "no_pip", "noasan", # times out "notsan", + "optonly", # times out http://b/79220679 ], deps = [ + ":head", ":rnn", + "//tensorflow/contrib/data", "//tensorflow/core:protos_all_py", "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", "//tensorflow/python:math_ops", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:parsing_utils", "//tensorflow/python/feature_column", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index be20d1b7770d3f..788ac5ca7046d6 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,11 +19,14 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.hooks import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * @@ -38,11 +41,14 @@ 'binary_classification_head', 'clip_gradients_by_norm', 'forward_features', + 'InMemoryEvaluatorHook', + 'logistic_regression_head', 'multi_class_head', 'multi_head', 'multi_label_head', 'poisson_regression_head', 'regression_head', + 'BaselineEstimator', 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', @@ -54,6 +60,9 @@ 'replicate_model_fn', 'TowerOptimizer', 'RNNClassifier', + 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py new file mode 100644 index 00000000000000..beffbee73064b9 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Baseline estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import baseline + + +class BaselineEstimator(estimator.Estimator): + """An estimator that can establish a simple baseline. + + The estimator uses a user-specified head. + + This estimator ignores feature values and will learn to predict the average + value of each label. E.g. for single-label classification problems, this will + predict the probability distribution of the classes as seen in the labels. + For multi-label classification problems, it will predict the ratio of examples + that contain each class. + + Example: + + ```python + + # Build baseline multi-label classifier. + estimator = BaselineEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3)) + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + estimator.train(input_fn=input_fn_train) + + # Evaluates cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # For each class, predicts the ratio of training examples that contain the + # class. + predictions = classifier.predict(new_samples) + + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` passed to the `head` constructor is not `None`, a feature + with `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + head, + model_dir=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + """ + def _model_fn(features, labels, mode, config): + return baseline._baseline_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + config=config) + super(BaselineEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py new file mode 100644 index 00000000000000..d0e3e670f73328 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -0,0 +1,430 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for baseline.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import baseline +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer +from tensorflow.python.training import saver + +# Names of variables created by model. +BIAS_NAME = 'baseline/bias' + + +def assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def save_variables_to_ckpt(model_dir): + init_all_op = [variables.global_variables_initializer()] + with tf_session.Session() as sess: + sess.run(init_all_op) + saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + + +def _baseline_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a BaselineEstimator that uses regression_head.""" + return baseline.BaselineEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), + *args, **kwargs) + + +class BaselineEstimatorEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate( + input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the sum over batch = 9 + 9 = 18 + # Average loss is the average over batch = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 18., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} + labels = ((10.,), (10.,)) + return features, labels + + baseline_estimator = _baseline_estimator_fn( + weight_column='weights', + model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the weighted sum over batch = 9 + 2*9 = 27 + # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 27., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_for_multi_dimensions(self): + label_dim = 2 + with ops.Graph().as_default(): + variables.Variable([46.0, 58.0], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dim, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={ + 'age': np.array([[2., 4., 5.]]), + }, + y=np.array([[46., 58.]]), + batch_size=1, + num_epochs=None, + shuffle=False) + eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1) + + self.assertItemsEqual( + (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, + ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + + # Logit is bias which is [46, 58] + self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) + + +class BaselineEstimatorPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_1d(self): + """Tests predict when all variables are one-dimensional.""" + with ops.Graph().as_default(): + variables.Variable([.2], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[2.]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # x * weight + bias = 2. * 10. + .2 = 20.2 + self.assertAllClose([[.2]], predicted_scores) + + def testMultiDim(self): + """Tests predict when all variables are multi-dimenstional.""" + batch_size = 2 + label_dimension = 3 + with ops.Graph().as_default(): + variables.Variable( # shape=[label_dimension] + [.2, .4, .6], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + # x shape=[batch_size, x_dim] + x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # score = bias, shape=[batch_size, label_dimension] + self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], + predicted_scores) + + +class BaselineEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, + input_dimension, label_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['predictions'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + +class BaselineEstimatorTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step=None, var_list=None): + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint(self, + label_dimension, + expected_global_step, + expected_bias=None): + shapes = { + name: shape + for (name, shape) in checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(self._model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([label_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertEqual(expected_bias, + checkpoint_utils.load_variable(self._model_dir, + BIAS_NAME)) + + def testFromScratch(self): + # Create BaselineRegressor. + label = 5. + age = 17 + # loss = (logits - label)^2 = (0 - 5.)^2 = 25. + mock_optimizer = self._mock_optimizer(expected_loss=25.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=num_steps, + expected_bias=[0.]) + + def testFromCheckpoint(self): + # Create initial checkpoint. + bias = 7.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = 6. + # loss = (logits - label)^2 = (7 - 5)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=[bias]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index cf6e3329d2e277..7ff25b95c079c7 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -93,7 +93,7 @@ def __init__(self, dropout=None, input_layer_partitioner=None, config=None): - """Initializes a `DNNClassifier` instance. + """Initializes a `DNNEstimator` instance. Args: head: A `_Head` instance constructed with a method such as diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 00000000000000..03cf6f107c1c55 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name(''linear/linear_model/age/weights') + ... + ``` + + This method is a wrapper for _export_all_saved_models, and wraps a raw + input_receiver_fn in a dictionary to pass in to that function. + See _export_all_saved_models for full docs. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + For each mode passed in via the input_receiver_fn_map, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Only one of the modes is used for saving variables to the SavedModel + (order of preference: TRAIN, EVAL, then PREDICT), such that up to three + MetaGraphDefs are saved with a single set of variables in a single + SavedModel directory. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name('linear/linear_model/age/weights') + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 00000000000000..050821ee672f30 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,373 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for tag_set in model_fn_lib.EXPORT_TAG_MAP.values(): + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + self._validate_exported_files(export_dir) + + return export_dir, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 201699ed775f70..bf08be09e7baf6 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -22,12 +22,12 @@ from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import function_utils _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -330,7 +330,7 @@ def get_slot_names(self, *args, **kwargs): def _verify_metric_fn_args(metric_fn): - args = set(estimator_util.fn_args(metric_fn)) + args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % @@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn): def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" - metric_fn_args = estimator_util.fn_args(metric_fn) + metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 3dcf0374c8a12b..b798769d2cfde6 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys @@ -72,6 +74,33 @@ def multi_class_head(n_classes, shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -139,6 +168,33 @@ def binary_classification_head( shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.binary_classification_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.binary_classification_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -205,11 +261,39 @@ def regression_head(weight_column=None, shape `[D0, D1, ... DN, label_dimension]`. Also supports custom `inverse_link_fn`, also known as 'mean function'. - `inverse_link_fn` takes `logits` as argument and returns predicted values. - This function is the inverse of the link function defined in + `inverse_link_fn` is only used in `PREDICT` mode. It takes `logits` as + argument and returns predicted values. This function is the inverse of the + link function defined in https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -234,7 +318,7 @@ def regression_head(weight_column=None, Raises: ValueError: If `label_dimension` or `loss_reduction` is invalid. """ - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, @@ -269,6 +353,33 @@ def poisson_regression_head( This is implemented as a generalized linear model, see https://en.wikipedia.org/wiki/Generalized_linear_model. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.poisson_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.poisson_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -296,7 +407,7 @@ def poisson_regression_head( def _poisson_loss(labels, logits): return nn.log_poisson_loss( targets=labels, log_input=logits, compute_full_loss=compute_full_loss) - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, @@ -305,12 +416,103 @@ def _poisson_loss(labels, logits): name=name) +def logistic_regression_head( + weight_column=None, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, + name=None): + """Creates a `_Head` for logistic regression. + + Uses `sigmoid_cross_entropy_with_logits` loss, which is the same as + `binary_classification_head`. The differences compared to + `binary_classification_head` are: + + * Does not support `label_vocabulary`. Instead, labels must be float in the + range [0, 1]. + * Does not calculate some metrics that do not make sense, such as AUC. + * In `PREDICT` mode, only returns logits and predictions + (`=tf.sigmoid(logits)`), whereas `binary_classification_head` also returns + probabilities, classes, and class_ids. + * Export output defaults to `RegressionOutput`, whereas + `binary_classification_head` defaults to `PredictOutput`. + + The head expects `logits` with shape `[D0, D1, ... DN, 1]`. + In many applications, the shape is `[batch_size, 1]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`. + + This is implemented as a generalized linear model, see + https://en.wikipedia.org/wiki/Generalized_linear_model. + + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.logistic_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.logistic_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. + + Returns: + An instance of `_Head` for logistic regression. + + Raises: + ValueError: If `loss_reduction` is invalid. + """ + def _logistic_loss(labels, logits): + labels = head_lib._assert_range( # pylint:disable=protected-access + labels, n_classes=2, message='Labels must be in range [0, 1]') + return nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + return head_lib._regression_head( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=1, + loss_reduction=loss_reduction, + loss_fn=_logistic_loss, + inverse_link_fn=math_ops.sigmoid, + name=name) + + def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. @@ -342,6 +544,33 @@ def multi_label_head(n_classes, shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -363,6 +592,10 @@ def multi_label_head(n_classes, reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. + classes_for_class_based_metrics: List of integer class IDs or string class + names for which per-class metrics are evaluated. If integers, all must be + in the range `[0, n_classes - 1]`. If strings, all must be in + `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -370,8 +603,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is - invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or + `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -396,10 +629,31 @@ def multi_label_head(n_classes, if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + classes_for_class_based_metrics = tuple( + [] if classes_for_class_based_metrics is None + else classes_for_class_based_metrics) + if classes_for_class_based_metrics: + if isinstance(classes_for_class_based_metrics[0], six.string_types): + if not label_vocabulary: + raise ValueError( + 'label_vocabulary must be provided when ' + 'classes_for_class_based_metrics are sting.') + class_ids = [] + for class_string in classes_for_class_based_metrics: + class_ids.append(label_vocabulary.index(class_string)) + classes_for_class_based_metrics = tuple(class_ids) + else: + for class_id in classes_for_class_based_metrics: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError( + 'All classes_for_class_based_metrics must be in range [0, {}]. ' + 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, - loss_fn=loss_fn, name=name) + loss_fn=loss_fn, + classes_for_class_based_metrics=classes_for_class_based_metrics, + name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -412,6 +666,7 @@ def __init__(self, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): self._n_classes = n_classes self._weight_column = weight_column @@ -419,6 +674,7 @@ def __init__(self, self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn + self._classes_for_class_based_metrics = classes_for_class_based_metrics self._name = name @property @@ -496,10 +752,10 @@ def create_loss(self, features, mode, logits, labels): weights=weights, processed_labels=processed_labels) - def create_estimator_spec( + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): - """Returns an `EstimatorSpec`. + """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. @@ -522,7 +778,7 @@ def create_estimator_spec( `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. @@ -542,7 +798,7 @@ def create_estimator_spec( classifier_output = head_lib._classification_output( # pylint:disable=protected-access scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -565,16 +821,18 @@ def create_estimator_spec( # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=processed_labels, - probabilities=probabilities, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access + self._eval_metric_ops, { + 'labels': processed_labels, + 'probabilities': probabilities, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss, + })) # Train. if optimizer is not None: @@ -587,6 +845,7 @@ def create_estimator_spec( train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -608,7 +867,7 @@ def create_estimator_spec( summary.scalar( head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, @@ -671,4 +930,36 @@ def _eval_metric_ops( weights=weights, threshold=threshold, name=recall_key)) + for class_id in self._classes_for_class_based_metrics: + batch_rank = array_ops.rank(probabilities) - 1 + begin = array_ops.concat( + [array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]], + axis=0) + size = array_ops.concat( + [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]], + axis=0) + class_probabilities = array_ops.slice( + probabilities, begin=begin, size=size) + class_labels = array_ops.slice(labels, begin=begin, size=size) + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access + head_lib._predictions_mean( # pylint:disable=protected-access + predictions=class_probabilities, + weights=weights, + name=prob_key)) + auc_key = keys.AUC_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + name=auc_key)) + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + curve='PR', + name=auc_pr_key)) return metric_ops diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 98962ca4277a3e..b2b57fa06ba818 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -175,6 +176,21 @@ def _loss_fn(labels, logits, name=None): r'loss_fn has unexpected args: \[\'name\'\]'): head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_classes_for_class_based_metrics_invalid(self): + with self.assertRaisesRegexp( + ValueError, + r'All classes_for_class_based_metrics must be in range \[0, 2\]\. ' + r'Given: -1'): + head_lib.multi_label_head( + n_classes=3, classes_for_class_based_metrics=[2, -1]) + + def test_classes_for_class_based_metrics_string_invalid(self): + with self.assertRaisesRegexp( + ValueError, r'\'z\' is not in list'): + head_lib.multi_label_head( + n_classes=3, label_vocabulary=['a', 'b', 'c'], + classes_for_class_based_metrics=['c', 'z']) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -591,6 +607,81 @@ def test_eval_with_thresholds(self): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_classes_for_class_based_metrics(self): + head = head_lib.multi_label_head( + n_classes=2, classes_for_class_based_metrics=[0, 1]) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_with_classes_for_class_based_metrics_string(self): + head = head_lib.multi_label_head( + n_classes=2, label_vocabulary=['a', 'b'], + classes_for_class_based_metrics=['a', 'b']) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = sparse_tensor.SparseTensor( + values=['a', 'a', 'b'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_onehot, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') @@ -899,6 +990,34 @@ def minimize(self, loss, global_step): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) @@ -1211,5 +1330,124 @@ def test_predict(self): self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) +class LogisticRegressionHead(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def test_train(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [.6], [.8]], dtype=np.float32) + # Following the documentation in + # tf.nn.sigmoid_cross_entropy_with_logits: + # With x = logits, z = labels. + # loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + # loss = [0 - 0 * 0.4 + ln(1 + exp(-0)), + # 0 + 1 * 0.6 + ln(1 + exp(-1)), + # 1 - 1 * 0.8 + ln(1 + exp(-1))] + # = [0.6931, 0.9133, 0.5133] + # training_loss = (0.6931 + 0.9133 + 0.5133) / 3 + expected_loss = 0.7066 + atol = 0.001 + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_near( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + atol=atol, name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run([spec.loss, spec.train_op]) + self.assertAlmostEqual(expected_loss, loss, delta=atol) + self.assertEqual(expected_train_result, train_result) + + def test_train_labels_too_large(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [1.2], [.8]], dtype=np.float32) + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[1.2\]\[0.8\]\]'): + _ = sess.run(spec.loss) + + def test_train_labels_negative(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [-0.2], [.8]], dtype=np.float32) + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[-0.2\]\[0.8\]\]' + ): + _ = sess.run(spec.loss) + + def test_predict(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + expected_predictions = 1. / (1. + np.exp(-logits)) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py new file mode 100644 index 00000000000000..ddd6aa442f82ba --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Some useful session run hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training + + +# pylint: disable=protected-access +class InMemoryEvaluatorHook(training.SessionRunHook): + """Hook to run evaluation in training without a checkpoint. + + Example: + + ```python + def train_input_fn(): + ... + return train_dataset + + def eval_input_fn(): + ... + return eval_dataset + + estimator = tf.estimator.DNNClassifier(...) + + evaluator = tf.contrib.estimator.InMemoryEvaluatorHook( + estimator, eval_input_fn) + estimator.train(train_input_fn, hooks=[evaluator]) + ``` + + Current limitations of this approach are: + * It doesn't support multi-node distributed mode. + * It doesn't support saveable objects other than variables (such as boosted + tree support) + * It doesn't support custom saver logic (such as ExponentialMovingAverage + support) + + """ + + def __init__(self, + estimator, + input_fn, + steps=None, + hooks=None, + name=None, + every_n_iter=100): + """Initializes a `InMemoryEvaluatorHook`. + + Args: + estimator: A `tf.estimator.Estimator` instance to call evaluate. + input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A + function that constructs the input data for evaluation. + See @{$premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + steps: Equivalent to the `steps` arg to `estimator.evaluate`. Number of + steps for which to evaluate model. If `None`, evaluates until `input_fn` + raises an end-of-input exception. + hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of + `SessionRunHook` subclass instances. Used for callbacks inside the + evaluation call. + name: Equivalent to the `name` arg to `estimator.evaluate`. Name of the + evaluation if user needs to run multiple evaluations on different data + sets, such as on training data vs test data. Metrics for different + evaluations are saved in separate folders, and appear separately in + tensorboard. + every_n_iter: `int`, runs the evaluator once every N training iteration. + + Raises: + ValueError: if `every_n_iter` is non-positive or it's not a single machine + training + """ + if every_n_iter is None or every_n_iter <= 0: + raise ValueError('invalid every_n_iter=%s.' % every_n_iter) + if (estimator.config.num_ps_replicas > 0 or + estimator.config.num_worker_replicas > 1): + raise ValueError( + 'InMemoryEvaluator supports only single machine (aka Local) setting.') + self._estimator = estimator + self._input_fn = input_fn + self._steps = steps + self._name = name + self._every_n_iter = every_n_iter + self._eval_dir = os.path.join(self._estimator.model_dir, 'eval' + if not name else 'eval_' + name) + + self._graph = None + self._hooks = estimator_lib._check_hooks_type(hooks) + self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps)) + self._timer = training.SecondOrStepTimer(every_steps=every_n_iter) + + def begin(self): + """Build eval graph and restoring op.""" + self._timer.reset() + self._iter_count = 0 + self._graph = ops.Graph() + with self._graph.as_default(): + (self._scaffold, self._update_op, self._eval_dict, + self._all_hooks) = self._estimator._evaluate_build_graph( + self._input_fn, self._hooks, checkpoint_path=None) + + if self._scaffold.saver is not None: + raise ValueError('InMemoryEvaluator does not support custom saver') + if self._scaffold.init_fn is not None: + raise ValueError('InMemoryEvaluator does not support custom init_fn') + + self._var_name_to_eval_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + self._var_name_to_placeholder = { + v.name: array_ops.placeholder(v.dtype) + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + + def after_create_session(self, session, coord): # pylint: disable=unused-argument + """Does first run which shows the eval metrics before training.""" + if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS): + raise ValueError( + 'InMemoryEvaluator does not support saveables other than global ' + 'variables.') + self._var_name_to_train_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set( + self._var_name_to_train_var.keys()) + # Filter training var names that are not exist in evaluation + self._var_name_to_train_var = { + v_name: self._var_name_to_train_var[v_name] + for v_name in var_names_to_transfer + } + # Filter eval var names that are not exist in training + self._var_name_to_eval_var = { + v_name: self._var_name_to_eval_var[v_name] + for v_name in var_names_to_transfer + } + + with self._graph.as_default(): + self._var_feed_op = control_flow_ops.group([ + state_ops.assign(self._var_name_to_eval_var[v_name], + self._var_name_to_placeholder[v_name]) + for v_name in var_names_to_transfer + ]) + + self._evaluate(session) + + def _evaluate(self, train_session): + var_name_to_value = train_session.run(self._var_name_to_train_var) + placeholder_to_value = { + self._var_name_to_placeholder[v_name]: var_name_to_value[v_name] + for v_name in var_name_to_value + } + + def feed_variables(scaffold, session): + del scaffold + session.run(self._var_feed_op, feed_dict=placeholder_to_value) + + scaffold = training.Scaffold( + init_fn=feed_variables, copy_from_scaffold=self._scaffold) + + with self._graph.as_default(): + return self._estimator._evaluate_run( + checkpoint_path=None, + scaffold=scaffold, + update_op=self._update_op, + eval_dict=self._eval_dict, + all_hooks=self._all_hooks, + output_dir=self._eval_dir) + + self._timer.update_last_triggered_step(self._iter_count) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Runs evaluator.""" + self._iter_count += 1 + if self._timer.should_trigger_for_step(self._iter_count): + self._evaluate(run_context.session) + + def end(self, session): # pylint: disable=unused-argument + """Runs evaluator for final model.""" + self._evaluate(session) + + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py new file mode 100644 index 00000000000000..95ae971852ee6d --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os + +from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import training + + +def summary_step_keyword_to_value_mapping(dir_): + writer_cache.FileWriterCache.clear() + + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + step_keyword_to_value = {} + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step not in step_keyword_to_value: + step_keyword_to_value[last_event.step] = {} + if last_event.summary is not None: + for value in last_event.summary.value: + step_keyword_to_value[last_event.step][value.tag] = value.simple_value + + return step_keyword_to_value + + +def get_summary_value(dir_, step, keyword): + """Get summary value for given step and keyword.""" + + writer_cache.FileWriterCache.clear() + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + print('XXX', event_paths) + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step == step and last_event.summary is not None: + for value in last_event.summary.value: + if keyword in value.tag: + return value.simple_value + return None + + +class InMemoryEvaluatorHookTest(test.TestCase): + + def test_runs_eval_metrics(self): + + def model_fn(features, labels, mode): + _ = labels + if estimator_lib.ModeKeys.TRAIN == mode: + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + return estimator_lib.EstimatorSpec( + mode, loss=constant_op.constant(3.), train_op=train_op) + if estimator_lib.ModeKeys.EVAL == mode: + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(5.), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # 4.5 = sum(range(10))/10 + # before training + self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) + # intervals (every_n_iter=4) + self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features']) + self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) + # end + self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + + def test_uses_latest_variable_value(self): + + def model_fn(features, labels, mode): + _ = labels + step = training.get_global_step() + w = variable_scope.get_variable( + 'w', + shape=[], + initializer=init_ops.zeros_initializer(), + dtype=dtypes.int64) + if estimator_lib.ModeKeys.TRAIN == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + step_inc = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([step_inc]): + assign_w_to_step_plus_2 = w.assign(step + 2) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + train_op=assign_w_to_step_plus_2) + if estimator_lib.ModeKeys.EVAL == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + loss = constant_op.constant(5.) + return estimator_lib.EstimatorSpec( + mode, + loss=loss, + # w is constant in each step, so the mean. + # w = 0 if step==0 else step+2 + eval_metric_ops={'mean_of_const': metrics_lib.mean(w)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # w = 0 if step==0 else step+2 + self.assertEqual(0, step_keyword_to_value[0]['mean_of_const']) + self.assertEqual(6, step_keyword_to_value[4]['mean_of_const']) + self.assertEqual(12, step_keyword_to_value[10]['mean_of_const']) + + def test_dnn_classifier(self): + embedding = feature_column_lib.embedding_column( + feature_column_lib.categorical_column_with_vocabulary_list( + 'wire_cast', ['kima', 'omar', 'stringer']), 8) + dnn = estimator_lib.DNNClassifier( + feature_columns=[embedding], hidden_units=[3, 1]) + + def train_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['omar'], ['kima']] + }, [[0], [1]])).repeat(3) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['stringer'], ['kima']] + }, [[0], [1]])).repeat(2) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + dnn, eval_input_fn, name='in-memory') + dnn.train(train_input_fn, hooks=[evaluator]) + self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory'))) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + dnn.eval_dir('in-memory')) + + final_metrics = dnn.evaluate(eval_input_fn) + step = final_metrics[ops.GraphKeys.GLOBAL_STEP] + for summary_tag in final_metrics: + if summary_tag == ops.GraphKeys.GLOBAL_STEP: + continue + self.assertEqual(final_metrics[summary_tag], + step_keyword_to_value[step][summary_tag]) + + def test_raise_error_with_multi_worker(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_ps(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1'], + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_custom_saver_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(saver=training.Saver()), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom saver'): + evaluator.begin() + + def test_raise_error_with_custom_init_fn_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + + def init_fn(scaffold, session): + _, _ = scaffold, session + + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_fn=init_fn), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'): + evaluator.begin() + + def test_raise_error_with_saveables_other_than_global_variables(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + w = variables.Variable( + initial_value=[0.], + trainable=False, + collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) + init_op = control_flow_ops.group( + [w.initializer, training.get_global_step().initializer]) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_op=init_op), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support saveables'): + estimator.train(input_fn, hooks=[evaluator]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 09c2862ccd3f90..c8b0dd62970e34 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -41,10 +41,10 @@ import six -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core from tensorflow.python.framework import ops +from tensorflow.python.util import function_utils # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder @@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ - logit_fn_args = util.fn_args(logit_fn) + logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index f8564446e5da3e..cda23aa437f954 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -32,7 +32,6 @@ from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import device_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib @@ -48,6 +47,7 @@ from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils @deprecation.deprecated( @@ -521,7 +521,7 @@ def _get_loss_towers(model_fn, """Replicate the loss computation across devices.""" tower_specs = [] - model_fn_args = util.fn_args(model_fn) + model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index b475c12f5af3ae..7c49cd00d16777 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -229,6 +229,7 @@ def rnn_logit_fn(features, mode): rnn_outputs, _ = rnn.dynamic_rnn( cell=cell, inputs=sequence_input, + sequence_length=sequence_length, dtype=dtypes.float32, time_major=False) last_activations = _select_last_activations(rnn_outputs, sequence_length) @@ -328,6 +329,19 @@ def _train_op_fn(loss): logits=logits) +def _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type): + """Assert arguments are valid and return rnn_cell_fn.""" + if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): + raise ValueError( + 'num_units and cell_type must not be specified when using rnn_cell_fn' + ) + if not rnn_cell_fn: + if cell_type == USE_DEFAULT: + cell_type = 'basic_rnn' + rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + return rnn_cell_fn + + class RNNClassifier(estimator.Estimator): """A classifier for TensorFlow RNN models. @@ -341,8 +355,8 @@ class RNNClassifier(estimator.Estimator): token_emb = embedding_column(categorical_column=token_sequence, ...) estimator = RNNClassifier( - num_units=[32, 16], cell_type='lstm', - sequence_feature_columns=[token_emb]) + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') # Input builders def input_fn_train: # returns x, y @@ -438,8 +452,8 @@ def __init__(self, encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. @@ -448,14 +462,7 @@ def __init__(self, ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not compatible. """ - if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): - raise ValueError( - 'num_units and cell_type must not be specified when using rnn_cell_fn' - ) - if not rnn_cell_fn: - if cell_type == USE_DEFAULT: - cell_type = 'basic_rnn' - rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access @@ -479,3 +486,137 @@ def _model_fn(features, labels, mode, config): config=config) super(RNNClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +class RNNEstimator(estimator.Estimator): + """An Estimator for TensorFlow RNN models with user-specified head. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') + + # Or with custom RNN cell: + def rnn_cell_fn(mode): + cells = [ tf.contrib.rnn.LSTMCell(size) for size in [32, 16] ] + if mode == tf.estimator.ModeKeys.TRAIN: + cells = [ tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=0.5) + for cell in cells ] + return tf.contrib.rnn.MultiRNNCell(cells) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + rnn_cell_fn=rnn_cell_fn) + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if the head's `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. This specifies the model's + output and loss function to be optimized. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) + + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 393f94f5c7de02..959b40371aa5fa 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -25,12 +25,15 @@ import numpy as np import six +from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.estimator.python.estimator import head as head_lib from tensorflow.contrib.estimator.python.estimator import rnn from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.canned import parsing_utils from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io @@ -38,9 +41,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import python_io from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import state_ops @@ -50,7 +53,6 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import input as input_lib from tensorflow.python.training import monitored_session from tensorflow.python.training import optimizer from tensorflow.python.training import training_util @@ -984,7 +986,10 @@ def predict_input_fn(): predictions[prediction_keys.PredictionKeys.CLASSES]) -class RNNClassifierIntegrationTest(test.TestCase): +class BaseRNNClassificationIntegrationTest(object): + + def __init__(self, _create_estimator_fn): + self._create_estimator_fn = _create_estimator_fn def setUp(self): self._model_dir = tempfile.mkdtemp() @@ -994,20 +999,11 @@ def tearDown(self): writer_cache.FileWriterCache.clear() shutil.rmtree(self._model_dir) - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, n_classes, - batch_size): - col = seq_fc.sequence_categorical_column_with_hash_bucket( - 'tokens', hash_bucket_size=10) - embed = fc.embedding_column(col, dimension=2) - feature_columns = [embed] - + def _test_complete_flow(self, feature_columns, train_input_fn, eval_input_fn, + predict_input_fn, n_classes, batch_size): cell_units = [4, 2] - est = rnn.RNNClassifier( - num_units=cell_units, - sequence_feature_columns=feature_columns, - n_classes=n_classes, - model_dir=self._model_dir) + est = self._create_estimator_fn(feature_columns, n_classes, cell_units, + self._model_dir) # TRAIN num_steps = 10 @@ -1026,10 +1022,10 @@ def _test_complete_flow( self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) # EXPORT - feature_spec = { - 'tokens': parsing_ops.VarLenFeature(dtypes.string), - 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), - } + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) export_dir = est.export_savedmodel(tempfile.mkdtemp(), @@ -1069,7 +1065,13 @@ def testNumpyInputFn(self): batch_size=batch_size, shuffle=False) + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + self._test_complete_flow( + feature_columns=feature_columns, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, predict_input_fn=predict_input_fn, @@ -1082,7 +1084,8 @@ def testParseExampleInputFn(self): batch_size = 10 words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept'] - serialized_examples = [] + _, examples_file = tempfile.mkstemp() + writer = python_io.TFRecordWriter(examples_file) for _ in range(batch_size): sequence_length = random.randint(1, len(words)) sentence = random.sample(words, sequence_length) @@ -1096,30 +1099,36 @@ def testParseExampleInputFn(self): feature_pb2.Feature(int64_list=feature_pb2.Int64List( value=[label])), })) - serialized_examples.append(example.SerializeToString()) + writer.write(example.SerializeToString()) + writer.close() + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) - feature_spec = { - 'tokens': parsing_ops.VarLenFeature(dtypes.string), - 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), - } def _train_input_fn(): - features = parsing_ops.parse_example(serialized_examples, feature_spec) - labels = features.pop('label') - return features, labels + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec) + return dataset.map(lambda features: (features, features.pop('label'))) def _eval_input_fn(): - features = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - labels = features.pop('label') - return features, labels + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + return dataset.map(lambda features: (features, features.pop('label'))) def _predict_input_fn(): - features = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - features.pop('label') - return features, None + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + def features_fn(features): + features.pop('label') + return features + return dataset.map(features_fn) self._test_complete_flow( + feature_columns=feature_columns, train_input_fn=_train_input_fn, eval_input_fn=_eval_input_fn, predict_input_fn=_predict_input_fn, @@ -1127,5 +1136,37 @@ def _predict_input_fn(): batch_size=batch_size) +def _rnn_classifier_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=feature_columns, + n_classes=n_classes, + model_dir=model_dir) + + +class RNNClassifierIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_classifier_fn) + + +def _rnn_estimator_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNEstimator( + head=head_lib.multi_class_head(n_classes=n_classes), + num_units=cell_units, + sequence_feature_columns=feature_columns, + model_dir=model_dir) + + +class RNNEstimatorIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_estimator_fn) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 8baf0f2f1d07ca..4c0419b9a7623c 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -215,7 +215,7 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:sparse_tensor", ], - tags = ["noasan"], # times out b/78588193 + shard_count = 4, ) # Estimators tests diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 811fa89bc38c61..8f73274c2a0ebb 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -107,7 +107,7 @@ class WALSModel(object): # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init - # To be run once per integration sweep before the row(column) update + # To be run once per iteration sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. @@ -197,7 +197,8 @@ def __init__(self, row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -239,6 +240,8 @@ def __init__(self, weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -251,25 +254,46 @@ def __init__(self, regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") - self._row_update_prep_gramian = self._prepare_gramian( - self._col_factors, self._col_gramian) - self._col_update_prep_gramian = self._prepare_gramian( - self._row_factors, self._row_gramian) - self._create_transient_vars() + with ops.name_scope("row_prepare_gramian"): + self._row_update_prep_gramian = self._prepare_gramian( + self._col_factors, self._col_gramian) + with ops.name_scope("col_prepare_gramian"): + self._col_update_prep_gramian = self._prepare_gramian( + self._row_factors, self._row_gramian) + with ops.name_scope("transient_vars"): + self._create_transient_vars() @property def row_factors(self): @@ -436,7 +460,7 @@ def _prepare_gramian(self, factors, gramian): gramian: Variable storing the gramian calculated from the factors. Returns: - A op that updates the gramian with the calculated value from the factors. + An op that updates the gramian with the calculated value from the factors. """ partial_gramians = [] for f in factors: diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 555beddeaab419..b588f75efe9d0b 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -346,7 +346,8 @@ def sequence_numeric_column( key, shape=(1,), default_value=0., - dtype=dtypes.float32): + dtype=dtypes.float32, + normalizer_fn=None): """Returns a feature column that represents sequences of numeric data. Example: @@ -370,6 +371,12 @@ def sequence_numeric_column( default_value: A single value compatible with `dtype` that is used for padding the sparse data into a dense `Tensor`. dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. Returns: A `_SequenceNumericColumn`. @@ -383,12 +390,16 @@ def sequence_numeric_column( if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) return _SequenceNumericColumn( key, shape=shape, default_value=default_value, - dtype=dtype) + dtype=dtype, + normalizer_fn=normalizer_fn) def _assert_all_equal_and_return(tensors, name=None): @@ -407,7 +418,7 @@ class _SequenceNumericColumn( fc._SequenceDenseColumn, collections.namedtuple( '_SequenceNumericColumn', - ['key', 'shape', 'default_value', 'dtype'])): + ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])): """Represents sequences of numeric data.""" @property @@ -419,7 +430,10 @@ def _parse_example_spec(self): return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - return inputs.get(self.key) + input_tensor = inputs.get(self.key) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor @property def _variable_shape(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 88f5d535162939..89b5f4c4137f6c 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session @@ -670,6 +671,7 @@ def test_defaults(self): self.assertEqual((1,), a.shape) self.assertEqual(0., a.default_value) self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) def test_shape_saved_as_tuple(self): a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) @@ -688,6 +690,10 @@ def test_dtype_is_convertible_to_float(self): ValueError, 'dtype must be convertible to float'): sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] @@ -708,6 +714,41 @@ def test_get_sequence_dense_tensor(self): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_shape(self): """Tests get_sequence_dense_tensor with shape !=(1,).""" sparse_input = sparse_tensor.SparseTensorValue( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index daba965a98893b..484ffee3e7afe5 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -28,7 +28,6 @@ from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio -from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index a8d5a0dd83fb50..bf2aa75545813f 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -53,7 +53,7 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, int32 samples_per_second, int32 channel_count, const std::vector& samples, string* output_data); -// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8 +// Reads an video file using ffmpeg and converts it into a RGB24 in uint8 // [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg. Status ReadVideoFile(const string& filename, std::vector* output_data, uint32* width, uint32* height, uint32* frames); diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 020b5c99c61019..b1b5126d9e9e51 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -21,7 +21,6 @@ from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py -from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 10d1ecc738de67..dc49383c5c300e 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -119,14 +119,13 @@ from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec -from tensorflow.python.ops.array_ops import broadcast_to from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest', 'broadcast_to'] +_allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 8fc4f60492b0bf..af1b404cb51bf5 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -78,7 +78,6 @@ def test_assert_scalar_int(self): [3, 4], dtype=dtypes.int32)) -@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -216,25 +215,18 @@ def test_with_shape_partial(self): tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: - if ops._USE_C_API: - error_message = "Shapes must be equal rank, but are 2 and 1" - else: - error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, "Shapes must be equal rank, but are 2 and 1", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: - if ops._USE_C_API: - error_message = (r"Dimension 1 in both shapes must be equal, but are " - r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") - else: - error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, + r"Dimension 1 in both shapes must be equal, but are 2 and 1. " + r"Shapes are \[\?,2\] and \[2,1\].", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index bd764ed57a6da0..72835c3ad86e63 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -202,7 +202,7 @@ def execute(self, fn, *args, **kwargs): or lazy way that may cause a deadlock. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same - resources as in `*args`, `**kwargs`, and any additionaly captured + resources as in `*args`, `**kwargs`, and any additionally captured inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, if another execution in another `CriticalSection` was created without `exclusive_resource_access=True`, a `ValueError` will be raised. diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index e1ac5d77139786..6f1e4d3626140a 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -75,6 +75,7 @@ tf_kernel_library( "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, ) @@ -94,6 +95,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_cuda//cuda:cudnn_header", ], ) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 3d0ed899322c26..4d62ac65ff619f 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -289,8 +289,8 @@ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides, conv = tensors[i] value = values[i] ref_value = ref_values[i] - print("expected = ", ref_value) - print("actual = ", value) + tf_logging.info("expected = ", ref_value) + tf_logging.info("actual = ", value) tol = 1e-5 if value.dtype == np.float16: tol = 1e-3 @@ -831,7 +831,8 @@ def runTest(self, test_param): vertical_stride, padding_type) output_width = CalculateConvolvedOutputDim(input_width, filter_width, horizontal_stride, padding_type) - print("output_height=", output_height, ", output_width=", output_width) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -866,8 +867,8 @@ def runTest(self, test_param): with self.test_session(use_gpu=True) as sess: actual_y, expected_y = sess.run([actual, expected]) - print("actual_y = ", actual_y) - print("expected_y = ", expected_y) + tf_logging.info("actual_y = ", actual_y) + tf_logging.info("expected_y = ", expected_y) self.assertTrue(np.array_equal(actual_y, expected_y)) def testFusedConvInt8(self): diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e3fc6bf0f03405..4092b320042162 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -112,6 +112,7 @@ def __init__(self, generator_optimizer=None, discriminator_optimizer=None, get_hooks_fn=None, + get_eval_metric_ops_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -146,6 +147,9 @@ def __init__(self, list of hooks. These hooks are run on the generator and discriminator train ops, and can be used to implement the GAN training scheme. Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -160,7 +164,8 @@ def _model_fn(features, labels, mode): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn) + use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 387a62bd741bd4..955482599b372b 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -194,6 +195,12 @@ def make_opt(): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) + def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( @@ -203,6 +210,7 @@ def make_opt(): discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) # TRAIN @@ -213,6 +221,9 @@ def make_opt(): scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index a21358c50bbdb4..ff903a78cc36c1 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,17 +25,21 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib __all__ = [ 'GANHead', 'gan_head', ] +def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - name=None): + get_eval_metric_ops_fn=None, name=None): """Creates a `GANHead`. Args: @@ -47,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + If `None`, uses defaults. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. @@ -62,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn, name=name) @@ -72,6 +80,7 @@ def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, + get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. @@ -85,8 +94,11 @@ def __init__(self, generator_loss_fn, discriminator_loss_fn, discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. Defaults to `train.get_sequential_train_hooks()` + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. Defaults to `train.get_sequential_train_hooks()` + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ @@ -104,6 +116,8 @@ def __init__(self, generator_loss_fn, discriminator_loss_fn, self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn + self._get_eval_metric_ops_fn = get_eval_metric_ops_fn + self._name = name @property def name(self): @@ -173,13 +187,26 @@ def create_estimator_spec( gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + eval_metric_ops = { + _summary_key(self._name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(self._name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if self._get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, - # TODO(joelshor): Add metrics. If head name provided, append it to - # metric keys. - eval_metric_ops={}) + eval_metric_ops=eval_metric_ops) elif mode == model_fn_lib.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None.') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8168f005cd1105..6587f1fc600b94 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -62,9 +62,14 @@ def setUp(self): generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + get_eval_metric_ops_fn=self.get_metrics) self.assertTrue(isinstance(self.gan_head, head.GANHead)) + def get_metrics(self, gan_model): + self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) + return {} + def _test_modes_helper(self, mode): self.gan_head.create_estimator_spec( features=None, diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py index df71187fbd98c8..a9b8faa7126253 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples.""" +"""Miscellaneous utilities for TFGAN code and examples.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 2889e937436d2f..9f5fee45422e0b 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -570,7 +570,7 @@ def setUp(self): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul' + self._expected_op_name = 'mutual_information_loss/mul_1' self._batch_size = 2 diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index a320a3f232fc1d..026a3d12000334 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -492,7 +489,7 @@ def _finalize_cycles(self, info): t_ = info.transformed_ts[t] consumer_op_ = info.transformed_ops[consumer_op] t_index_ = list(consumer_op_.inputs).index(tmp_t_) - consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access def _connect_control_inputs(self, info): """Connect the previously copied ops.""" @@ -677,7 +674,7 @@ def replace_t_with_replacement_handler(info, t): def _add_control_flow_ops(ops, control_ios): - """Complete `ops` so that the tranformed graph is valid. + """Complete `ops` so that the transformed graph is valid. Partially copying a graph can lead to a malformed graph. For instance, copying half of a while construct is likely to result in an invalid graph. diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c index 6a5d982dc8514d..2e5c84704f8464 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c @@ -19,7 +19,7 @@ limitations under the License. #include "hexagon_controller.h" -#include +#include #include #include "adspmsgd.h" diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 60281951dda940..66939fbb0f0d3b 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def, HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "dump_all_nodes = " << dump_all_nodes - << ", dump_shape_and_tpye = " << dump_shape_and_type; + << ", dump_shape_and_type = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index 8f406ace1d5dcc..f230d93da4a9c0 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -17,7 +17,7 @@ ### API This module provides functions for image manipulation; currently, chrominance -transformas (including changing saturation and hue) in YIQ space and +transforms (including changing saturation and hue) in YIQ space and projective transforms (including rotation) are supported. ## Image Transformation `Ops` diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index a4cd4a2cc4b99b..2638b25ec424b5 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel { eof_(eof), timeout_(timeout) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Kafka")})); @@ -81,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "KafkaDatasetOp::Dataset"; } + string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index d04838c218d664..3f0184276f6b90 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import print_function # Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh +from tensorflow.python.keras.activations import elu +from tensorflow.python.keras.activations import hard_sigmoid +from tensorflow.python.keras.activations import linear +from tensorflow.python.keras.activations import relu +from tensorflow.python.keras.activations import selu +from tensorflow.python.keras.activations import sigmoid +from tensorflow.python.keras.activations import softmax +from tensorflow.python.keras.activations import softplus +from tensorflow.python.keras.activations import softsign +from tensorflow.python.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get +from tensorflow.python.keras.activations import deserialize +from tensorflow.python.keras.activations import serialize +from tensorflow.python.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index abf8393ae45d71..6dfb5cab17c088 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index b809e91193b459..67306cc51e1927 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index 530805d150bfe3..a25ff48b593a9a 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 118361604bbc7e..4964b1b7deb56f 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index cda52628f3c10d..afb3abebdd6735 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index ae9cd9cd18c5cc..2e3335d02aff0f 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception +from tensorflow.python.keras.applications.xception import decode_predictions +from tensorflow.python.keras.applications.xception import preprocess_input +from tensorflow.python.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index 10ef5a75852deb..a755364014206e 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like +from tensorflow.python.keras.backend import abs +from tensorflow.python.keras.backend import all +from tensorflow.python.keras.backend import any +from tensorflow.python.keras.backend import arange +from tensorflow.python.keras.backend import argmax +from tensorflow.python.keras.backend import argmin +from tensorflow.python.keras.backend import backend +from tensorflow.python.keras.backend import batch_dot +from tensorflow.python.keras.backend import batch_flatten +from tensorflow.python.keras.backend import batch_get_value +from tensorflow.python.keras.backend import batch_normalization +from tensorflow.python.keras.backend import batch_set_value +from tensorflow.python.keras.backend import bias_add +from tensorflow.python.keras.backend import binary_crossentropy +from tensorflow.python.keras.backend import cast +from tensorflow.python.keras.backend import cast_to_floatx +from tensorflow.python.keras.backend import categorical_crossentropy +from tensorflow.python.keras.backend import clear_session +from tensorflow.python.keras.backend import clip +from tensorflow.python.keras.backend import concatenate +from tensorflow.python.keras.backend import constant +from tensorflow.python.keras.backend import conv1d +from tensorflow.python.keras.backend import conv2d +from tensorflow.python.keras.backend import conv2d_transpose +from tensorflow.python.keras.backend import conv3d +from tensorflow.python.keras.backend import cos +from tensorflow.python.keras.backend import count_params +from tensorflow.python.keras.backend import ctc_batch_cost +from tensorflow.python.keras.backend import ctc_decode +from tensorflow.python.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras.backend import dot +from tensorflow.python.keras.backend import dropout +from tensorflow.python.keras.backend import dtype +from tensorflow.python.keras.backend import elu +from tensorflow.python.keras.backend import epsilon +from tensorflow.python.keras.backend import equal +from tensorflow.python.keras.backend import eval +from tensorflow.python.keras.backend import exp +from tensorflow.python.keras.backend import expand_dims +from tensorflow.python.keras.backend import eye +from tensorflow.python.keras.backend import flatten +from tensorflow.python.keras.backend import floatx +from tensorflow.python.keras.backend import foldl +from tensorflow.python.keras.backend import foldr +from tensorflow.python.keras.backend import function +from tensorflow.python.keras.backend import gather +from tensorflow.python.keras.backend import get_session +from tensorflow.python.keras.backend import get_uid +from tensorflow.python.keras.backend import get_value +from tensorflow.python.keras.backend import gradients +from tensorflow.python.keras.backend import greater +from tensorflow.python.keras.backend import greater_equal +from tensorflow.python.keras.backend import hard_sigmoid +from tensorflow.python.keras.backend import image_data_format +from tensorflow.python.keras.backend import in_test_phase +from tensorflow.python.keras.backend import in_top_k +from tensorflow.python.keras.backend import in_train_phase +from tensorflow.python.keras.backend import int_shape +from tensorflow.python.keras.backend import is_sparse +from tensorflow.python.keras.backend import l2_normalize +from tensorflow.python.keras.backend import learning_phase +from tensorflow.python.keras.backend import less +from tensorflow.python.keras.backend import less_equal +from tensorflow.python.keras.backend import log +from tensorflow.python.keras.backend import manual_variable_initialization +from tensorflow.python.keras.backend import map_fn +from tensorflow.python.keras.backend import max +from tensorflow.python.keras.backend import maximum +from tensorflow.python.keras.backend import mean +from tensorflow.python.keras.backend import min +from tensorflow.python.keras.backend import minimum +from tensorflow.python.keras.backend import moving_average_update +from tensorflow.python.keras.backend import name_scope +from tensorflow.python.keras.backend import ndim +from tensorflow.python.keras.backend import normalize_batch_in_training +from tensorflow.python.keras.backend import not_equal +from tensorflow.python.keras.backend import one_hot +from tensorflow.python.keras.backend import ones +from tensorflow.python.keras.backend import ones_like +from tensorflow.python.keras.backend import permute_dimensions +from tensorflow.python.keras.backend import placeholder +from tensorflow.python.keras.backend import pool2d +from tensorflow.python.keras.backend import pool3d +from tensorflow.python.keras.backend import pow +from tensorflow.python.keras.backend import print_tensor +from tensorflow.python.keras.backend import prod +from tensorflow.python.keras.backend import random_binomial +from tensorflow.python.keras.backend import random_normal +from tensorflow.python.keras.backend import random_normal_variable +from tensorflow.python.keras.backend import random_uniform +from tensorflow.python.keras.backend import random_uniform_variable +from tensorflow.python.keras.backend import relu +from tensorflow.python.keras.backend import repeat +from tensorflow.python.keras.backend import repeat_elements +from tensorflow.python.keras.backend import reset_uids +from tensorflow.python.keras.backend import reshape +from tensorflow.python.keras.backend import resize_images +from tensorflow.python.keras.backend import resize_volumes +from tensorflow.python.keras.backend import reverse +from tensorflow.python.keras.backend import rnn +from tensorflow.python.keras.backend import round +from tensorflow.python.keras.backend import separable_conv2d +from tensorflow.python.keras.backend import set_epsilon +from tensorflow.python.keras.backend import set_floatx +from tensorflow.python.keras.backend import set_image_data_format +from tensorflow.python.keras.backend import set_learning_phase +from tensorflow.python.keras.backend import set_session +from tensorflow.python.keras.backend import set_value +from tensorflow.python.keras.backend import shape +from tensorflow.python.keras.backend import sigmoid +from tensorflow.python.keras.backend import sign +from tensorflow.python.keras.backend import sin +from tensorflow.python.keras.backend import softmax +from tensorflow.python.keras.backend import softplus +from tensorflow.python.keras.backend import softsign +from tensorflow.python.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras.backend import spatial_2d_padding +from tensorflow.python.keras.backend import spatial_3d_padding +from tensorflow.python.keras.backend import sqrt +from tensorflow.python.keras.backend import square +from tensorflow.python.keras.backend import squeeze +from tensorflow.python.keras.backend import stack +from tensorflow.python.keras.backend import std +from tensorflow.python.keras.backend import stop_gradient +from tensorflow.python.keras.backend import sum +from tensorflow.python.keras.backend import switch +from tensorflow.python.keras.backend import tanh +from tensorflow.python.keras.backend import temporal_padding +from tensorflow.python.keras.backend import to_dense +from tensorflow.python.keras.backend import transpose +from tensorflow.python.keras.backend import truncated_normal +from tensorflow.python.keras.backend import update +from tensorflow.python.keras.backend import update_add +from tensorflow.python.keras.backend import update_sub +from tensorflow.python.keras.backend import var +from tensorflow.python.keras.backend import variable +from tensorflow.python.keras.backend import zeros +from tensorflow.python.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2d884790ddb9cc..10e05f2969bc40 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras.callbacks import BaseLogger +from tensorflow.python.keras.callbacks import Callback +from tensorflow.python.keras.callbacks import CSVLogger +from tensorflow.python.keras.callbacks import EarlyStopping +from tensorflow.python.keras.callbacks import History +from tensorflow.python.keras.callbacks import LambdaCallback +from tensorflow.python.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import ProgbarLogger +from tensorflow.python.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras.callbacks import RemoteMonitor +from tensorflow.python.keras.callbacks import TensorBoard +from tensorflow.python.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 152606d8ebbcad..08debf974ec3a3 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm +from tensorflow.python.keras.constraints import Constraint +from tensorflow.python.keras.constraints import max_norm +from tensorflow.python.keras.constraints import MaxNorm +from tensorflow.python.keras.constraints import min_max_norm +from tensorflow.python.keras.constraints import MinMaxNorm +from tensorflow.python.keras.constraints import non_neg +from tensorflow.python.keras.constraints import NonNeg +from tensorflow.python.keras.constraints import unit_norm +from tensorflow.python.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get +from tensorflow.python.keras.constraints import deserialize +from tensorflow.python.keras.constraints import serialize +from tensorflow.python.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index b5371a03fd5f57..a5a6fdab445d2d 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.python.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index 68d3eb789ea2c4..e74e5f347df2ee 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data +from tensorflow.python.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index ca937426733416..8f5753a6360dfb 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.python.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index 1c6396d2d32b88..bd6ec4b8dfb034 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data +from tensorflow.python.keras.datasets.imdb import get_word_index +from tensorflow.python.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 364255f3387b59..f61145655bd5d9 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data +from tensorflow.python.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index bb6791a344ad0c..ade31f4ea9c332 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data +from tensorflow.python.keras.datasets.reuters import get_word_index +from tensorflow.python.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 6b1fcfd2d9585d..c6bdc4f0dac3f4 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros +from tensorflow.python.keras.initializers import Constant +from tensorflow.python.keras.initializers import Identity +from tensorflow.python.keras.initializers import Initializer +from tensorflow.python.keras.initializers import Ones +from tensorflow.python.keras.initializers import Orthogonal +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.initializers import RandomUniform +from tensorflow.python.keras.initializers import TruncatedNormal +from tensorflow.python.keras.initializers import VarianceScaling +from tensorflow.python.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform +from tensorflow.python.keras.initializers import glorot_normal +from tensorflow.python.keras.initializers import glorot_uniform +from tensorflow.python.keras.initializers import he_normal +from tensorflow.python.keras.initializers import he_uniform +from tensorflow.python.keras.initializers import lecun_normal +from tensorflow.python.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get +from tensorflow.python.keras.initializers import deserialize +from tensorflow.python.keras.initializers import serialize +from tensorflow.python.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index acf0a5e1799b7c..938c881fcbe186 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine import Input +from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine import Layer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 66721b694f5fd5..c4476a7bbd5056 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import print_function # Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get +from tensorflow.python.keras.losses import deserialize +from tensorflow.python.keras.losses import serialize +from tensorflow.python.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 59faf037bce0f0..7317fdb52c5b79 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import print_function # Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras.metrics import binary_accuracy +from tensorflow.python.keras.metrics import binary_crossentropy +from tensorflow.python.keras.metrics import categorical_accuracy +from tensorflow.python.keras.metrics import categorical_crossentropy +from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import hinge +from tensorflow.python.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras.metrics import mean_absolute_error +from tensorflow.python.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras.metrics import mean_squared_error +from tensorflow.python.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras.metrics import poisson +from tensorflow.python.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras.metrics import squared_hinge +from tensorflow.python.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get +from tensorflow.python.keras.metrics import deserialize +from tensorflow.python.keras.metrics import serialize +from tensorflow.python.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 2fb4ac0960d38f..3a196984cd88cb 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras.models import load_model +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.models import model_from_config +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.keras.models import model_from_yaml +from tensorflow.python.keras.models import save_model +from tensorflow.python.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index 44f47bc47f4a0e..4849a06747958a 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import print_function # Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras.optimizers import Adadelta +from tensorflow.python.keras.optimizers import Adagrad +from tensorflow.python.keras.optimizers import Adam +from tensorflow.python.keras.optimizers import Adamax +from tensorflow.python.keras.optimizers import Nadam +from tensorflow.python.keras.optimizers import Optimizer +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get +from tensorflow.python.keras.optimizers import deserialize +from tensorflow.python.keras.optimizers import serialize +from tensorflow.python.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index b96e7675527041..1f9e82b41bf09b 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom +from tensorflow.python.keras.preprocessing.image import apply_transform +from tensorflow.python.keras.preprocessing.image import array_to_img +from tensorflow.python.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras.preprocessing.image import flip_axis +from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras.preprocessing.image import img_to_array +from tensorflow.python.keras.preprocessing.image import Iterator +from tensorflow.python.keras.preprocessing.image import load_img +from tensorflow.python.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras.preprocessing.image import random_rotation +from tensorflow.python.keras.preprocessing.image import random_shear +from tensorflow.python.keras.preprocessing.image import random_shift +from tensorflow.python.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 112f6af5e588bc..9a93b6fb57ff5a 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index 5bf1a2fb21dc27..86386a9b6762d1 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras.preprocessing.text import one_hot +from tensorflow.python.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index 3e707ccab577b5..d668e39c09ca28 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 +from tensorflow.python.keras.regularizers import l1 +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get +from tensorflow.python.keras.regularizers import deserialize +from tensorflow.python.keras.regularizers import serialize +from tensorflow.python.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index a7c2179fe7ad43..47cd01b924fb43 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index a46f859273ea01..c4b7aa765c26ba 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 91929184a2e6f3..2ff4d41d75fe59 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -31,7 +31,7 @@ def _inner_product(x, y): - """Inner product between tensors x and y. + r"""Inner product between tensors x and y. The input tensors are assumed to be in ROW representation, that is, the method returns \\(x * y^T\\). @@ -131,10 +131,6 @@ def testGoodKernelApproximationAmortized(self): mapped_dim = 5000 stddev = 5.0 - # TODO(sibyl-vie3Poto): Reduce test's running time before moving to third_party. One - # possible way to speed the test up is to compute both the approximate and - # the exact kernel matrix directly using matrix operations instead of - # computing the values for each pair of points separately. points_shape = [1, input_dim] points = [ random_ops.random_uniform(shape=points_shape, maxval=1.0) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 762a2f0b57e95e..102626925db560 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,5 +1,10 @@ # K-FAC: Kronecker-Factored Approximate Curvature +# WARNING: +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== +# ==== + **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. When applied to feedforward and convolutional neural networks, K-FAC can converge `>3.5x` diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index e8e3353091df25..d6b1a61b716ab7 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -223,26 +223,26 @@ def minimize_loss_single_machine(loss, (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - with tf.device(device): - train_op = optimizer.minimize(loss, global_step=g_step) - def make_update_op(update_thunks): - update_op = [thunk() for thunk in update_thunks] - return tf.group(*update_op) + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([train_op, cov_update_op]): + with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( - tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, inverse_op]) + [g_step, loss, accuracy, train_op]) - if (global_step_ + 1) % _INVERT_EVERY == 0: + if global_step_ % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -325,7 +325,7 @@ def distributed_grads_only_and_ops_chief_worker( All workers perform gradient computation. Chief worker applies gradient after averaging the gradients obtained from all the workers. All workers block - execution untill the update is applied. Chief worker runs covariance and + execution until the update is applied. Chief worker runs covariance and inverse update ops. Covariance and inverse matrices are placed on parameter servers in a round robin manner. For further details on synchronous distributed optimization check `tf.train.SyncReplicasOptimizer`. @@ -357,24 +357,25 @@ def distributed_grads_only_and_ops_chief_worker( task_id, num_worker_tasks, num_ps_tasks, layer_collection) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - train_op = sync_optimizer.minimize(loss, global_step=global_step) tf.logging.info("Starting training.") hooks = [sync_optimizer.make_session_run_hook(is_chief)] def make_update_op(update_thunks): - update_op = [thunk() for thunk in update_thunks] - return tf.group(*update_op) + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) if is_chief: cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([train_op, cov_update_op]): - update_op = tf.cond( - tf.equal(tf.mod(global_step + 1, invert_every), 0), + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, invert_every), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = sync_optimizer.minimize(loss, global_step=global_step) else: - update_op = train_op + train_op = sync_optimizer.minimize(loss, global_step=global_step) with tf.train.MonitoredTrainingSession( master=master, @@ -384,7 +385,7 @@ def make_update_op(update_thunks): stop_grace_period_secs=0) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, update_op]) + [global_step, loss, accuracy, train_op]) tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) return accuracy_ @@ -577,25 +578,25 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - train_op = optimizer.minimize(loss, global_step=g_step) - def make_update_op(update_thunks): - update_op = [thunk() for thunk in update_thunks] - return tf.group(*update_op) + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([train_op, cov_update_op]): + with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( - tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, inverse_op]) + [g_step, loss, accuracy, train_op]) - if (global_step_ + 1) % _INVERT_EVERY == 0: + if global_step_ % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 87eed03888c894..ea2b252a05702d 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -105,18 +105,21 @@ def build_model(examples, labels, num_labels, layer_collection): return loss, accuracy -def minimize(loss, accuracy, layer_collection, session_config=None): +def minimize(loss, accuracy, layer_collection, num_towers, session_config=None): """Minimize 'loss' with KfacOptimizer. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance. Describes layers in model. + num_towers: int. Number of CPUs to split minibatch across. session_config: tf.ConfigProto. Configuration for tf.Session(). Returns: accuracy of classifier on final minibatch. """ + devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers)) + # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 # every 10k iterations. tf.logging.info("Building KFAC Optimizer.") @@ -125,27 +128,38 @@ def minimize(loss, accuracy, layer_collection, session_config=None): learning_rate=tf.train.exponential_decay( 0.00002, global_step, 10000, 0.5, staircase=True), cov_ema_decay=0.95, - damping=0.0001, + damping=0.0005, layer_collection=layer_collection, - momentum=0.99) - train_op = optimizer.minimize(loss, global_step=global_step) + momentum=0.99, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices) + + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt + # once that gets moved over? Could still leave more advanced examples as they + # are (e.g. train_mnist_estimator in this file) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # We update the inverses only every 20 iterations. + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, 100), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - # K-FAC has 3 primary ops, - # - train_op: Update the weights with the minibatch's gradient. - # - cov_update_op: Update statistics used for building K-FAC's - # preconditioner matrix. - # - inv_update_op: Update preconditioner matrix using statistics. - # - # The first 2 of these are cheap and should be done with each step. The - # latter is more expensive, and should be updated ~100 iterations. - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, train_op]) if global_step_ % 100 == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %f", @@ -180,7 +194,7 @@ def train_mnist(data_dir, num_epochs, use_fake_data=False): loss, accuracy = build_model(examples, labels, 10, layer_collection) # Fit model. - minimize(loss, accuracy, layer_collection) + minimize(loss, accuracy, layer_collection, 1) def train_mnist_multitower(data_dir, @@ -238,7 +252,8 @@ def train_mnist_multitower(data_dir, "CPU": num_towers }) return minimize( - loss, accuracy, layer_collection, session_config=session_config) + loss, accuracy, layer_collection, num_towers, + session_config=session_config) def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): @@ -298,13 +313,26 @@ def model_fn(features, labels, mode, params): layer_collection=layer_collection, momentum=0.99) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + def make_batch_executed_op(update_thunks, batch_size=1): + return tf.group(*tf.contrib.kfac.utils.batch_execute( + global_step, update_thunks, batch_size=batch_size)) + # Run cov_update_op every step. Run 1 inv_update_ops per step. - cov_update_op = optimizer.cov_update_op - inv_update_op = tf.group( - tf.contrib.kfac.utils.batch_execute( - global_step, optimizer.inv_update_thunks, batch_size=1)) - with tf.control_dependencies([cov_update_op, inv_update_op]): - train_op = optimizer.minimize(loss, global_step=global_step) + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # But make sure to execute all the inverse ops on the first step + inverse_op = tf.cond(tf.equal(global_step, 0), + lambda: make_update_op(inv_update_thunks), + lambda: make_batch_executed_op(inv_update_thunks)) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) # Print metrics every 5 sec. hooks = [ diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 6de775cc79953b..adecda71666ee7 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -157,7 +157,7 @@ def testTrainMnistDistributed(self): num_ps_tasks=0, master="", data_dir=None, - num_epochs=1, + num_epochs=2, op_strategy="chief_worker", use_fake_data=True) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 2477d2bfc12c2d..6e4a8d71baa85d 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -58,6 +58,7 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:linear_operator", "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -96,6 +97,7 @@ py_test( srcs = ["optimizer_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index f22dbcf2156629..0e65d419a31838 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -81,7 +81,7 @@ def testEstimatorInitManualRegistration(self): damping=0.2, layer_collection=self.layer_collection ) - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() # Check that we throw an error if we don't include registered variables, # i.e. self.weights @@ -91,7 +91,7 @@ def testEstimatorInitManualRegistration(self): cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): @@ -101,7 +101,7 @@ def testVariableWrongNumberOfUses(self, mock_uses): cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): @@ -111,7 +111,7 @@ def testInvalidEstimationMode(self): damping=0.2, layer_collection=self.layer_collection, estimation_mode="not_a_real_mode") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testGradientsModeBuild(self): with self._graph.as_default(): @@ -121,7 +121,7 @@ def testGradientsModeBuild(self): damping=0.2, layer_collection=self.layer_collection, estimation_mode="gradients") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testEmpiricalModeBuild(self): with self._graph.as_default(): @@ -131,7 +131,7 @@ def testEmpiricalModeBuild(self): damping=0.2, layer_collection=self.layer_collection, estimation_mode="empirical") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testCurvaturePropModeBuild(self): with self._graph.as_default(): @@ -141,7 +141,7 @@ def testCurvaturePropModeBuild(self): damping=0.2, layer_collection=self.layer_collection, estimation_mode="curvature_prop") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testExactModeBuild(self): with self._graph.as_default(): @@ -151,7 +151,7 @@ def testExactModeBuild(self): damping=0.2, layer_collection=self.layer_collection, estimation_mode="exact") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" @@ -215,8 +215,11 @@ def test_round_robin_placement(self): inv_devices=["/cpu:{}".format(i) for i in range(2)]) # Construct an op that executes one covariance update per step. - (cov_update_ops, _, inv_update_ops, _, _, - _) = fisher_estimator.make_ops_and_vars(scope="test") + (cov_update_thunks, + inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( + scope="test") + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 6eda6c31e34370..86ec7a095afdf4 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -21,7 +21,9 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -34,6 +36,19 @@ from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + +# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our +# inverse is something other than the identity" are actually broken. They never +# run the covariance update ops and so the inverse actually is the identity +# (possible plus the damping term, which would still make it a multiple of the +# identity). + + def _make_psd(dim): """Constructs a PSD matrix of the given dimension.""" mat = np.ones((dim, dim), dtype=np.float32) @@ -46,8 +61,9 @@ class UtilsTest(test.TestCase): def testComputePiTracenorm(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) + diag = ops.convert_to_tensor([1., 2., 0., 1.]) + left_factor = lo.LinearOperatorDiag(diag) + right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) # pi is the sqrt of the left trace norm divided by the right trace norm pi = fb.compute_pi_tracenorm(left_factor, right_factor) @@ -245,7 +261,6 @@ def testMultiplyInverseAgainstExplicit(self): full = sess.run(block.full_fisher_block()) explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - self.assertAllClose(output_flat, explicit) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 432b67e5690003..fad47cd02f372e 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -35,6 +35,13 @@ from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + def make_damping_func(damping): return fb._package_func(lambda: damping, damping) @@ -70,35 +77,44 @@ def make_inverse_update_ops(self): def get_cov(self): return NotImplementedError - def left_multiply(self, x, damping): + def instantiate_inv_variables(self): return NotImplementedError - def right_multiply(self, x, damping): - return NotImplementedError + def _num_towers(self): + raise NotImplementedError - def left_multiply_matpower(self, x, exp, damping): - return NotImplementedError + def _get_data_device(self): + raise NotImplementedError - def right_multiply_matpower(self, x, exp, damping): - return NotImplementedError + def register_matpower(self, exp, damping_func): + raise NotImplementedError - def instantiate_inv_variables(self): - return NotImplementedError + def register_cholesky(self, damping_func): + raise NotImplementedError - def _num_towers(self): + def register_cholesky_inverse(self, damping_func): raise NotImplementedError - def _get_data_device(self): + def get_matpower(self, exp, damping_func): + raise NotImplementedError + + def get_cholesky(self, damping_func): + raise NotImplementedError + + def get_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_cov_as_linear_operator(self): raise NotImplementedError -class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): - """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. +class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): + """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. """ def __init__(self, shape): self._shape = shape - super(InverseProvidingFactorTestingDummy, self).__init__() + super(DenseSquareMatrixFactorTestingDummy, self).__init__() @property def _var_scope(self): @@ -230,13 +246,13 @@ def testMakeInverseUpdateOps(self): self.assertEqual(0, len(factor.make_inverse_update_ops())) -class InverseProvidingFactorTest(test.TestCase): +class DenseSquareMatrixFactorTest(test.TestCase): def testRegisterDampedInverse(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [2, 2] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' damping_funcs = [make_damping_func(0.1), @@ -248,22 +264,25 @@ def testRegisterDampedInverse(self): factor.instantiate_inv_variables() - inv = factor.get_inverse(damping_funcs[0]) - self.assertEqual(inv, factor.get_inverse(damping_funcs[1])) - self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2])) - self.assertEqual(factor.get_inverse(damping_funcs[2]), - factor.get_inverse(damping_funcs[3])) + inv = factor.get_inverse(damping_funcs[0]).to_dense() + self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) + self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), + factor.get_inverse(damping_funcs[3]).to_dense()) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]), - set(factor_vars)) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + self.assertEqual(set([inv, + factor.get_inverse(damping_funcs[2]).to_dense()]), + set(factor_tensors)) self.assertEqual(shape, inv.get_shape()) def testRegisterMatpower(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [3, 3] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' # TODO(b/74201126): Change to using the same func for both once @@ -278,10 +297,13 @@ def testRegisterMatpower(self): factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - matpower1 = factor.get_matpower(-0.5, damping_func_1) - matpower2 = factor.get_matpower(2, damping_func_2) - self.assertEqual(set([matpower1, matpower2]), set(factor_vars)) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() + matpower2 = factor.get_matpower(2, damping_func_2).to_dense() + + self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower2.get_shape()) @@ -297,7 +319,7 @@ def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[1., 2.], [3., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) damping_funcs = [] @@ -316,7 +338,8 @@ def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): sess.run(ops) for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append(sess.run(factor.get_inverse(damping_funcs[i]))) + new_invs.append( + sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -328,7 +351,7 @@ def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[6., 2.], [2., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power damping = 0.5 @@ -341,7 +364,7 @@ def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): sess.run(tf_variables.global_variables_initializer()) sess.run(ops[0]) - matpower = sess.run(factor.get_matpower(exp, damping_func)) + matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) self.assertAllClose(matpower, matpower_np) @@ -349,7 +372,7 @@ def testMakeInverseUpdateOpsNoEigenDecomp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) damping_func = make_damping_func(0) @@ -361,12 +384,12 @@ def testMakeInverseUpdateOpsNoEigenDecomp(self): sess.run(tf_variables.global_variables_initializer()) # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor.get_inverse(damping_func)) + old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose( sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ops) - new_inv = sess.run(factor.get_inverse(damping_func)) + new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose(new_inv, np.linalg.inv(cov)) @@ -411,7 +434,7 @@ def testNaiveDiagonalFactorInit(self): tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) factor.instantiate_cov_variables() - self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -420,7 +443,7 @@ def testNaiveDiagonalFactorInitFloat64(self): tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) factor.instantiate_cov_variables() - cov = factor.get_cov_var() + cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -444,7 +467,7 @@ def testInitialization(self): vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) factor.instantiate_cov_variables() - cov = factor.get_cov_var() + cov = factor.get_cov() self.assertEqual(cov.shape.as_list(), [vocab_size]) def testCovarianceUpdateOp(self): @@ -502,7 +525,7 @@ def testInit(self): self.kernel_height * self.kernel_width * self.in_channels, self.out_channels ], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(): @@ -564,7 +587,7 @@ def testHasBias(self): self.kernel_height * self.kernel_width * self.in_channels + 1, self.out_channels ], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure update op doesn't crash. cov_update_op = factor.make_covariance_update_op(0.0) @@ -654,13 +677,13 @@ def test3DConvolution(self): # Ensure shape of covariance matches input size of filter. input_size = in_channels * (width**3) self.assertEqual([input_size, input_size], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank-8, as the filter will be applied at each corner of # the 4-D cube. @@ -685,13 +708,13 @@ def testPointwiseConv2d(self): # Ensure shape of covariance matches input size of filter. self.assertEqual([in_channels, in_channels], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank-9, as the filter will be applied at each location. self.assertMatrixRank(9, cov) @@ -716,7 +739,7 @@ def testStrides(self): with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be the sum of 3 * 2 = 6 outer products. self.assertMatrixRank(6, cov) @@ -742,7 +765,7 @@ def testDilationRate(self): with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank = in_channels, as only the center of the filter # receives non-zero input for each input channel. diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py index 9325aa1b7325fa..560a9b0b426ecc 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -20,6 +20,7 @@ import numpy as np +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import optimizer from tensorflow.python.framework import ops @@ -32,6 +33,13 @@ from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + def dummy_layer_collection(): lcoll = lc.LayerCollection() dummy = array_ops.constant([1., 2.]) @@ -186,6 +194,11 @@ def testApplyGradients(self): layer_collection, momentum=0.5, momentum_type='regular') + (cov_update_thunks, + inv_update_thunks) = opt.make_vars_and_create_op_thunks() + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + grads_and_vars = opt.compute_gradients(output, [weights, bias]) all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] @@ -193,6 +206,8 @@ def testApplyGradients(self): sess.run(tf_variables.global_variables_initializer()) old_vars = sess.run(all_vars) + sess.run(cov_update_ops) + sess.run(inv_update_ops) sess.run(op) new_vars = sess.run(all_vars) diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index cb0917bb851cff..3c01eb65e7a687 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -35,6 +35,7 @@ py_library( srcs = ["fisher_factors.py"], srcs_version = "PY2AND3", deps = [ + ":linear_operator", ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -63,6 +64,19 @@ py_library( ], ) +py_library( + name = "linear_operator", + srcs = ["linear_operator.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", + ], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index d11c9c82881074..854f885c26f2b4 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs): if placement_strategy in [None, "round_robin"]: return FisherEstimatorRoundRobin(**kwargs) else: - raise ValueError("Unimplemented vars and ops placement strategy : %s", - placement_strategy) + raise ValueError("Unimplemented vars and ops " + "placement strategy : {}".format(placement_strategy)) # pylint: enable=abstract-class-instantiated @@ -81,7 +81,9 @@ def __init__(self, exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, - name="FisherEstimator"): + name="FisherEstimator", + compute_cholesky=False, + compute_cholesky_inverse=False): """Create a FisherEstimator object. Args: @@ -124,6 +126,12 @@ def __init__(self, name: A string. A name given to this estimator, which is added to the variable scope when constructing variables and ops. (Default: "FisherEstimator") + compute_cholesky: Bool. Whether or not the FisherEstimator will be + able to multiply vectors by the Cholesky factor. + (Default: False) + compute_cholesky_inverse: Bool. Whether or not the FisherEstimator + will be able to multiply vectors by the Cholesky factor inverse. + (Default: False) Raises: ValueError: If no losses have been registered with layer_collection. """ @@ -142,6 +150,8 @@ def __init__(self, self._made_vars = False self._exps = exps + self._compute_cholesky = compute_cholesky + self._compute_cholesky_inverse = compute_cholesky_inverse self._name = name @@ -170,44 +180,6 @@ def factors(self): def name(self): return self._name - @abc.abstractmethod - def make_ops_and_vars(self, scope=None): - """Make ops and vars with a specific placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. For example in case of - round robin placement a new device is chosen for each factor by cycling - through list of devices in the cov_devices argument. If cov_devices is None - then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - pass - @abc.abstractmethod def make_vars_and_create_op_thunks(self, scope=None): """Make vars and create op thunks with a specific placement strategy. @@ -300,9 +272,54 @@ def multiply_matpower(self, exp, vecs_and_vars): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ + assert exp in self._exps + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) return self._apply_transformation(vecs_and_vars, fcn) + def multiply_cholesky(self, vecs_and_vars, transpose=False): + """Multiplies the vecs by the corresponding Cholesky factors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factors are transposed before + multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky + + fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): + """Mults the vecs by the inverses of the corresponding Cholesky factors. + + Note: if you are using Cholesky inverse multiplication to sample from + a matrix-variate Gaussian you will want to multiply by the transpose. + Let L be the Cholesky factor of F and observe that + + L^-T * L^-1 = (L * L^T)^-1 = F^-1 . + + Thus we want to multiply by L^-T in order to sample from Gaussian with + covariance F^-1. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factor inverses are transposed + before multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky_inverse + + fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + def _instantiate_factors(self): """Instantiates FisherFactors' variables. @@ -333,9 +350,13 @@ def made_vars(self): return self._made_vars def _register_matrix_functions(self): - for exp in self._exps: - for block in self.blocks: + for block in self.blocks: + for exp in self._exps: block.register_matpower(exp) + if self._compute_cholesky: + block.register_cholesky() + if self._compute_cholesky_inverse: + block.register_cholesky_inverse() def _finalize_layer_collection(self): self._layers.create_subgraph() diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py index 33c969650615bf..9c9fef471f8033 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -25,6 +25,7 @@ _allowed_symbols = [ 'FisherEstimator', + 'make_fisher_estimator', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 00b3673a742e92..3a5c8eb5f9630f 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): - """Computes the scalar constant pi for Tikhonov regularization/damping. + r"""Computes the scalar constant pi for Tikhonov regularization/damping. $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: - left_cov: The left Kronecker factor "covariance". - right_cov: The right Kronecker factor "covariance". + left_cov: A LinearOperator object. The left Kronecker factor "covariance". + right_cov: A LinearOperator object. The right Kronecker factor "covariance". Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ - - def _trace(cov): - if len(cov.shape) == 1: - # Diagonal matrix. - return math_ops.reduce_sum(cov) - elif len(cov.shape) == 2: - # Full matrix. - return math_ops.trace(cov) - else: - raise ValueError( - "What's the trace of a Tensor of rank %d?" % len(cov.shape)) - # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = left_cov.trace() * int(right_cov.domain_dimension) + right_norm = right_cov.trace() * int(left_cov.domain_dimension) return math_ops.sqrt(left_norm / right_norm) @@ -188,6 +176,16 @@ def register_matpower(self, exp): """ pass + @abc.abstractmethod + def register_cholesky(self): + """Registers a Cholesky factor to be computed by the block.""" + pass + + @abc.abstractmethod + def register_cholesky_inverse(self): + """Registers an inverse Cholesky factor to be computed by the block.""" + pass + def register_inverse(self): """Registers a matrix inverse to be computed by the block.""" self.register_matpower(-1) @@ -228,6 +226,33 @@ def multiply(self, vector): """ return self.multiply_matpower(vector, 1) + @abc.abstractmethod + def multiply_cholesky(self, vector, transpose=False): + """Multiplies the vector by the (damped) Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor is transposed before + multiplying the vector. (Default: False) + + Returns: + The vector left-multiplied by the (damped) Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def multiply_cholesky_inverse(self, vector, transpose=False): + """Multiplies vector by the (damped) inverse Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor inverse is transposed + before multiplying the vector. (Default: False) + Returns: + Vector left-multiplied by (damped) inverse Cholesky-factor of the block. + """ + pass + @abc.abstractmethod def tensors_to_compute_grads(self): """Returns the Tensor(s) with respect to which this FisherBlock needs grads. @@ -275,15 +300,32 @@ def instantiate_factors(self, grads_list, damping): def register_matpower(self, exp): self._factor.register_matpower(exp, self._damping_func) - def multiply_matpower(self, vector, exp): + def register_cholesky(self): + self._factor.register_cholesky(self._damping_func) + + def register_cholesky_inverse(self): + self._factor.register_cholesky_inverse(self._damping_func) + + def _multiply_matrix(self, matrix, vector, transpose=False): vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_matpower( - vector_flat, exp, self._damping_func) + out_flat = matrix.matmul(vector_flat, adjoint=transpose) return utils.column_to_tensors(vector, out_flat) + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov() + return self._factor.get_cov_as_linear_operator().to_dense() def tensors_to_compute_grads(self): return self._params @@ -305,7 +347,47 @@ def _batch_size(self): return math_ops.reduce_sum(self._batch_sizes) -class NaiveDiagonalFB(FisherBlock): +@six.add_metaclass(abc.ABCMeta) +class DiagonalFB(FisherBlock): + """A base class for FisherBlocks that use diagonal approximations.""" + + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass + + def register_cholesky(self): + # Not needed for this. Cholesky's are computed on demand in the + # diagonal case + pass + + def register_cholesky_inverse(self): + # Not needed for this. Cholesky inverses's are computed on demand in the + # diagonal case + pass + + def _multiply_matrix(self, matrix, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def full_fisher_block(self): + return self._factor.get_cov_as_linear_operator().to_dense() + + +class NaiveDiagonalFB(DiagonalFB): """FisherBlock using a diagonal matrix approximation. This type of approximation is generically applicable but quite primitive. @@ -333,20 +415,6 @@ def instantiate_factors(self, grads_list, damping): self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_matpower( - vector_flat, exp, self._damping_func) - return utils.column_to_tensors(vector, out_flat) - - def full_fisher_block(self): - return self._factor.get_cov() - def tensors_to_compute_grads(self): return self._params @@ -452,7 +520,7 @@ def _outputs(self): return self.__outputs -class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): +class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -497,32 +565,8 @@ def instantiate_factors(self, grads_list, damping): self._damping_func = _package_func(lambda: damping, (damping,)) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - def multiply_matpower(self, vector, exp): - """Multiplies the vector by the (damped) matrix-power of the block. - - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. - exp: A scalar representing the power to raise the block before multiplying - it by the vector. - - Returns: - The vector left-multiplied by the (damped) matrix-power of the block. - """ - reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_matpower( - reshaped_vec, exp, self._damping_func) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - -class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): +class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional @@ -621,17 +665,6 @@ def damping_func(): self._num_locations) self._damping_func = _package_func(damping_func, damping_id) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_matpower( - reshaped_vect, exp, self._damping_func) - return utils.mat2d_to_layer_params(vector, reshaped_out) - class KroneckerProductFB(FisherBlock): """A base class for blocks with separate input and output Kronecker factors. @@ -640,9 +673,6 @@ class KroneckerProductFB(FisherBlock): output factors. """ - def __init__(self, layer_collection): - super(KroneckerProductFB, self).__init__(layer_collection) - def _setup_damping(self, damping, normalization=None): """Makes functions that compute the damping values for both factors.""" def compute_damping(): @@ -651,9 +681,10 @@ def compute_damping(): else: maybe_normalized_damping = damping - return compute_pi_adjusted_damping(self._input_factor.get_cov(), - self._output_factor.get_cov(), - maybe_normalized_damping**0.5) + return compute_pi_adjusted_damping( + self._input_factor.get_cov_as_linear_operator(), + self._output_factor.get_cov_as_linear_operator(), + maybe_normalized_damping**0.5) if normalization is not None: damping_id = ("compute_pi_adjusted_damping", @@ -675,6 +706,14 @@ def register_matpower(self, exp): self._input_factor.register_matpower(exp, self._input_damping_func) self._output_factor.register_matpower(exp, self._output_damping_func) + def register_cholesky(self): + self._input_factor.register_cholesky(self._input_damping_func) + self._output_factor.register_cholesky(self._output_damping_func) + + def register_cholesky_inverse(self): + self._input_factor.register_cholesky_inverse(self._input_damping_func) + self._output_factor.register_cholesky_inverse(self._output_damping_func) + @property def _renorm_coeff(self): """Kronecker factor multiplier coefficient. @@ -687,17 +726,47 @@ def _renorm_coeff(self): """ return 1.0 - def multiply_matpower(self, vector, exp): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply_matpower( - reshaped_vector, exp, self._output_damping_func) - reshaped_out = self._input_factor.left_multiply_matpower( - reshaped_out, exp, self._input_damping_func) - if self._renorm_coeff != 1.0: - renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) - reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype) + reshaped_out = right_factor.matmul_right(reshaped_vector, + adjoint=transpose_right) + reshaped_out = left_factor.matmul(reshaped_out, + adjoint=transpose_left) + if extra_scale != 1.0: + reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) + def multiply_matpower(self, vector, exp): + left_factor = self._input_factor.get_matpower( + exp, self._input_damping_func) + right_factor = self._output_factor.get_matpower( + exp, self._output_damping_func) + extra_scale = float(self._renorm_coeff)**exp + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale) + + def multiply_cholesky(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky(self._input_damping_func) + right_factor = self._output_factor.get_cholesky(self._output_damping_func) + extra_scale = float(self._renorm_coeff)**0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky_inverse( + self._input_damping_func) + right_factor = self._output_factor.get_cholesky_inverse( + self._output_damping_func) + extra_scale = float(self._renorm_coeff)**-0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block. @@ -706,8 +775,8 @@ def full_fisher_block(self): Returns: The full Fisher block. """ - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() + left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() + right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() return self._renorm_coeff * utils.kronecker_product(left_factor, right_factor) @@ -796,7 +865,7 @@ def instantiate_factors(self, grads_list, damping): class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): - """FisherBlock for convolutional layers using the basic KFC approx. + r"""FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -945,10 +1014,10 @@ def __init__(self, self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) - def multiply_matpower(self, vector, exp): + def _multiply_matrix(self, matrix, vector): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( - conv2d_vector, exp) + conv2d_result = super( + DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) @@ -1016,10 +1085,14 @@ def __init__(self, self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) - def multiply_matpower(self, vector, exp): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( - conv2d_vector, exp) + conv2d_result = super( + DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( + left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, + transpose_left=transpose_left, transpose_right=transpose_right) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) @@ -1233,6 +1306,8 @@ def _process_data(self, grads_list): else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") + else: + inputs = tuple(inputs) # Now we perform the analogous processing for grads_list if isinstance(grads_list[0][0], (list, tuple)): @@ -1275,6 +1350,8 @@ def _process_data(self, grads_list): else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") + else: + grads_list = tuple(tuple(grads) for grads in grads_list) if self._num_uses is None: raise ValueError("You must supply a value for the num_uses argument if " @@ -1664,3 +1741,12 @@ def gamma(x): return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name + + def multiply_cholesky(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + + def multiply_cholesky_inverse(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 7988a3b92bf013..b43232dfafaa6d 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -24,6 +24,7 @@ import numpy as np import six +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops @@ -42,10 +43,14 @@ # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). -INIT_COVARIANCES_AT_ZERO = False +INIT_COVARIANCES_AT_ZERO = True # Whether to zero-debias the moving averages. -ZERO_DEBIAS = False +ZERO_DEBIAS = True + +# Whether to initialize inverse (and other such matrices computed from the cov +# matrices) to the zero matrix (or the identity matrix). +INIT_INVERSES_AT_ZERO = True # When the number of inverses requested from a FisherFactor exceeds this value, # the inverses are computed using an eigenvalue decomposition. @@ -82,6 +87,7 @@ def set_global_constants(init_covariances_at_zero=None, zero_debias=None, + init_inverses_at_zero=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, max_num_outer_products_per_cov_row=None, @@ -92,6 +98,7 @@ def set_global_constants(init_covariances_at_zero=None, """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS + global INIT_INVERSES_AT_ZERO global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW @@ -104,6 +111,8 @@ def set_global_constants(init_covariances_at_zero=None, INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero if zero_debias is not None: ZERO_DEBIAS = zero_debias + if init_inverses_at_zero is not None: + INIT_INVERSES_AT_ZERO = init_inverses_at_zero if eigenvalue_decomposition_threshold is not None: EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: @@ -121,19 +130,21 @@ def set_global_constants(init_covariances_at_zero=None, def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - return array_ops.diag(array_ops.ones(shape[0], dtype)) + if INIT_INVERSES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument if INIT_COVARIANCES_AT_ZERO: - return array_ops.diag(array_ops.zeros(shape[0], dtype)) - return array_ops.diag(array_ops.ones(shape[0], dtype)) + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) -def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: disable=unused-argument +def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype) - return array_ops.ones(shape, dtype) + return array_ops.zeros(shape, dtype=dtype) + return array_ops.ones(shape, dtype=dtype) @contextlib.contextmanager @@ -399,7 +410,7 @@ def _compute_new_cov(self, source, tower): the cov update. Returns: - Tensor of same shape as self.get_cov_var(). + Tensor of same shape as self.get_cov(). """ pass @@ -448,78 +459,43 @@ def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" pass - @abc.abstractmethod def get_cov(self): - """Get full covariance matrix. - - Returns: - Tensor of shape [n, n]. Represents all parameter-parameter correlations - captured by this FisherFactor. - """ - pass - - def get_cov_var(self): - """Get variable backing this FisherFactor. - - May or may not be the same as self.get_cov() - - Returns: - Variable of shape self._cov_shape. - """ return self._cov @abc.abstractmethod - def left_multiply_matpower(self, x, exp, damping_func): - """Left multiplies 'x' by matrix power of this factor (w/ damping applied). - - This calculation is essentially: - (C + damping * I)**exp * x - where * is matrix-multiplication, ** is matrix power, I is the identity - matrix, and C is the matrix represented by this factor. - - x can represent either a matrix or a vector. For some factors, 'x' might - represent a vector but actually be stored as a 2D matrix for convenience. - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - exp: float. The matrix exponent to use. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). + def get_cov_as_linear_operator(self): + pass - Returns: - Tensor of same shape as 'x' representing the result of the multiplication. - """ + @abc.abstractmethod + def register_matpower(self, exp, damping_func): pass @abc.abstractmethod - def right_multiply_matpower(self, x, exp, damping_func): - """Right multiplies 'x' by matrix power of this factor (w/ damping applied). + def register_cholesky(self, damping_func): + pass - This calculation is essentially: - x * (C + damping * I)**exp - where * is matrix-multiplication, ** is matrix power, I is the identity - matrix, and C is the matrix represented by this factor. + @abc.abstractmethod + def register_cholesky_inverse(self, damping_func): + pass - Unlike left_multiply_matpower, x will always be a matrix. + @abc.abstractmethod + def get_matpower(self, exp, damping_func): + pass - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - exp: float. The matrix exponent to use. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). + @abc.abstractmethod + def get_cholesky(self, damping_func): + pass - Returns: - Tensor of same shape as 'x' representing the result of the multiplication. - """ + @abc.abstractmethod + def get_cholesky_inverse(self, damping_func): pass -class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses explicitly. +class DenseSquareMatrixFactor(FisherFactor): + """Base class for FisherFactors that are stored as dense square matrices. - This class explicitly calculates and stores inverses of covariance matrices - provided by the underlying FisherFactor implementation. It is assumed that - vectors can be represented as 2-D matrices. + This class explicitly calculates and stores inverses of their `cov` matrices, + which must be square dense matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -538,7 +514,19 @@ def __init__(self): self._eigendecomp = None self._damping_funcs_by_id = {} # {hashable: lambda} - super(InverseProvidingFactor, self).__init__() + self._cholesky_registrations = set() # { hashable } + self._cholesky_inverse_registrations = set() # { hashable } + + self._cholesky_by_damping = {} # { hashable: variable } + self._cholesky_inverse_by_damping = {} # { hashable: variable } + + super(DenseSquareMatrixFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self.get_cov().shape.ndims == 2 + return lo.LinearOperatorFullMatrix(self.get_cov(), + is_self_adjoint=True, + is_square=True) def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) @@ -563,8 +551,6 @@ def register_matpower(self, exp, damping_func): be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: - # We don't register these. The user shouldn't even be calling this - # function with exp = 1.0. return damping_id = self._register_damping(damping_func) @@ -572,6 +558,38 @@ def register_matpower(self, exp, damping_func): if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) + def register_cholesky(self, damping_func): + """Registers a Cholesky factor to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_registrations: + self._cholesky_registrations.add(damping_id) + + def register_cholesky_inverse(self, damping_func): + """Registers an inverse Cholesky factor to be maintained/served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky_inverse. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_inverse_registrations: + self._cholesky_inverse_registrations.add(damping_id) + def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" @@ -589,6 +607,32 @@ def instantiate_inv_variables(self): assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower + for damping_id in self._cholesky_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + chol = variable_scope.get_variable( + "cholesky_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_by_damping + self._cholesky_by_damping[damping_id] = chol + + for damping_id in self._cholesky_inverse_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + cholinv = variable_scope.get_variable( + "cholesky_inverse_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_inverse_by_damping + self._cholesky_inverse_by_damping[damping_id] = cholinv + def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] @@ -606,7 +650,8 @@ def make_inverse_update_ops(self): # We precompute these so we don't need to evaluate them multiple times (for # each matrix power that uses them) - damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]() + damping_value_by_id = {damping_id: math_ops.cast( + self._damping_funcs_by_id[damping_id](), self._dtype) for damping_id in self._damping_funcs_by_id} if use_eig: @@ -627,29 +672,91 @@ def make_inverse_update_ops(self): self._matpower_by_exp_and_damping.items()): assert exp == -1 damping = damping_value_by_id[damping_id] - ops.append(matpower.assign(utils.posdef_inv(self._cov, damping))) + ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) + + # TODO(b/77902055): If inverses are being computed with Cholesky's + # we can share the work. Instead this code currently just computes the + # Cholesky a second time. It does at least share work between requests for + # Cholesky's and Cholesky inverses with the same damping id. + for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): + cholesky_ops = [] + + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + + if damping_id in self._cholesky_by_damping: + cholesky = self._cholesky_by_damping[damping_id] + cholesky_ops.append(cholesky.assign(cholesky_value)) + + identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], + dtype=cholesky_value.dtype) + cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, + identity) + cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) + + ops.append(control_flow_ops.group(*cholesky_ops)) + + for damping_id, cholesky in self._cholesky_by_damping.items(): + if damping_id not in self._cholesky_inverse_by_damping: + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + ops.append(cholesky.assign(cholesky_value)) self._eigendecomp = False return ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests - damping_id = graph_func_to_id(damping_func) - return self._matpower_by_exp_and_damping[(-1, damping_id)] + return self.get_matpower(-1, damping_func) def get_matpower(self, exp, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + if exp != 1: + damping_id = graph_func_to_id(damping_func) + matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] + else: + matpower = self.get_cov() + identity = linalg_ops.eye(matpower.shape.as_list()[0], + dtype=matpower.dtype) + matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity + + assert matpower.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(matpower, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). damping_id = graph_func_to_id(damping_func) - return self._matpower_by_exp_and_damping[(exp, damping_id)] + cholesky = self._cholesky_by_damping[damping_id] + assert cholesky.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky, + is_non_singular=True, + is_square=True) + + def get_cholesky_inverse(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky_inv = self._cholesky_inverse_by_damping[damping_id] + assert cholesky_inv.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky_inv, + is_non_singular=True, + is_square=True) def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_matpower this doesn't retrieve a stored variable, but instead # always computes a fresh version from the current value of get_cov(). if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here @@ -660,45 +767,8 @@ def get_eigendecomp(self): return self._eigendecomp - def get_cov(self): - # Variable contains full covariance matrix. - return self.get_cov_var() - - def left_multiply_matpower(self, x, exp, damping_func): - if isinstance(x, tf_ops.IndexedSlices): - raise ValueError("Left-multiply not yet supported for IndexedSlices.") - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - if exp == 1: - return math_ops.matmul(self.get_cov(), x) + damping_func() * x - - return math_ops.matmul(self.get_matpower(exp, damping_func), x) - - def right_multiply_matpower(self, x, exp, damping_func): - if isinstance(x, tf_ops.IndexedSlices): - if exp == 1: - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping_func() * array_ops.eye(n) - return utils.matmul_sparse_dense(x, damped_cov) - - return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func)) - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - if exp == 1: - return math_ops.matmul(x, self.get_cov()) + damping_func() * x - - return math_ops.matmul(x, self.get_matpower(exp, damping_func)) - - -class FullFactor(InverseProvidingFactor): +class FullFactor(DenseSquareMatrixFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. Note that this uses the naive "square the sum estimator", and so is applicable @@ -757,41 +827,51 @@ class DiagonalFactor(FisherFactor): """ def __init__(self): - self._damping_funcs_by_id = {} # { hashable: lambda } super(DiagonalFactor, self).__init__() + def get_cov_as_linear_operator(self): + assert self._matrix_diagonal.shape.ndims == 1 + return lo.LinearOperatorDiag(self._matrix_diagonal, + is_self_adjoint=True, + is_square=True) + @property def _cov_initializer(self): return diagonal_covariance_initializer + @property + def _matrix_diagonal(self): + return array_ops.reshape(self.get_cov(), [-1]) + def make_inverse_update_ops(self): return [] def instantiate_inv_variables(self): pass - def get_cov(self): - # self.get_cov() could be any shape, but it must have one entry per - # parameter. Flatten it into a vector. - cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) - return array_ops.diag(cov_diag_vec) + def register_matpower(self, exp, damping_func): + pass - def left_multiply_matpower(self, x, exp, damping_func): - matpower = (self.get_cov_var() + damping_func())**exp + def register_cholesky(self, damping_func): + pass - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x) + def register_cholesky_inverse(self, damping_func): + pass - if x.shape != matpower.shape: - raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, matpower)) - return matpower * x + def get_matpower(self, exp, damping_func): + matpower_diagonal = (self._matrix_diagonal + + math_ops.cast(damping_func(), self._dtype))**exp + return lo.LinearOperatorDiag(matpower_diagonal, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) - def right_multiply_matpower(self, x, exp, damping_func): - raise NotImplementedError("Only left-multiply is currently supported.") + def get_cholesky(self, damping_func): + return self.get_matpower(0.5, damping_func) - def register_matpower(self, exp, damping_func): - pass + def get_cholesky_inverse(self, damping_func): + return self.get_matpower(-0.5, damping_func) class NaiveDiagonalFactor(DiagonalFactor): @@ -1167,7 +1247,7 @@ def _get_data_device(self, tower): return self._inputs[tower].device -class FullyConnectedKroneckerFactor(InverseProvidingFactor): +class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ @@ -1220,7 +1300,7 @@ def _get_data_device(self, tower): return self._tensors[0][tower].device -class ConvInputKroneckerFactor(InverseProvidingFactor): +class ConvInputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the input side of a convolutional layer. Estimates E[ a a^T ] where a is the inputs to a convolutional layer given @@ -1384,7 +1464,7 @@ def _get_data_device(self, tower): return self._inputs[tower].device -class ConvOutputKroneckerFactor(InverseProvidingFactor): +class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the output side of a convolutional layer. Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer @@ -1674,6 +1754,7 @@ def make_inverse_update_ops(self): psi_var) in self._option1quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) invsqrtC0 = math_ops.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) @@ -1702,6 +1783,7 @@ def make_inverse_update_ops(self): mu_var) in self._option2quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) # compute C0^(-1/2) invsqrtC0 = math_ops.matmul( diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 366e2a82d56602..cbbfe7212c9d94 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -182,7 +182,7 @@ def __init__(self, self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None - self._default_generic_approximation = APPROX_FULL_NAME + self._default_generic_approximation = APPROX_DIAGONAL_NAME self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_conv2d_approximation = APPROX_KRONECKER_NAME diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py new file mode 100644 index 00000000000000..61cb955ae85df9 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SmartMatrices definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.ops.linalg import linear_operator_util as lou + + +class LinearOperatorExtras(object): # pylint: disable=missing-docstring + + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + if isinstance(x, ops.IndexedSlices): + return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -2 if adjoint else -1 + arg_dim = -1 if adjoint_arg else -2 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + + if isinstance(x, ops.IndexedSlices): + return self._matmul_right_sparse( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -1 if adjoint else -2 + arg_dim = -2 if adjoint_arg else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + +class LinearOperatorFullMatrix(LinearOperatorExtras, + linalg.LinearOperatorFullMatrix): + + # TODO(b/78117889) Remove this definition once core LinearOperator + # has _matmul_right. + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + return lou.matmul_with_broadcast( + x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + assert not adjoint and not adjoint_arg + return utils.matmul_sparse_dense(x, self._matrix) + + +class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring + linalg.LinearOperatorDiag): + + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + x = linalg_impl.adjoint(x) if adjoint_arg else x + return diag_mat * x + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + assert not adjoint_arg + return utils.matmul_diag_sparse(diag_mat, x) + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index f01c5a832212f8..03b9da793307b9 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -19,6 +19,7 @@ from __future__ import print_function import warnings + # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -67,7 +68,7 @@ def __init__(self, the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. If damping is adapted during training then this value is used for - initializing damping varaible. + initializing damping variable. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher @@ -108,6 +109,10 @@ def __init__(self, ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) # Parameters to be passed to the Fisher estimator: self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay @@ -115,7 +120,7 @@ def __init__(self, self._estimation_mode = estimation_mode self._colocate_gradients_with_ops = colocate_gradients_with_ops - # The below paramaters are required only if damping needs to be adapated. + # The below parameters are required only if damping needs to be adapated. # These parameters can be set by calling # set_damping_adaptation_params() explicitly. self._damping_adaptation_decay = 0.95 @@ -196,7 +201,7 @@ def set_damping_adaptation_params(self, min_damping: `float`(Optional), Minimum value the damping parameter can take. Default value 1e-5. damping_adaptation_decay: `float`(Optional), The `damping` parameter is - multipled by the `damping_adaptation_decay` every + multiplied by the `damping_adaptation_decay` every `damping_adaptation_interval` number of iterations. Default value 0.99. damping_adaptation_interval: `int`(Optional), Number of steps in between updating the `damping` parameter. Default value 5. @@ -243,62 +248,6 @@ def damping(self): def damping_adaptation_interval(self): return self._damping_adaptation_interval - @property - def cov_update_thunks(self): - self._maybe_make_and_save_everything() - return self._cov_update_thunks - - @property - def cov_update_ops(self): - self._maybe_make_and_save_everything() - return self._cov_update_ops - - @property - def cov_update_op(self): - self._maybe_make_and_save_everything() - return self._cov_update_op - - @property - def inv_update_thunks(self): - self._maybe_make_and_save_everything() - return self._inv_update_thunks - - @property - def inv_update_ops(self): - self._maybe_make_and_save_everything() - return self._inv_update_ops - - @property - def inv_update_op(self): - self._maybe_make_and_save_everything() - return self._inv_update_op - - def _maybe_make_and_save_everything(self): - if not self._fisher_est.made_vars(): - warnings.warn("These convenience properties will be depcrecated soon. " - "Please use explicit op/thunk creation methods instead " - "(e.g. make_ops_and_vars, etc).", - DeprecationWarning) - (self._cov_update_ops, self._cov_update_op, self._inv_update_ops, - self._inv_update_op, self._cov_update_thunks, - self._inv_update_thunks) = self.make_ops_and_vars() - - def make_ops_and_vars(self): - """Make ops and vars with device placement `self._placement_strategy`. - - See `FisherEstimator.make_ops_and_vars` for details. - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_op: inv_update_ops grouped into a single op. - """ - return self._fisher_est.make_ops_and_vars(scope=self.get_name()) - def make_vars_and_create_op_thunks(self): """Make vars and create op thunks. @@ -385,7 +334,6 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs): Returns: An `Operation` that applies the specified gradients. """ - self._maybe_make_and_save_everything() # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py index bf12dbaa9adbaa..c4454325aebe13 100644 --- a/tensorflow/contrib/kfac/python/ops/placement.py +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -21,8 +21,6 @@ import itertools from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope def _make_thunk_on_device(func, device): @@ -35,7 +33,7 @@ def thunk(): class RoundRobinPlacementMixin(object): """Implements round robin placement strategy for ops and variables.""" - def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + def __init__(self, cov_devices=None, inv_devices=None, **kwargs): """Initializes the RoundRobinPlacementMixin class. Args: @@ -45,66 +43,15 @@ def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. - *args: - **kwargs: + **kwargs: Need something here? """ - super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + super(RoundRobinPlacementMixin, self).__init__(**kwargs) self._cov_devices = cov_devices self._inv_devices = inv_devices - def make_ops_and_vars(self, scope=None): - """Make ops and vars with a round-robin device placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the - `self._cov_devices` attribute. If `self._cov_devices` is `None` then no - explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the `self._inv_devices` attribute. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - (cov_update_thunks, - inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope) - cov_update_ops = [thunk() for thunk in cov_update_thunks] - inv_update_ops = [thunk() for thunk in inv_update_thunks] - - scope = self.name if scope is None else scope - with variable_scope.variable_scope(scope): - cov_update_op = control_flow_ops.group(cov_update_ops, - name="cov_update_op") - inv_update_op = control_flow_ops.group(inv_update_ops, - name="inv_update_op") - - return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, - cov_update_thunks, inv_update_thunks) - def make_vars_and_create_op_thunks(self, scope=None): - """Make vars and create op thunks w/ a round-robin device placement strat. + """Make vars and create op thunks w/ a round-robin device placement start. For each factor, all of that factor's cov variables and their associated update ops will be placed on a particular device. A new device is chosen diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index b6f42815e79fa5..144295f4c7e36f 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -235,6 +235,13 @@ def posdef_eig_self_adjoint(mat): } +def cholesky(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return linalg_ops.cholesky(tensor + damping * identity) + + class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ @@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format): return data_format.endswith("C") -def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name +def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. Args: A: tf.IndexedSlices with dense shape [m, n]. B: tf.Tensor with shape [n, k]. name: str. Name of op. + transpose_a: Bool. If true we transpose A before multiplying it by B. + (Default: False) + transpose_b: Bool. If true we transpose B before multiplying it by A. + (Default: False) Returns: tf.IndexedSlices resulting from matmul(A, B). @@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name raise ValueError("A must represent a matrix. Found: %s." % A) if B.shape.ndims != 2: raise ValueError("B must be a matrix.") - new_values = math_ops.matmul(A.values, B) + new_values = math_ops.matmul( + A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) return ops.IndexedSlices( new_values, A.indices, diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index b527cdad7088a9..7e79b48cdbe14e 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -381,7 +381,7 @@ py_test( py_test( name = "rev_block_lib_test", - size = "small", + size = "medium", srcs = ["python/layers/rev_block_lib_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 49c3faf3b7f5ea..60e1d85ea9c08a 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -458,7 +458,7 @@ def scattered_embedding_lookup_sparse(params, return embeddings -def embedding_lookup_unique(params, ids, name=None): +def embedding_lookup_unique(params, ids, partition_strategy="mod", name=None): """Version of embedding_lookup that avoids duplicate lookups. This can save communication in the case of repeated ids. @@ -470,6 +470,9 @@ def embedding_lookup_unique(params, ids, name=None): `PartitionedVariable`. Shape `[index, d1, d2, ...]`. ids: A one-dimensional `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`. + partition_strategy: A string specifying the partitioning strategy, relevant + if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default + is `"mod"`. name: A name for this operation (optional). Returns: @@ -485,7 +488,8 @@ def embedding_lookup_unique(params, ids, name=None): ids_flat = array_ops.reshape( ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) - unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) + unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids, + partition_strategy) embeds_flat = array_ops.gather(unique_embeddings, idx) embed_shape = array_ops.concat( [shape, array_ops.shape(unique_embeddings)[1:]], 0) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index bf2514498202e9..dd2395f8c9748d 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops @@ -691,11 +692,12 @@ def _GroupByBatchEntry(self, vals, vals_per_batch_entry): index += num_val return grouped_vals + @test_util.enable_c_shapes def testEmbeddingLookupSparse(self): vocab_size = 13 batch_size = 10 param_shape = [2, 5] - expected_lookup_result_shape = [None] + param_shape + expected_lookup_result_shape = param_shape sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( self._RandomIdsAndWeights(batch_size, vocab_size)) @@ -719,7 +721,7 @@ def testEmbeddingLookupSparse(self): None if ignore_weights else sp_weights, combiner=combiner) - self.assertEqual(embedding_sum.get_shape().as_list(), + self.assertEqual(embedding_sum.get_shape().as_list()[1:], expected_lookup_result_shape) tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 2f3e57653c5d6d..b7194ae3330450 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2022,6 +2022,7 @@ def build(self, input_shape): def beta_initializer(shape, dtype=None, partition_info=None): del partition_info # unused + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal) def gamma_initializer(shape, dtype=None, partition_info=None): @@ -2029,6 +2030,7 @@ def gamma_initializer(shape, dtype=None, partition_info=None): assert len(shape) == 2 assert shape[0] == shape[1] eye = linalg_ops.eye(shape[0], dtype=dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) beta = self.add_variable( diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index b01fd5d5c95ac1..56e9194cebbe46 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1333,7 +1333,7 @@ def testCreateDropout(self): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 02d294c68f1e10..0e35b1aa8bf682 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -33,23 +33,32 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") _USE_DEFAULT = "__rev_block_lib_default" +_WRONG_VARS_ERR = """\ +The variables used on recompute were different than the variables originally +used. The function wrapped with @recompute_grad likley creates its own variable +scope with a default name and has been called twice in the same enclosing scope. +To fix, ensure each call to the function happens in its own unique variable +scope. +""" def _acc_grads(*lists_of_grads): @@ -146,7 +155,7 @@ def _scope_wrap(fn, scope): @functools.wraps(fn) def wrap(*args, **kwargs): - with variable_scope.variable_scope(scope): + with variable_scope.variable_scope(scope, use_resource=True): return fn(*args, **kwargs) return wrap @@ -221,95 +230,95 @@ def build(self, _): "build.") self.built = True - def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): - """Custom gradient fn for a block of reversible residual layers.""" - # Inputs have passed through an Identity. Recover the original Tensors to - # be able to match up side inputs. - assert [u"Identity"] == list(set([x.op.type for x in inputs])) - inputs = [x.op.inputs[0] for x in inputs] - side_inputs = inputs[2:] - del inputs - - f_side_idxs = [None] * len(self.f_side_input) - g_side_idxs = [None] * len(self.g_side_input) - assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) - - for i, t in enumerate(side_inputs): - if t in self.f_side_input: - f_side_idxs[self.f_side_input.index(t)] = i - elif t in self.g_side_input: - g_side_idxs[self.g_side_input.index(t)] = i - else: - assert False - - f_vars = [[] for _ in range(self.num_layers)] - g_vars = [[] for _ in range(self.num_layers)] - f_vars_idxs = [[] for _ in range(self.num_layers)] - g_vars_idxs = [[] for _ in range(self.num_layers)] - - for i, ref in enumerate(variables): - # Use the name to identify the layer number and function (f or g) - regex = LAYER_RE.match(ref.name) - layer_no = int(regex.group(1)) - fn_name = regex.group(2) - if fn_name == "f": - f_vars[layer_no].append(ref) - f_vars_idxs[layer_no].append(i) - else: - assert fn_name == "g" - g_vars[layer_no].append(ref) - g_vars_idxs[layer_no].append(i) - - f_var_grads = [] - g_var_grads = [] - f_side_grads = [] - g_side_grads = [] - - # Reverse variable containers to go backward - f_vars.reverse() - g_vars.reverse() - f = list(self.f) - g = list(self.g) - f.reverse() - g.reverse() - - with variable_scope.variable_scope(self.scope_name, reuse=True): - for i in xrange(self.num_layers): - ys, grad_ys, f_ret, g_ret = _rev_layer_backward( - ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], - self.g_side_input) - - grad_f_vars, grad_f_side = f_ret - grad_g_vars, grad_g_side = g_ret - f_var_grads.append(grad_f_vars) - g_var_grads.append(grad_g_vars) - f_side_grads.append(grad_f_side) - g_side_grads.append(grad_g_side) - - # Accumulate layer gradients for f_side_input and g_side_input - acc_f_side_grads = _acc_grads(*f_side_grads) - acc_g_side_grads = _acc_grads(*g_side_grads) - - # Use the stored idxs to put gradients in the passed-in order. - side_input_grads = [None] * len(side_inputs) - variable_grads = [None] * len(variables) - - # Variable gradients were collected in reverse layer order. Reverse to match - # idxs. - f_var_grads.reverse() - g_var_grads.reverse() - for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( - zip(g_vars_idxs, g_var_grads)): - for i, grad in zip(idxs, grads): - variable_grads[i] = grad - - for i, grad in zip(f_side_idxs, acc_f_side_grads): - side_input_grads[i] = grad - for i, grad in zip(g_side_idxs, acc_g_side_grads): - side_input_grads[i] = grad - - grad_x1, grad_x2 = grad_ys - return [grad_x1, grad_x2] + side_input_grads, variable_grads + def _make_efficient_grad_fn(self, inputs_, ys_): + def _efficient_grad_fn(*grad_ys, **kwargs): + """Custom gradient fn for a block of reversible residual layers.""" + inputs = inputs_ + ys = ys_ + variables = kwargs["variables"] + side_inputs = inputs[2:] + + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False + + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] + + for i, ref in enumerate(variables): + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) + + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] + + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) + + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) + + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) + + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) + + # Variable gradients were collected in reverse layer order. Reverse to + # match idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad + + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad + + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + return _efficient_grad_fn def _forward(self, x1, x2): """Run forward through the reversible layers.""" @@ -317,10 +326,6 @@ def _forward(self, x1, x2): side_inputs = [self.f_side_input, self.g_side_input] flat_side_inputs = nest.flatten(side_inputs) - custom_grad_fn = ( - self._efficient_grad_fn if self._use_efficient_backprop else None) - - @_fn_with_custom_grad(custom_grad_fn) def _forward_wrap(x1_, x2_, *flat_side_inputs): f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) return _rev_block_forward( @@ -333,7 +338,16 @@ def _forward_wrap(x1_, x2_, *flat_side_inputs): g_side_input=g_side, gate_outputs=self._use_efficient_backprop) - return _forward_wrap(x1, x2, *flat_side_inputs) + @custom_gradient.custom_gradient + def _forward_with_custom_grad(*args): + out = _forward_wrap(*args) # pylint: disable=no-value-for-parameter + grad_fn = self._make_efficient_grad_fn(args, out) + return out, grad_fn + + if self._use_efficient_backprop: + return _forward_with_custom_grad(x1, x2, *flat_side_inputs) + else: + return _forward_wrap(x1, x2, *flat_side_inputs) def _backward(self, y1, y2): """Run backward through the reversible layers.""" @@ -432,6 +446,19 @@ def new_dec(*args, **kwargs): def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. + To use this function, you must use `ResourceVariable`s (i.e. + `variable_scope(name, use_resource=True), which are the default in Eager mode + and when running on TPU. + + Warning: Because the function will be called again on the backwards pass, the + user should be careful to not use ops in their function that mutate state or + have randomness (for example, batch normalization or dropout). If the function + does have such operations, it is recommended that the function take the + `is_recomputing` keyword argument which will be `False` on the forward pass + and `True` on the backwards pass so that it can disable state changes when + `is_recomputing=True` (for example, not updating the moving averages in batch + normalization). + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -465,6 +492,7 @@ def _is_on_tpu(): def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args for arg in args: if not isinstance(arg, framework_ops.Tensor): raise ValueError("All inputs to function must be Tensors") @@ -472,44 +500,61 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if use_data_dep_ == _USE_DEFAULT: use_data_dep_ = _is_on_tpu() - cached_vs = [] - cached_arg_scope = [] - - def grad_fn(inputs, variables, outputs, output_grads): - """Recompute outputs for gradient computation.""" - del outputs - # Recompute outputs - with framework_ops.control_dependencies(output_grads): - if use_data_dep_: - inputs = _force_data_dependency(output_grads, inputs) - with contrib_framework_ops.arg_scope(cached_arg_scope[0]): - with variable_scope.variable_scope(cached_vs[0], reuse=True): - outputs = fn(*inputs) - - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): - outputs = [outputs] - outputs = list(outputs) - grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) - - if tupleize_grads: - if use_data_dep_: - grads = _tuple_with_data_dep(grads) - else: - grads = control_flow_ops.tuple(grads) - - grad_inputs = grads[:len(inputs)] - grad_vars = grads[len(inputs):] - return grad_inputs, grad_vars - - @_fn_with_custom_grad(grad_fn) + @custom_gradient.custom_gradient def fn_with_recompute(*args): - cached_vs.append(variable_scope.get_variable_scope()) - # TODO(rsepassi): Rm conditional in TF 1.4 - if hasattr(contrib_framework_ops, "current_arg_scope"): - cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) - else: - cached_arg_scope.append({}) - return fn(*args) + """Wrapper for fn.""" + # Forward pass + vs = variable_scope.get_variable_scope() + arg_scope = contrib_framework_ops.current_arg_scope() + with backprop.GradientTape() as tape: + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = False + outputs = fn(*args, **fn_kwargs) + original_vars = set(tape.watched_variables()) + + # Backward pass + def grad_fn(*output_grads, **kwargs): + """Recompute outputs for gradient computation.""" + variables = [] + if original_vars: + variables = kwargs["variables"] + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + del kwargs + inputs = list(args) + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) + with contrib_framework_ops.arg_scope(arg_scope): + with variable_scope.variable_scope(vs, reuse=True): + with backprop.GradientTape() as tape: + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = fn(*inputs, **fn_kwargs) + recompute_vars = set(tape.watched_variables()) + if original_vars != recompute_vars: + raise ValueError(_WRONG_VARS_ERR) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, + output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + return outputs, grad_fn return fn_with_recompute(*args) @@ -536,107 +581,6 @@ def _underlying_variable_ref(t): return None -def _fn_with_custom_grad(grad_fn, use_global_vars=False): - """Decorator to create a subgraph with a custom gradient function. - - The subgraph created by the decorated function is NOT put in a Defun and so - does not suffer from the limitations of the Defun (all subgraph ops on the - same device, no summaries). - - Args: - grad_fn: function with signature - (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - Decorator for function such that the gradient is defined by grad_fn. - """ - - def dec(fn): - - @functools.wraps(fn) - def wrapped(*args): - return _fn_with_custom_grad_internal( - fn, args, grad_fn, use_global_vars=use_global_vars) - - return wrapped - - return dec - - -def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): - """Create a subgraph with a custom gradient. - - Args: - fn: function that takes inputs as arguments and produces 1 or more Tensors. - inputs: list, will be passed as fn(*inputs). - grad_fn: function with signature - (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - fn(*inputs) - """ - vs = variable_scope.get_variable_scope() - get_vars_fn = ( - vs.global_variables if use_global_vars else vs.trainable_variables) - len_before_vars = len(get_vars_fn()) - inputs = [array_ops.identity(x) for x in inputs] - outputs = fn(*inputs) - train_vars = get_vars_fn()[len_before_vars:] - - if grad_fn is None: - return outputs - - if not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] - outputs = list(outputs) - - defun_inputs = [inputs, train_vars, outputs] - - def custom_grad_fn(op, *dys): - """Custom grad fn applying grad_fn for identity Defun.""" - fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( - defun_inputs, list(op.inputs)) - fn_vars = [_underlying_variable_ref(v) for v in fn_vars] - dys = list(dys) - assert len(fn_outputs) == len(outputs) - assert len(fn_outputs) == len(dys) - - grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) - grad_outputs = [None] * len(fn_outputs) - return tuple(grad_inputs + grad_vars + grad_outputs) - - # The Defun takes as input the original inputs, the trainable variables - # created in fn, and the outputs. In the forward it passes through the - # outputs. In the backwards, it produces gradients for the original inputs - # and the trainable variables. - in_types = [t.dtype for t in inputs] - out_types = [t.dtype for t in outputs] - var_types = [t.dtype for t in train_vars] - - # Get a unique name for the Defun - with framework_ops.name_scope("identity_custom_grad") as ns: - defun_name = ns - - @function.Defun( - *(in_types + var_types + out_types), - func_name=defun_name, - python_grad_func=custom_grad_fn, - shape_func=lambda _: [t.get_shape() for t in outputs]) - def identity(*args): - _, _, outs = nest.pack_sequence_as(defun_inputs, args) - return tuple([array_ops.identity(t) for t in outs]) - - flat_inputs = nest.flatten(defun_inputs) - id_out = identity(*flat_inputs) - return id_out - - def _force_data_dependency(first_compute, then_compute): """Force all of `then_compute` to depend on all of `first_compute`. diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 8c118402a4c85d..bc09ba8d439808 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -21,9 +21,11 @@ from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import rev_block_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import convolutional from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -83,8 +85,8 @@ def g(x): sess.run(variables.global_variables_initializer()) y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) - self.assertAllClose(y1, y1_inv) - self.assertAllClose(y2, y2_inv) + self.assertAllClose(y1, y1_inv, rtol=1e-5) + self.assertAllClose(y2, y2_inv, rtol=1e-5) def _testRevBlock(self, x=None, @@ -179,18 +181,16 @@ def f2(x): self._testRevBlock(f=[f1, f2, f1, f2]) - # TODO(rsepassi): Recent change to conv seems to have broken this test. Find - # out why. - def _testConvAndBatchNorm(self): + def testConvAndBatchNorm(self): x = random_ops.random_uniform( [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) def f(x): x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) return x self._testRevBlock(x=x, f=f) @@ -278,7 +278,7 @@ def fn_both(x): ] outputs_and_vars = [] for name, wrapped_fn in names_and_fns: - with variable_scope.variable_scope(name) as vs: + with variable_scope.variable_scope(name, use_resource=True) as vs: out = math_ops.reduce_sum(wrapped_fn(x)) outputs_and_vars.append((out, vs.trainable_variables())) @@ -304,103 +304,73 @@ def fn_both(x): self.assertAllClose(current, g) current = g - def testResourceVariable(self): - @rev_block_lib.recompute_grad(tupleize_grads=True) + def testDoubleCallInSameScopeFails(self): + + @rev_block_lib.recompute_grad def layer_with_recompute(inputs): - var = variable_scope.get_variable("var", ()) - return var * inputs + return core_layers.dense(inputs, 2) - inputs = array_ops.ones((), dtypes.float32) with variable_scope.variable_scope("layer", use_resource=True): - outputs = layer_with_recompute(inputs) - loss = math_ops.square(outputs) - grads = gradients_impl.gradients(loss, variables.trainable_variables()) - self.assertEqual(1, len(grads)) - self.assertTrue(grads[0] is not None) + inputs = array_ops.ones((2, 4), dtypes.float32) + out1 = layer_with_recompute(inputs) + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) + tvars = variables.trainable_variables() + assert len(tvars) == 4 + with self.assertRaisesWithPredicateMatch( + ValueError, "called twice in the same enclosing scope"): + gradients_impl.gradients(out, [inputs] + tvars) -class FnWithCustomGradTest(test.TestCase): + def testDoubleCallInUniqueScope(self): - def testCorrectness(self): + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs): + with variable_scope.variable_scope("inner", use_resource=True): + return core_layers.dense(inputs, 2) - w = random_ops.random_uniform([6, 10]) + with variable_scope.variable_scope("layer", use_resource=True): + inputs = array_ops.ones((2, 4), dtypes.float32) - def fn(a, b, c): - return core_layers.dense( - a, - 10, - use_bias=False, - kernel_initializer=lambda shape, dtype, partition_info: w - ) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, outputs, grad_outputs): - outputs = outputs[0] - grad_outputs = grad_outputs[0] - grad_inputs = gradients_impl.gradients( - outputs, inputs, grad_ys=grad_outputs) - grad_vars = gradients_impl.gradients( - outputs, trainable_variables, grad_ys=grad_outputs) - return grad_inputs, grad_vars - - custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - - out = fn(a, b, c) - custom_out = custom_fn(a, b, c) - self.assertEqual(out.get_shape().as_list(), - custom_out.get_shape().as_list()) - - loss = math_ops.reduce_mean(out) - custom_loss = math_ops.reduce_mean(custom_out) - - grads = gradients_impl.gradients( - loss, [a, b, c] + [variables.trainable_variables()[0]]) - custom_grads = gradients_impl.gradients( - custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) + with variable_scope.variable_scope("layer1", use_resource=True): + out1 = layer_with_recompute(inputs) + with variable_scope.variable_scope("layer2", use_resource=True): + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - out_val, custom_out_val, grads_val, custom_grads_val = sess.run( - [out, custom_out, grads, custom_grads]) - self.assertAllClose(out_val, custom_out_val) - for g1, g2 in zip(grads_val, custom_grads_val): - self.assertAllClose(g1, g2) - - def testCustomGrad(self): - - def fn(a, b, c): - return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, unused_outputs, - unused_grad_outputs): - grad_inputs = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) - ] - grad_vars = [ - array_ops.ones_like(t) * (i + len(inputs) + 1.) - for i, t in enumerate(trainable_variables) - ] - return grad_inputs, grad_vars - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - w = random_ops.random_uniform([6, 10]) - out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) - loss = math_ops.reduce_mean(out) - grads = gradients_impl.gradients( - loss, [a, b, c, variables.trainable_variables()[0]]) - expected_grads = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) - ] - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - g_val, eg_val = sess.run([grads, expected_grads]) - for g1, g2 in zip(g_val, eg_val): - self.assertAllClose(g1, g2) + tvars = variables.trainable_variables() + assert len(tvars) == 4 + grads = gradients_impl.gradients(out, [inputs] + tvars) + for grad in grads: + self.assertTrue(grad is not None) + + def testWithIsRecomputeKwarg(self): + + kwarg_values = [] + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs, is_recomputing=False): + kwarg_values.append(is_recomputing) + out = core_layers.dense(inputs, 2) + out = normalization_layers.batch_normalization(out, training=True) + if is_recomputing: + # Ensure that the updates are not duplicated by popping off the latest + # 2 additions. + update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) + update_ops.pop() + update_ops.pop() + return out + + x = array_ops.ones((2, 4), dtypes.float32) + with variable_scope.variable_scope("layer1", use_resource=True): + y = layer_with_recompute(x) + loss = math_ops.reduce_sum(y) + tvars = variables.trainable_variables() + gradients_impl.gradients(loss, [x] + tvars) + + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + self.assertEqual(2, len(update_ops)) + self.assertEqual([False, True], kwarg_values) if __name__ == "__main__": diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 3b053cd4c66952..b56a88659bbd44 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -284,6 +284,7 @@ py_test( tags = [ "manual", "noasan", # times out + "optonly", # test is flaky without optimization. ], deps = [ ":learn", @@ -434,6 +435,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "noasan", # b/73741358 @@ -485,6 +487,7 @@ py_test( name = "state_saving_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], deps = [ @@ -744,7 +747,7 @@ py_test( tf_py_test( name = "graph_io_test", - size = "small", + size = "medium", srcs = ["python/learn/learn_io/graph_io_test.py"], additional_deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index e28e6854a5097d..339c4e0e360ed9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -1862,12 +1862,12 @@ def _get_arguments(func): if hasattr(func, "__code__"): # Regular function. return tf_inspect.getargspec(func) - elif hasattr(func, "__call__"): - # Callable object. - return _get_arguments(func.__call__) elif hasattr(func, "func"): # Partial function. return _get_arguments(func.func) + elif hasattr(func, "__call__"): + # Callable object. + return _get_arguments(func.__call__) def _verify_loss_fn_args(loss_fn): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 70b70af98c51dc..e100bc7a1e7be4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -31,7 +31,6 @@ from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -51,6 +50,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook from tensorflow.python.training import training as train +from tensorflow.python.training import training_util # The default learning rate of 0.2 is a historical artifact of the initial @@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), + name_or_scope=parent_scope, + partitioner=optimizer.partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index d3bb0fda5765d8..597ca4e86dbf66 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test from tensorflow.python.training import ftrl from tensorflow.python.training import input as input_lib @@ -863,6 +864,38 @@ def input_fn(): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self): + """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + sparse_tensor.SparseTensor( + values=[2., 3., 1.], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]), + 'country': + sparse_tensor.SparseTensor( + # 'GB' is out of the vocabulary. + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]) + }, constant_op.constant([[1], [0], [1]]) + + country = feature_column_lib.sparse_column_with_keys( + 'country', keys=['US', 'CA', 'MK', 'IT', 'CN']) + country_weighted_by_price = feature_column_lib.weighted_sparse_column( + country, 'price') + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id') + classifier = linear.LinearClassifier( + feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerCrossedFeatures(self): """Tests LinearClassifier with SDCAOptimizer and crossed features.""" @@ -934,6 +967,63 @@ def input_fn(): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearClassifier with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + classifier = linear.LinearClassifier( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + print('all scores = {}'.format(scores)) + self.assertGreater(scores['accuracy'], 0.9) + def testEval(self): """Tests that eval produces correct metrics. """ @@ -1508,6 +1598,60 @@ def input_fn(): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearRegressor with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([0.6, 0.8, 0.3]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', symmetric_l2_regularization=1.0, + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + regressor = linear.LinearRegressor( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """Tests LinearClassifier with SDCAOptimizer and sparse features.""" diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 3744abd860e7f4..541da9061732ad 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import function_utils __all__ = ["Experiment"] def _get_standardized_predicate_fn(predicate_fn): - pred_fn_args = estimator_util.fn_args(predicate_fn) + pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): @@ -468,10 +468,15 @@ def _continuous_eval(self, on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` will be None so it's expected that the predicate function handles that gracefully. - When `predicate_fn` is not specified, continuous eval will run in an - infinite loop (if `train_steps` is None). or exit once global step - reaches `train_steps`. - + Continuous eval behavior under different conditions: + * When `predicate_fn` is specified: + + if `train_steps` is None, run until `predicate_fn` returns False. + + if `train_steps` is specified, run until either global step + reaches `train_steps` or `predicate_fn` returns False. + * When `predicate_fn` is not specified: + + if `train_steps` is None, run in an infinite loop. + + if `train_steps` is specified, run until global step reaches + `train_steps`. export: Whether to export from this step. Default is 'True'. Raises: diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index c7cdb4131215c3..f8106d1e4a7e79 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -343,7 +343,8 @@ def get_temp_export_dir(timestamped_export_dir): """ (dirname, basename) = os.path.split(timestamped_export_dir) temp_export_dir = os.path.join( - compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename))) + compat.as_bytes(dirname), + compat.as_bytes('temp-{}'.format(compat.as_text(basename)))) return temp_export_dir diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index fb6a989b76db97..6c18a0712531e4 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -42,47 +42,3 @@ gpu_py_test( "//tensorflow/python:platform_test", ], ) - -gpu_py_test( - name = "linear_operator_block_diag_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_block_diag_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], - shard_count = 5, - tags = [ - "noasan", - "optonly", - ], -) - -gpu_py_test( - name = "linear_operator_kronecker_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_kronecker_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], - shard_count = 8, - tags = [ - "noasan", - "optonly", - ], -) diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 554854da84715e..a262a099cf8f84 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -39,14 +39,14 @@ # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import * -from tensorflow.contrib.linalg.python.ops.linear_operator_kronecker import * from tensorflow.python.ops.linalg.linear_operator import * +from tensorflow.python.ops.linalg.linear_operator_block_diag import * from tensorflow.python.ops.linalg.linear_operator_circulant import * from tensorflow.python.ops.linalg.linear_operator_composition import * from tensorflow.python.ops.linalg.linear_operator_diag import * from tensorflow.python.ops.linalg.linear_operator_full_matrix import * from tensorflow.python.ops.linalg.linear_operator_identity import * +from tensorflow.python.ops.linalg.linear_operator_kronecker import * from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index b5741967ab5256..ef0e08a777779e 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -35,6 +35,8 @@ from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import googletest @@ -132,15 +134,22 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): return examples_dict, variables_dict -def make_variable_dict(max_age, max_gender): +def make_variable_dict(max_age, max_gender, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + partitioner = None + if partitioned: + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, + axis=0) + with variable_scope.variable_scope( + name_or_scope='variables', + partitioner=partitioner): + age_weights = variables_lib.Variable( + array_ops.zeros( + [max_age + 1], dtype=dtypes.float32)) + gender_weights = variables_lib.Variable( + array_ops.zeros( + [max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -265,6 +274,54 @@ def testSimple(self): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testPartitionedPrimals(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0], + 'gender': [0] + }, 0), + make_example_proto({ + 'age': [1], + 'gender': [1] + }, 1), + ] + example_weights = [1.0, 1.0] + for num_shards in _SHARD_NUMBERS: + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1, partitioned=True) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + num_table_shards=num_shards, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + unregularized_loss = lr.unregularized_loss(examples) + loss = lr.regularized_loss(examples) + predictions = lr.predictions(examples) + self.assertAllClose(0.693147, unregularized_loss.eval()) + self.assertAllClose(0.693147, loss.eval()) + train_op = lr.minimize() + for _ in range(_MAX_ITERATIONS): + train_op.run() + lr.update_weights(train_op).run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.411608 is the unregularized_loss at that optimum. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) + self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose( + 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSparseRandom(self): dim = 20 num_examples = 1000 @@ -320,7 +377,10 @@ def testSparseDuplicate(self): train_op.run() def testDistributedSimple(self): - # Setup test data + # Distributed SDCA may not converge if the workers update concurrently the + # same example. In this test the examples are partitioned across workers. + # The examples are the same for all workers, just the example_ids are + # different. example_protos = [ make_example_proto({ 'age': [0], @@ -332,13 +392,19 @@ def testDistributedSimple(self): }, 1), ] example_weights = [1.0, 1.0] + examples = make_example_dict(example_protos, example_weights) + example_ids = array_ops.placeholder( + dtypes.string, shape=(len(example_weights),)) + examples['example_ids'] = example_ids + variables = make_variable_dict(1, 1) for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): - examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) options = dict( - symmetric_l2_regularization=1, + # Keep the same solution as for TestSimple: since the number of + # examples is multplied by num_loss_partitions, multiply also + # L2 by the same value. + symmetric_l2_regularization=num_loss_partitions, symmetric_l1_regularization=0, loss_type='logistic_loss', num_table_shards=num_shards, @@ -354,32 +420,30 @@ def testDistributedSimple(self): train_op = lr.minimize() - def minimize(): + def minimize(worker_id): with self._single_threaded_test_session(): + feed_dict = {example_ids: [ + str(i + worker_id*len(example_weights)) for i in range( + len(example_weights))]} for _ in range(_MAX_ITERATIONS): - train_op.run() # pylint: disable=cell-var-from-loop + train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop threads = [] - for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=minimize)) + for worker_id in range(num_loss_partitions): + threads.append(threading.Thread(target=minimize, args=(worker_id,))) threads[-1].start() for t in threads: t.join() - lr.update_weights(train_op).run() - - # The high tolerance in unregularized_loss comparisons is due to the - # fact that it's possible to trade off unregularized_loss vs. - # regularization and still have a sum that is quite close to the - # optimal regularized_loss value. SDCA's duality gap only ensures - # that the regularized_loss is within 0.01 of optimal. - # 0.525457 is the optimal regularized_loss. - # 0.411608 is the unregularized_loss at that optimum. - self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) - self.assertAllClose(0.525457, loss.eval(), atol=0.01) + lr.update_weights(train_op).run(feed_dict={ + example_ids: [str(i) for i in range(len(example_weights))]}) + + # Test only the unregularized loss because the optimal value of the + # regularized loss depends on num_loss_partitions. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) - self.assertTrue(lr.approximate_duality_gap().eval() < 0.02) + self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02) def testSimpleNoL2(self): # Same as test above (so comments from above apply) but without an L2. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f980746a19fb8e..0047d5753a773c 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,12 +22,14 @@ from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -43,9 +45,6 @@ class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - This class currently only supports a single machine (multi-threaded) - implementation. We expect the weights and duals to fit in a single machine. - Loss functions supported: * Binary logistic loss @@ -182,18 +181,41 @@ def _num_table_shards(self): # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic. def _create_slots(self): - # Make internal variables which have the updates before applying L1 - # regularization. + """Make unshrinked internal variables (slots).""" + # Unshrinked variables have the updates before applying L1 regularization. + # Each unshrinked slot variable is either a `Variable` or list of + # `Variable`, depending on the value of its corresponding primary variable. + # We avoid using `PartitionedVariable` for the unshrinked slots since we do + # not need any of the extra information. self._slots = collections.defaultdict(list) for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is - # fixed - self._slots['unshrinked_' + name].append( - var_ops.Variable( - array_ops.zeros_like(var.initialized_value(), dtypes.float32), - name=var.op.name + '_unshrinked/SDCAOptimizer')) + # Our primary variable may be either a PartitionedVariable, or a list + # of Variables (each representing a partition). + if (isinstance(var, var_ops.PartitionedVariable) or + isinstance(var, list)): + var_list = [] + # pylint: disable=protected-access + for v in var: + with ops.colocate_with(v): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 + # is fixed. + slot_var = var_ops.Variable( + initial_value=array_ops.zeros_like(v.initialized_value(), + dtypes.float32), + name=v.op.name + '_unshrinked/SDCAOptimizer') + var_list.append(slot_var) + self._slots['unshrinked_' + name].append(var_list) + # pylint: enable=protected-access + else: + with ops.device(var.device): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is + # fixed. + self._slots['unshrinked_' + name].append( + var_ops.Variable( + array_ops.zeros_like(var.initialized_value(), + dtypes.float32), + name=var.op.name + '_unshrinked/SDCAOptimizer')) def _assertSpecified(self, items, check_in): for x in items: @@ -205,16 +227,25 @@ def _assertList(self, items, check_in): if not isinstance(check_in[x], list): raise ValueError(x + ' must be a list.') + def _var_to_list(self, var): + """Wraps var in a list if it is not a list or PartitionedVariable.""" + if not (isinstance(var, list) or + isinstance(var, var_ops.PartitionedVariable)): + var = [var] + return var + def _l1_loss(self): """Computes the (un-normalized) l1 loss of the model.""" with name_scope('sdca/l1_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.abs(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append( + math_ops.reduce_sum( + math_ops.abs(math_ops.cast(weights, dtypes.float64)))) # SDCA L1 regularization cost is: l1 * sum(|weights|) return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) @@ -223,17 +254,37 @@ def _l2_loss(self, l2): with name_scope('sdca/l2_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.square(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast( + weights, dtypes.float64)))) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" - return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list] + # input_list can be a list of Variables (that are implicitly partitioned), + # in which case the underlying logic in internal_convert_to_tensor will not + # concatenate the partitions together. This method takes care of the + # concatenating (we only allow partitioning on the first axis). + output_list = [] + for x in input_list: + tensor_to_convert = x + if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable): + # We only allow for partitioning on the first axis. + tensor_to_convert = array_ops.concat(x, axis=0) + output_list.append(internal_convert_to_tensor( + tensor_to_convert, as_ref=as_ref)) + return output_list + + def _get_first_dimension_size_statically(self, w, num_partitions): + """Compute the static size of the first dimension for a sharded variable.""" + dim_0_size = w[0].get_shape()[0] + for p in range(1, num_partitions): + dim_0_size += w[p].get_shape()[0] + return dim_0_size def _linear_predictions(self, examples): """Returns predictions of the form w*x.""" @@ -286,6 +337,28 @@ def predictions(self, examples): result = math_ops.sigmoid(result) return result + def _get_partitioned_update_ops(self, + v_num, + num_partitions_by_var, + p_assignments_by_var, + gather_ids_by_var, + weights, + full_update, + p_assignments, + num_partitions): + """Get updates for partitioned variables.""" + num_partitions = num_partitions_by_var[v_num] + p_assignments = p_assignments_by_var[v_num] + gather_ids = gather_ids_by_var[v_num] + updates = data_flow_ops.dynamic_partition( + full_update, p_assignments, num_partitions) + update_ops = [] + for p in range(num_partitions): + with ops.colocate_with(weights[p]): + result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) + update_ops.append(result) + return update_ops + def minimize(self, global_step=None, name=None): """Add operations to train a linear model by minimizing the loss function. @@ -318,18 +391,89 @@ def minimize(self, global_step=None, name=None): # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. - weights_tensor = self._convert_n_to_tensor(self._slots[ - 'unshrinked_sparse_features_weights']) sparse_weights = [] sparse_indices = [] - for w, i in zip(weights_tensor, sparse_feature_indices): - # Find the feature ids to lookup in the variables. - with ops.device(w.device): - sparse_indices.append( - math_ops.cast( - array_ops.unique(math_ops.cast(i, dtypes.int32))[0], - dtypes.int64)) - sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) + # If we have partitioned variables, keep a few lists of Tensors around + # that we need for the assign_add after the op call to + # gen_sdca_ops.sdca_optimizer(). + num_partitions_by_var = [] + p_assignments_by_var = [] + gather_ids_by_var = [] + for w, i in zip(self._slots['unshrinked_sparse_features_weights'], + sparse_feature_indices): + # Append the sparse_indices (in full-variable space). + sparse_idx = math_ops.cast( + array_ops.unique(math_ops.cast(i, dtypes.int32))[0], + dtypes.int64) + sparse_indices.append(sparse_idx) + if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable): + num_partitions = len(w) + flat_ids = array_ops.reshape(sparse_idx, [-1]) + # We use div partitioning, which is easiest to support downstream. + # Compute num_total_ids as the sum of dim-0 of w, then assign + # to partitions based on a constant number of ids per partition. + # Optimize if we already know the full shape statically. + dim_0_size = self._get_first_dimension_size_statically( + w, num_partitions) + + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, + flat_ids.dtype) + else: + dim_0_sizes = [] + for p in range(num_partitions): + if w[p].get_shape()[0].value is not None: + dim_0_sizes.append(w[p].get_shape()[0].value) + else: + with ops.colocate_with(w[p]): + dim_0_sizes.append(array_ops.shape(w[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // num_partitions + extras = num_total_ids % num_partitions + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + new_ids = array_ops.where(p_assignments < extras, + flat_ids % (ids_per_partition + 1), + (flat_ids - extras) % ids_per_partition) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into num_partitions + # separate lists. + gather_ids = data_flow_ops.dynamic_partition(new_ids, + p_assignments, + num_partitions) + # Append these to the lists for use in the later update. + num_partitions_by_var.append(num_partitions) + p_assignments_by_var.append(p_assignments) + gather_ids_by_var.append(gather_ids) + + # Gather the weights from each partition. + partition_gathered_weights = [] + for p in range(num_partitions): + with ops.colocate_with(w[p]): + partition_gathered_weights.append( + array_ops.gather(w[p], gather_ids[p])) + + # Stitch the weights back together in the same order they were before + # we dynamic_partitioned them. + condition_indices = data_flow_ops.dynamic_partition( + math_ops.range(array_ops.shape(new_ids)[0]), + p_assignments, num_partitions) + batch_gathered_weights = data_flow_ops.dynamic_stitch( + condition_indices, partition_gathered_weights) + else: + w_as_tensor = internal_convert_to_tensor(w) + with ops.device(w_as_tensor.device): + batch_gathered_weights = array_ops.gather( + w_as_tensor, sparse_idx) + sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( @@ -355,12 +499,25 @@ def minimize(self, global_step=None, name=None): with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] # Update the weights before the proximal step. - for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'], - sparse_indices, sfw): - update_ops.append(state_ops.scatter_add(w, i, u)) + for v_num, (w, i, u) in enumerate( + zip(self._slots['unshrinked_sparse_features_weights'], + sparse_indices, sfw)): + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + update_ops += self._get_partitioned_update_ops( + v_num, num_partitions_by_var, p_assignments_by_var, + gather_ids_by_var, w, u, p_assignments, num_partitions) + else: + update_ops.append(state_ops.scatter_add(w, i, u)) for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): - update_ops.append(w.assign_add(u)) - + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + split_updates = array_ops.split( + u, num_or_size_splits=[v.shape.as_list()[0] for v in w]) + for v, split_update in zip(w, split_updates): + update_ops.append(state_ops.assign_add(v, split_update)) + else: + update_ops.append(state_ops.assign_add(w, u)) if not global_step: return control_flow_ops.group(*update_ops) with ops.control_dependencies(update_ops): @@ -385,21 +542,22 @@ def update_weights(self, train_op): for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): - update_ops.append(var.assign(slot_var)) + for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)): + update_ops.append(v.assign(sv)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # pylint: disable=protected-access - update_ops.append( - gen_sdca_ops.sdca_shrink_l1( - self._convert_n_to_tensor( - [var], as_ref=True), - l1=self._symmetric_l1_regularization(), - l2=self._symmetric_l2_regularization())) + for v in self._var_to_list(var): + with ops.device(v.device): + # pylint: disable=protected-access + update_ops.append( + gen_sdca_ops.sdca_shrink_l1( + self._convert_n_to_tensor([v], as_ref=True), + l1=self._symmetric_l1_regularization(), + l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops) def approximate_duality_gap(self): diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index d4e54c82f988e0..200e7de6b95f17 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): num_loss_partitions = params["num_loss_partitions"] weight_column_name = params["weight_column_name"] update_weights_hook = params.get("update_weights_hook", None) + partitioner = params["partitioner"] loss_type = None if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access @@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None): example_id_column=example_id_column, num_loss_partitions=n_loss_partitions, symmetric_l1_regularization=l1_regularization, - symmetric_l2_regularization=l2_regularization) + symmetric_l2_regularization=l2_regularization, + partitioner=partitioner) parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), name_or_scope=parent_scope, + partitioner=partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -213,7 +216,8 @@ def __init__(self, l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `_SDCAEstimator` estimator object. Args: @@ -241,6 +245,8 @@ def __init__(self, feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `_SDCAEstimator` estimator. @@ -267,6 +273,7 @@ def __init__(self, "l2_regularization": l2_regularization, "weight_column_name": weight_column_name, "update_weights_hook": _SdcaUpdateWeightsHook(), + "partitioner": partitioner, } super(_SDCAEstimator, self).__init__( @@ -336,7 +343,8 @@ def __init__(self, l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALogisticClassifier` object. Args: @@ -361,6 +369,8 @@ def __init__(self, feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALogisiticClassifier` estimator. @@ -376,7 +386,8 @@ def __init__(self, l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_classes(self, input_fn=None): """Runs inference to determine the predicted class. @@ -463,7 +474,8 @@ def __init__(self, l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALinearRegressor` estimator object. @@ -489,6 +501,8 @@ def __init__(self, feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALinearRegressor` estimator. @@ -503,7 +517,8 @@ def __init__(self, l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_scores(self, input_fn): """Returns predicted scores for given features. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index bed3d5139fcbf9..647667188238dc 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test @@ -273,6 +274,47 @@ def input_fn(): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testPartitionedMixedFeatures(self): + """Tests SDCALogisticClassifier with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([900.0, 700.0, 600.0]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(metrics['accuracy'], 0.9) + class SDCALinearRegressorTest(test.TestCase): @@ -350,6 +392,48 @@ def input_fn(): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testMixedFeaturesArbitraryWeightsPartitioned(self): + """Tests SDCALinearRegressor works with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + regressor = sdca_estimator.SDCALinearRegressor( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + l2_regularization=1.0, + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """SDCALinearRegressor works with sparse features and L1 regularization.""" diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 5d4572bf6c761e..9872c6f97c879d 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -37,18 +37,18 @@ class SDCAOptimizer(object): Example usage: ```python - real_feature_column = real_valued_column(...) - sparse_feature_column = sparse_column_with_hash_bucket(...) - sdca_optimizer = linear.SDCAOptimizer(example_id_column='example_id', - num_loss_partitions=1, - num_table_shards=1, - symmetric_l2_regularization=2.0) - classifier = tf.contrib.learn.LinearClassifier( - feature_columns=[real_feature_column, sparse_feature_column], - weight_column_name=..., - optimizer=sdca_optimizer) - classifier.fit(input_fn_train, steps=50) - classifier.evaluate(input_fn=input_fn_eval) + real_feature_column = real_valued_column(...) + sparse_feature_column = sparse_column_with_hash_bucket(...) + sdca_optimizer = linear.SDCAOptimizer(example_id_column='example_id', + num_loss_partitions=1, + num_table_shards=1, + symmetric_l2_regularization=2.0) + classifier = tf.contrib.learn.LinearClassifier( + feature_columns=[real_feature_column, sparse_feature_column], + weight_column_name=..., + optimizer=sdca_optimizer) + classifier.fit(input_fn_train, steps=50) + classifier.evaluate(input_fn=input_fn_eval) ``` Here the expectation is that the `input_fn_*` functions passed to train and @@ -64,7 +64,8 @@ class SDCAOptimizer(object): of workers running the train steps. It defaults to 1 (single machine). `num_table_shards` defines the number of shards for the internal state table, typically set to match the number of parameter servers for large - data sets. + data sets. You can also specify a `partitioner` object to partition the primal + weights during training (`div` partitioning strategy will be used). """ def __init__(self, @@ -73,13 +74,15 @@ def __init__(self, num_table_shards=None, symmetric_l1_regularization=0.0, symmetric_l2_regularization=1.0, - adaptive=True): + adaptive=True, + partitioner=None): self._example_id_column = example_id_column self._num_loss_partitions = num_loss_partitions self._num_table_shards = num_table_shards self._symmetric_l1_regularization = symmetric_l1_regularization self._symmetric_l2_regularization = symmetric_l2_regularization self._adaptive = adaptive + self._partitioner = partitioner def get_name(self): return 'SDCAOptimizer' @@ -108,6 +111,10 @@ def symmetric_l2_regularization(self): def adaptive(self): return self._adaptive + @property + def partitioner(self): + return self._partitioner + def get_train_step(self, columns_to_variables, weight_column_name, loss_type, features, targets, global_step): """Returns the training operation of an SdcaModel optimizer.""" @@ -175,10 +182,12 @@ def _training_examples_and_variables(): sparse_feature_column = _dense_tensor_to_sparse_feature_column( dense_bucket_tensor) sparse_feature_with_values.append(sparse_feature_column) - # For bucketized columns, the variables list contains exactly one - # element. - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) elif isinstance( column, ( @@ -198,6 +207,14 @@ def _training_examples_and_variables(): example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1]) flat_ids = array_ops.reshape(id_tensor.values, [-1]) + # Prune invalid IDs (< 0) from the flat_ids, example_ids, and + # weight_tensor. These can come from looking up an OOV entry in the + # vocabulary (default value being -1). + is_id_valid = math_ops.greater_equal(flat_ids, 0) + flat_ids = array_ops.boolean_mask(flat_ids, is_id_valid) + example_ids = array_ops.boolean_mask(example_ids, is_id_valid) + weight_tensor = array_ops.boolean_mask(weight_tensor, is_id_valid) + projection_length = math_ops.reduce_max(flat_ids) + 1 # project ids based on example ids so that we can dedup ids that # occur multiple times for a single example. @@ -218,8 +235,12 @@ def _training_examples_and_variables(): array_ops.shape(ids)[0]), [-1]) sparse_feature_with_values.append( SparseFeatureColumn(example_ids_filtered, reproject_ids, weights)) - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) else: raise ValueError('SDCAOptimizer does not support column type %s.' % type(column).__name__) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 1534f97d760015..9c804d27854b80 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", @@ -92,6 +90,18 @@ cc_library( deps = [":context"], ) +cc_library( + name = "kernel_api", + hdrs = [ + "builtin_op_data.h", + "builtin_ops.h", + "context.h", + "context_util.h", + ], +) + +exports_files(["builtin_ops.h"]) + cc_library( name = "string", hdrs = [ @@ -112,6 +122,7 @@ cc_library( "interpreter.cc", "model.cc", "nnapi_delegate.cc", + "op_resolver.cc", "optional_debug_tools.cc", ], hdrs = [ @@ -122,6 +133,7 @@ cc_library( "interpreter.h", "model.h", "nnapi_delegate.h", + "op_resolver.h", "optional_debug_tools.h", ], copts = tflite_copts(), @@ -224,6 +236,18 @@ cc_test( ], ) +# Test OpResolver. +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the C extension API code. cc_test( name = "context_test", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 65fba52d461461..cc8a8035d1dade 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -1,4 +1,3 @@ - # Find where we're running from, so we can store generated files here. ifeq ($(origin MAKEFILE_DIR), undefined) MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) @@ -30,7 +29,7 @@ GENDIR := $(MAKEFILE_DIR)/gen/obj/ CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG CC := $(CC_PREFIX)gcc -CFLAGS := -O3 -DNDEBUG +CCFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -69,12 +68,12 @@ LIB_NAME := libtensorflow-lite.a LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. -BENCHMARK_PATH := $(BINDIR)benchmark_model +MINIMAL_PATH := $(BINDIR)minimal -BENCHMARK_SRCS := \ -tensorflow/contrib/lite/tools/benchmark_model.cc -BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) +MINIMAL_SRCS := \ +tensorflow/contrib/lite/examples/minimal/minimal.cc +MINIMAL_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. @@ -100,7 +99,7 @@ $(wildcard tensorflow/contrib/lite/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ -$(BENCHMARK_SRCS) +$(MINIMAL_SRCS) # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # File names of the intermediate files target compilation generates. @@ -119,17 +118,17 @@ $(OBJDIR)%.o: %.c $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(BENCHMARK_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) # Gathers together all the objects we've compiled into a single '.a' archive. $(LIB_PATH): $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) -$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH) +$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \ + -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) # Gets rid of all generated files. diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md new file mode 100644 index 00000000000000..8fd63d5cee7db3 --- /dev/null +++ b/tensorflow/contrib/lite/RELEASE.md @@ -0,0 +1,8 @@ +# Release 0.1.7 + +* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit + fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). +* To reproduce the iOS library, it's required to cherry pick git commit + f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. +* The code is based on TensorFlow 1.8.0 release candidate and it's very close + to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 85216776823eab..aa6a60dc9ed308 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -1,4 +1,8 @@ """Generate Flatbuffer binary from json.""" +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) def tflite_copts(): """Defines compile time flags.""" @@ -185,32 +189,106 @@ def json_to_tflite(name, src, out): tools = [flatc], ) -def gen_zipped_test_files(name, files): +# This is the master list of generated examples that will be made into tests. A +# function called make_XXX_tests() must also appear in generate_examples.py. +# Disable a test by commenting it out. If you do, add a link to a bug or issue. +def generated_test_models(): + return [ + "add", + "arg_max", + "avg_pool", + "batch_to_space_nd", + "concat", + "constant", + "control_dep", + "conv", + "depthwiseconv", + "div", + "exp", + "expand_dims", + "floor", + "fully_connected", + "fused_batch_norm", + "gather", + "global_batch_norm", + "greater", + "greater_equal", + "l2norm", + "l2_pool", + "less", + "less_equal", + "local_response_norm", + "log_softmax", + "lstm", + "max_pool", + "maximum", + "mean", + "minimum", + "mul", + "neg", + "pad", + "padv2", + # "prelu", + "relu", + "relu1", + "relu6", + "reshape", + "resize_bilinear", + "sigmoid", + "sin", + "slice", + "softmax", + "space_to_batch_nd", + "space_to_depth", + "sparse_to_dense", + "split", + "squeeze", + "strided_slice", + "strided_slice_1d_exhaustive", + "sub", + "tile", + "topk", + "transpose", + "transpose_conv", + "where", + ] + +def gen_zip_test(name, test_name, **kwargs): + """Generate a zipped-example test and its dependent zip files. + + Args: + name: Resulting cc_test target name + test_name: Test targets this model. Comes from the list above. + **kwargs: tf_cc_test kwargs. + """ + gen_zipped_test_file( + name = "zip_%s" % test_name, + file = "%s.zip" % test_name, + ) + tf_cc_test(name, **kwargs) + +def gen_zipped_test_file(name, file): """Generate a zip file of tests by using :generate_examples. Args: - name: Name of output. We will produce "`name`_files" as a target. - files: A list of zip file basenames. + name: Name of output. We will produce "`file`.files" as a target. + file: The name of one of the generated_examples targets, e.g. "transpose" """ toco = "//tensorflow/contrib/lite/toco:toco" - out_files = [] - for f in files: - out_file = name + "/" + f - out_files.append(out_file) - native.genrule( - name = name + "_" + f + ".files", - cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + f + " $(@D)"), - outs = [out_file], - tools = [ - ":generate_examples", - toco, - ], - ) + native.genrule( + name = file + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + file + " $(@D)"), + outs = [file], + tools = [ + ":generate_examples", + toco, + ], + ) native.filegroup( name = name, - srcs = out_files, + srcs = [file], ) def gen_selected_ops(name, model): diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 4910c89eaebabb..c1cc4476fbd45f 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -148,10 +148,20 @@ typedef struct { float beta; } TfLiteLocalResponseNormParams; +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + typedef struct { + // Parameters for LSTM version 1. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; typedef struct { @@ -161,6 +171,9 @@ typedef struct { typedef struct { } TfLitePadParams; +typedef struct { +} TfLitePadV2Params; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -227,6 +240,16 @@ typedef struct { TfLiteType output_type; } TfLiteArgMaxParams; +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 859bc7ab70dc36..fc6fdd6eefb4ce 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -33,6 +33,7 @@ typedef enum { kTfLiteBuiltinDepthwiseConv2d = 4, kTfLiteBuiltinDequantize = 6, kTfLiteBuiltinEmbeddingLookup = 7, + kTfLiteBuiltinFloor = 8, kTfLiteBuiltinFullyConnected = 9, kTfLiteBuiltinHashtableLookup = 10, kTfLiteBuiltinL2Normalization = 11, @@ -83,10 +84,21 @@ typedef enum { kTfLiteBuiltinArgMax = 56, kTfLiteBuiltinMinimum = 57, kTfLiteBuiltinLess = 58, + kTfLiteBuiltinNeg = 59, + kTfLiteBuiltinPadv2 = 60, + kTfLiteBuiltinGreater = 61, + kTfLiteBuiltinGreaterEqual = 62, + kTfLiteBuiltinLessEqual = 63, + kTfLiteBuiltinSelect = 64, + kTfLiteBuiltinSlice = 65, + kTfLiteBuiltinSin = 66, + kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, + kTfLiteBuiltinTile = 69, + kTfLiteBuiltinExpandDims = 70, } TfLiteBuiltinOperator; #ifdef __cplusplus } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 12841d233cc1d3..4eb66cc225eb04 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -370,13 +370,21 @@ typedef struct _TfLiteRegistration { // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. Note, it is the responsibility of the registration binder to - // set this properly. + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. int32_t builtin_code; // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. // WARNING: This is an experimental interface that is subject to change. const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h new file mode 100644 index 00000000000000..abe802e34214ca --- /dev/null +++ b/tensorflow/contrib/lite/context_util.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides a few C++ helpers that are useful for manipulating C structures +// in C++. +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + size_t size() const { return end() - begin(); } + + private: + const TfLiteIntArray* int_array_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD new file mode 100644 index 00000000000000..35a8f6ca4166e3 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = [ + "//visibility:public", +]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "nnapi_delegate", + srcs = ["nnapi_delegate.cc"], + hdrs = ["nnapi_delegate.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "nnapi_delegate_test", + size = "small", + srcs = ["nnapi_delegate_test.cc"], + deps = [ + ":nnapi_delegate", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc new file mode 100644 index 00000000000000..0731d14419d2de --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -0,0 +1,464 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { +namespace { + +// TODO(b/80621585): Consider printing error string, but don't for now to +// minimize binary size. +#define CHECK_NN(context, code) \ + if (code != ANEURALNETWORKS_NO_ERROR) { \ + context->ReportError(context, "NN API returned error (%d).\n", code); \ + return kTfLiteError; \ + } + +// RAII NN API Model Destructor for use with std::unique_ptr +struct NNFreeModel { + void operator()(ANeuralNetworksModel* model) { + ANeuralNetworksModel_free(model); + } +}; +// RAII NN API Compilation Destructor for use with std::unique_ptr +struct NNFreeCompilation { + void operator()(ANeuralNetworksCompilation* model) { + ANeuralNetworksCompilation_free(model); + } +}; + +// Track tensor indices to NN API tensor indices mapping. +class OperandMapping { + public: + // Given a TFLite index return the ANN index. If it doesn't exist + // return -1. + int lite_index_to_ann(int index) const { + if (index < lite_tensor_to_ann_tensor_.size()) + return lite_tensor_to_ann_tensor_[index]; + else + return -1; + } + + // NN API uses non tensor operands instead of structs. This creates one + // and returns the index. It uses a std::vector and resizes it as needed + // keeping -1 to unmapped values. Intermediate tensors likely will not + // be mapped. + int add_new_non_tensor_operand() { return next_ann_tensor_index_++; } + + // Add a new mapping from `tflite_index` and return the NN API tensor index. + int add_new_ann_tensor_index(int tflite_index) { + if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { + lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + } + int new_tensor_index = next_ann_tensor_index_++; + lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; + return new_tensor_index; + } + + private: + // Next index of ann tensor + int next_ann_tensor_index_ = 0; + + // Mapping from lite index. Use a std::vector for speed and code size + // rather than a map. + std::vector lite_tensor_to_ann_tensor_; +}; + +// Abstract builder for building an op in the NN API graph. This handles +// the disparity between TFLite and NN API operand types. NN API has singular +// operands for both tensors and parameters, and TFLite separates the two. +class NNAPIOpBuilder { + public: + NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + ANeuralNetworksModel* nn_model) + : context_(context), + operand_mapping_(tensor_mapping), + nn_model_(nn_model) {} + + TfLiteStatus AddScalarInt32Operand(int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(int32_t))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddTensorInput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + TfLiteStatus AddTensorOutput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_outputs_.push_back(ann_index); + return kTfLiteOk; + } + + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. + // This returns the NN API tensor index corresponding to the created tensor. + // If another caller previously created a NN API tensor for `tensor_index` + // then the existing one is returned. + TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); + if (ann_tensor_index != -1) { + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + // Allocate a new tensor index + ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index); + + // Parameters needed for new type. + int32_t nn_type = 0; + float scale = 0.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + *ann_tensor_index_out = -1; + return kTfLiteOk; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + scale = 0.f; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = 0.f; + zeroPoint = 0; + break; + default: + context_->ReportError(context_, "Logic error in NN API Delegate.\n"); + return kTfLiteError; + } + + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + if (tensor->allocation_type == kTfLiteMmapRo) { + // TODO(b/80630405): Use NNAPIAllocation. + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_tensor_index, tensor->data.raw, + tensor->bytes)); + } + + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + CHECK_NN(context_, ANeuralNetworksModel_addOperation( + nn_model_, type, + static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_outputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + // TfLiteContext for error handling. Must be named context for macros to + // work. + TfLiteContext* context_; + + // Tracks relationship between indices + OperandMapping* operand_mapping_; + + // The model + ANeuralNetworksModel* nn_model_; + + // Inputs and outputs for the current op. These are augmented in the sense + // that NN API uses operands for all arguments, not just tensors, unlike + // TensorFlow lite. + std::vector augmented_inputs_; + std::vector augmented_outputs_; +}; + +// The kernel that represents the subgraph of TF Lite being run on NN API. +class NNAPIDelegateKernel { + public: + NNAPIDelegateKernel() = default; + + typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, + NNAPIOpBuilder* builder, + TfLiteNode* node); + + // Return a function that knows how to translate a node into its operands + // when called. You can use this function to see if a node is supported + // (i.e. that MappingFn is not nullptr). + MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + break; + case kTfLiteBuiltinAveragePool2d: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->filter_width); + builder->AddScalarInt32Operand(builtin->filter_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + break; + default: + return nullptr; + } + } + + // Initialize the kernel (a NN model). + TfLiteStatus Init(TfLiteContext* context, + const TfLiteDelegateParams* params) { + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + nodes_.push_back(node_index); + } + + if (!nn_model_) { + ANeuralNetworksModel* model; + CHECK_NN(context, ANeuralNetworksModel_create(&model)); + nn_model_.reset(model); + + TF_LITE_ENSURE_STATUS( + BuildGraph(context, params->input_tensors, params->output_tensors)); + } + + if (!nn_compilation_) { + ANeuralNetworksCompilation* compilation; + CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation)); + nn_compilation_.reset(compilation); + } + return kTfLiteOk; + } + + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + + // Set the input tensor buffers. Note: we access tflite tensors using + // absolute indices but NN api indices inputs by relative indices. + int relative_input_index = 0; + for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[absolute_input_index]; + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } + + // Set the output tensor buffers. + int relative_output_index = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; + } + // Invoke ANN in blocking fashion. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(context, ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + + return kTfLiteOk; + } + + private: + // ANN API state. + std::unique_ptr nn_model_; + std::unique_ptr + nn_compilation_; + // Node indices that this delegate is responsible for. Indices here + // indexes into the nodes array in the TfLiteContext. + std::vector nodes_; + // Track indices we use + OperandMapping operand_mapping_; + + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { + // The operand builder allows creating a single op. We create it at this + // reduced power position rather than in the for loop to avoid reallocating + // the vectors. + NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); + // Add Tensors + // allocate outside to avoid realloc + for (auto node_index : nodes_) { + // Obtain the op and registration. + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + // Map inputs to NN API tensor indices. + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } + // Get op type and operands + int nn_op_type = + Map(context, reg->builtin_code, node)(context, &builder, node); + // Map outputs to NN API tensor indices. + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + } + + builder.FinalizeAddOperation(nn_op_type); + } + return kTfLiteOk; + } + + TfLiteStatus BuildGraph(TfLiteContext* context, + const TfLiteIntArray* input_tensors, + const TfLiteIntArray* output_tensors) { + // Build the ops and tensors. + TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context)); + // Map input and output tensor indices to ANN + std::vector inputs; + inputs.reserve(input_tensors->size); + std::vector outputs; + outputs.reserve(output_tensors->size); + // Make the TensorFlow lite inputs and outputs to ann_indices. + for (int i : TfLiteIntArrayView(input_tensors)) + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(output_tensors)) + outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + // Tell ANN to declare inputs/outputs + CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_.get(), inputs.size(), inputs.data(), + outputs.size(), outputs.data())); + // Finalize the model + CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + + return kTfLiteOk; + } +}; + +} // namespace + +// Return a NN API Delegate struct that can check for support of ops. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = { + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Do not check nodes_ if NN API is unavailable. + if (!NNAPIExists()) return kTfLiteOk; + + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int total_supported_nodes = 0; + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + NNAPIDelegateKernel dummy_kernel; + if (dummy_kernel.Map(context, registration->builtin_code, node)) { + supported_nodes.push_back(node_index); + } + total_supported_nodes += 1; + } + // Put the size at the beginning of the array. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API subgraphs) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .builtin_code = kTfLiteBuiltinDelegate, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent subgraph a new nnapi_delegate_kernel. + context->ReplaceSubgraphsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), + delegate); + return kTfLiteOk; + }}; + + return &delegate; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h new file mode 100644 index 00000000000000..44cca2fd285370 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Return a delegate that can be used to use the NN API. +// e.g. +// NnApiDelegate* delegate = NnApiDelegate(); +// interpreter->ModifyGraphWithDelegate(&delegate); +// NnApiDelegate() returns a singleton, so you should not free this +// pointer or worry about its lifetime. +TfLiteDelegate* NnApiDelegate(); +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc new file mode 100644 index 00000000000000..ff2e721423f078 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FloatAddOpModel : public SingleOpModel { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +// Do a test with the NN API using no activation. +TEST(NNAPIDelegate, AddWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Do a test with the NN api with relu. +TEST(NNAPIDelegate, AddWithRelu) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 436c3e1d4cad5e..840015a7fad173 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,9 +30,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 49280129971e38..57000072561303 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -42,7 +42,6 @@ android_binary( custom_package = "org.tensorflow.lite.demo", inline_constants = 1, manifest = "AndroidManifest.xml", - manifest_merger = "android", nocompress_extensions = [ ".tflite", ], diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 2a64c1de725b60..e36218e4f12057 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -62,8 +62,8 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, {1, wanted_height, wanted_width, wanted_channels}, quant); ops::builtin::BuiltinOpResolver resolver; - TfLiteRegistration* resize_op = - resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + const TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR, 1); auto* params = reinterpret_cast( malloc(sizeof(TfLiteResizeBilinearParams))); params->align_corners = false; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index a91467d345fdce..966fcd2a31fd4d 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -70,6 +71,22 @@ TfLiteStatus ReadLabelsFile(const string& file_name, return kTfLiteOk; } +void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index, + TfLiteRegistration registration) { + // output something like + // time (ms) , Node xxx, OpCode xxx, symblic name + // 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D + + LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3) + << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0 + << ", Node " << std::setw(3) << std::setprecision(3) << op_index + << ", OpCode " << std::setw(3) << std::setprecision(3) + << registration.builtin_code << ", " + << EnumNameBuiltinOperator( + static_cast(registration.builtin_code)) + << "\n"; +} + void RunInference(Settings* s) { if (!s->model_name.c_str()) { LOG(ERROR) << "no model file name\n"; @@ -166,19 +183,36 @@ void RunInference(Settings* s) { exit(-1); } + profiling::Profiler* profiler = new profiling::Profiler(); + interpreter->SetProfiler(profiler); + + if (s->profiling) profiler->StartProfiling(); + struct timeval start_time, stop_time; - gettimeofday(&start_time, NULL); + gettimeofday(&start_time, nullptr); for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed to invoke tflite!\n"; } } - gettimeofday(&stop_time, NULL); + gettimeofday(&stop_time, nullptr); LOG(INFO) << "invoked \n"; LOG(INFO) << "average time: " << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) << " ms \n"; + if (s->profiling) { + profiler->StopProfiling(); + auto profile_events = profiler->GetProfileEvents(); + for (int i = 0; i < profile_events.size(); i++) { + auto op_index = profile_events[i]->event_metadata; + const auto node_and_registration = + interpreter->node_and_registration(op_index); + const TfLiteRegistration registration = node_and_registration->second; + PrintProfilingInfo(profile_events[i], op_index, registration); + } + } + const int output_size = 1000; const size_t num_results = 5; const float threshold = 0.001f; @@ -217,13 +251,14 @@ void RunInference(Settings* s) { void display_usage() { LOG(INFO) << "label_image\n" - << "--accelerated, -a: [0|1], use Android NNAPI or note\n" + << "--accelerated, -a: [0|1], use Android NNAPI or not\n" << "--count, -c: loop interpreter->Invoke() for certain times\n" << "--input_mean, -b: input mean\n" << "--input_std, -s: input standard deviation\n" << "--image, -i: image_name.bmp\n" << "--labels, -l: labels for the model\n" << "--tflite_model, -m: model_name.tflite\n" + << "--profiling, -p: [0|1], profiling or not\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -235,21 +270,22 @@ int Main(int argc, char** argv) { int c; while (1) { static struct option long_options[] = { - {"accelerated", required_argument, 0, 'a'}, - {"count", required_argument, 0, 'c'}, - {"verbose", required_argument, 0, 'v'}, - {"image", required_argument, 0, 'i'}, - {"labels", required_argument, 0, 'l'}, - {"tflite_model", required_argument, 0, 'm'}, - {"threads", required_argument, 0, 't'}, - {"input_mean", required_argument, 0, 'b'}, - {"input_std", required_argument, 0, 's'}, - {0, 0, 0, 0}}; + {"accelerated", required_argument, nullptr, 'a'}, + {"count", required_argument, nullptr, 'c'}, + {"verbose", required_argument, nullptr, 'v'}, + {"image", required_argument, nullptr, 'i'}, + {"labels", required_argument, nullptr, 'l'}, + {"tflite_model", required_argument, nullptr, 'm'}, + {"profiling", required_argument, nullptr, 'p'}, + {"threads", required_argument, nullptr, 't'}, + {"input_mean", required_argument, nullptr, 'b'}, + {"input_std", required_argument, nullptr, 's'}, + {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; - c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options, + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options, &option_index); /* Detect the end of the options. */ @@ -257,15 +293,14 @@ int Main(int argc, char** argv) { switch (c) { case 'a': - s.accel = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'b': - s.input_mean = strtod(optarg, NULL); + s.input_mean = strtod(optarg, nullptr); break; case 'c': - s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.loop_count = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'i': s.input_bmp_name = optarg; @@ -276,16 +311,20 @@ int Main(int argc, char** argv) { case 'm': s.model_name = optarg; break; + case 'p': + s.profiling = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) + break; case 's': - s.input_std = strtod(optarg, NULL); + s.input_std = strtod(optarg, nullptr); break; case 't': s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + optarg, nullptr, 10); break; case 'v': - s.verbose = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.verbose = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'h': case '?': diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index 4de32e33fb4ef2..4b48014e1c77ec 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -25,6 +25,7 @@ struct Settings { bool verbose = false; bool accel = false; bool input_floating = false; + bool profiling = false; int loop_count = 1; float input_mean = 127.5f; float input_std = 127.5f; diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc new file mode 100644 index 00000000000000..8b0ace96ccaf06 --- /dev/null +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include + +// This is an example that is minimal to read a model +// from disk and perform inference. There is no data being loaded +// that is up to you to add as a user. +// +// NOTE: Do not add any dependencies to this that cannot be built with +// the minimal makefile. This example must remain trivial to build with +// the minimal build tool. +// +// Usage: minimal + +using namespace tflite; + +#define TFLITE_MINIMAL_CHECK(x) \ + if(!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ + } + + +int main(int argc, char *argv[]) { + if(argc != 2) { + fprintf(stderr, "minimal \n"); + return 1; + } + const char* filename = argv[1]; + + // Load model + std::unique_ptr model + = tflite::FlatBufferModel::BuildFromFile(filename); + TFLITE_MINIMAL_CHECK(model != nullptr); + + // Build the interpreter + tflite::ops::builtin::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model.get(), resolver); + std::unique_ptr interpreter; + builder(&interpreter); + TFLITE_MINIMAL_CHECK(interpreter != nullptr); + + // Allocate tensor buffers. + TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + + // Fill input buffers + // TODO(user): Insert code to fill input tensors + + // Run inference + TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + + // Read output buffers + // TODO(user): Insert getting data out code. + + return 0; +} diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index d7cc854ebac08e..972e57f73e8296 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -39,7 +39,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); int num_dims = NumDimensions(input); @@ -54,7 +54,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { using namespace tflite; - TfLiteTensor* input = GetInput(context, node,0); + const TfLiteTensor* input = GetInput(context, node,0); TfLiteTensor* output = GetOutput(context, node,0); float* input_data = input->data.f; diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index d8134d5a00097b..c1c8ef049f693d 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,28 +1,63 @@ # List of Hosted Models -* [NASNet large](https://storage.googleapis.com/download.tensorflow.org/models/tflite/nasnet_large_2018_03_27.zip) -* [NASNet mobile](https://storage.googleapis.com/download.tensorflow.org/models/tflite/nasnet_mobile_2018_03_27.zip) -* [ResNet v2 101](https://storage.googleapis.com/download.tensorflow.org/models/tflite/resnet_v2_101_2018_03_27.zip) -* [ResNet v2 50](https://storage.googleapis.com/download.tensorflow.org/models/tflite/resnet_v2_50_2018_03_27.zip) -* [Inception ResNet v2](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_resnet_v2_2018_03_27.zip) -* [Inception v4](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v4_2018_03_27.zip) -* [Inception v3 2015](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_2015_2017_11_10.zip) -* [Inception v3 Slim 2016](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) -* [Mobilenet 0.25 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_128_float_2017_11_08.zip) -* [Mobilenet 0.25 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_160_float_2017_11_08.zip) -* [Mobilenet 0.25 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_192_float_2017_11_08.zip) -* [Mobilenet 0.25 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_224_float_2017_11_08.zip) -* [Mobilenet 0.50 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_128_float_2017_11_08.zip) -* [Mobilenet 0.50 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_160_float_2017_11_08.zip) -* [Mobilenet 0.50 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_192_float_2017_11_08.zip) -* [Mobilenet 0.50 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_224_float_2017_11_08.zip) -* [Mobilenet 0.75 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_128_float_2017_11_08.zip) -* [Mobilenet 0.75 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_160_float_2017_11_08.zip) -* [Mobilenet 0.75 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_192_float_2017_11_08.zip) -* [Mobilenet 0.75 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_224_float_2017_11_08.zip) -* [Mobilenet 1.0 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_128_float_2017_11_08.zip) -* [Mobilenet 1.0 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_160_float_2017_11_08.zip) -* [Mobilenet 1.0 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_192_float_2017_11_08.zip) -* [Mobilenet 1.0 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_float_2017_11_08.zip) -* [Mobilenet 1.0 224 Quant](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) -* [Smart Reply 1.0 Android ](https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip) +## Image classification (Float Models) + +Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance +------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: +DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms +SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms +NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 72.2% | 90.6% | 261 ms | 389 ms +NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.1% | 95.8% | 6697 ms | 7940 ms +ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms +ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms +Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 76.9% | 93.5% | 1433 ms | 1522 ms +Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 79.6% | 94.6% | 2986 ms | 3139 ms +Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 76.8% | 93.5% | 2731 ms | 2926 ms +Mobilenet_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.5% | 66.3% | 6.2 ms | 13.0 ms +Mobilenet_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.5% | 70.3% | 8.6 ms | 19.5 ms +Mobilenet_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.7% | 72.3% | 12.1 ms | 27.8 ms +Mobilenet_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.8% | 74.2% | 16.2 ms | 37.3 ms +Mobilenet_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.3% | 79.4% | 18.1 ms | 29.9 ms +Mobilenet_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.1% | 81.9% | 26.8 ms | 45.9 ms +Mobilenet_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.6% | 35.6 ms | 65.3 ms +Mobilenet_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.3% | 84.9% | 47.6 ms | 164.2 ms +Mobilenet_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.1% | 83.9% | 34.6 ms | 48.7 ms +Mobilenet_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.3% | 86.0% | 51.3 ms | 75.2 ms +Mobilenet_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.2% | 87.3% | 71.7 ms | 107.0 ms +Mobilenet_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.4% | 88.2% | 95.7 ms | 143.4 ms +Mobilenet_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.8% | 57.4 ms | 76.8 ms +Mobilenet_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms +Mobilenet_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.0% | 89.2% | 118.6 ms | 167.3 ms +Mobilenet_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 70.9% | 89.9% | 160.1 ms | 224.3 ms + +^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph. + +^^ The performance numbers are generated in the benchmark on Pixel-2 using +single thread large core. + +## Image classification (Quantized Models) + +Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance +------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms + +## Other models + +Model | TF Lite FlatBuffer +----------------------- | :----------------: +Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md index 7a3a231626d0e1..ab507893074142 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -32,7 +32,7 @@ This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc v Log in to you RPI, install the toolchain. ```bash -sudo apt-get instal build-essential +sudo apt-get install build-essential ``` First, clone this TensorFlow repository. Run this at the root of the repository: diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 203924f03d3101..b2f6444e9e1dc3 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -132,10 +132,7 @@ TensorFlow operation not listed above are likely unsupported. Notably, the following common ops are not supported at the moment: * [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space) -* [tf.floor](https://www.tensorflow.org/api_docs/python/tf/floor) -* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) * [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear) -* [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice) * [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh) ## TensorFlow Lite Operations @@ -223,6 +220,23 @@ Options { } ``` +**CONV_2D_TRANSPOSE** + +``` +Inputs { + 0: output_shape + 1: filter + 2: 4D tensor +} +Outputs { + 0: the transpose (gradient) of conv2d +} +Options { + padding: SAME|VALID + stride_w,stride_h: stride of the filter window +} +``` + **DEPTHWISE_CONV_2D** ``` @@ -254,6 +268,17 @@ Outputs { } ``` +**FLOOR** + +``` +inputs { + 0: tensor +} +outputs: { + 0: result of computing element-wise floor of the input tensor +} +``` + **FULLY_CONNECTED** ``` @@ -271,6 +296,45 @@ Options { } ``` +**GATHER** + +``` +Inputs { + 0: params tensor + 1: indices tensor + 2: axis tensor (optional) +} +Outputs { + 0: a tensor with same type as the params tensor. +} +``` + +**GREATER** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than the corresponding element of the second tensor. +} +``` + +**GREATER_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than or equal to the corresponding element of the second tensor. +} +``` + **L2_NORMALIZATION** ``` @@ -315,6 +379,19 @@ Outputs { } ``` +**LESS_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is less + than or equal to the corresponding element of the second tensor. +} +``` + **LOCAL_RESPONSE_NORMALIZATION** ``` @@ -387,6 +464,17 @@ Options { } ``` +**NEG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: elementwise negation of the input tensor +} +``` + **PAD** ``` @@ -463,6 +551,19 @@ Options { } ``` +**SLICE** + +``` +Inputs { + 0: tensor + 1: 1D tensor + 2: 1D tensor +} +Outputs { + 0: slice of the input tensor of the given size from the given begin index. +} +``` + **SOFTMAX** ``` @@ -506,6 +607,21 @@ Outputs { } ``` +**SPARSE_TO_DENSE** + +``` +Inputs { + 0: 0D or 1D or 2D tensor + 1: 1D tensor + 2: 0D or 1D tensor + 3: 0D tensor + 4: a boolean value +} +Outputs { + 0: Dense Tensor of shape output_shape. Has the same type as sparse_values. +} +``` + **SPLIT** ``` @@ -548,7 +664,7 @@ Outputs { 0: slice of the input tensor of the given size } Options { - begin_mask: mask for begin indicies + begin_mask: mask for begin indices end_mask: mask for end indices shrink_axis_mask: mask that indicates which dimensions to remove } @@ -563,7 +679,7 @@ Inputs { } Outputs { 0: k largest element along each last dimensional slice - 1: indicies of values within the last dimension of the input ensor + 1: indices of values within the last dimension of the input ensor } ``` @@ -579,6 +695,20 @@ Outputs { } ``` +**SELECT** + +``` +Inputs { + 0: tensor + 1: tensor + 2: tensor +} +Outputs { + 0: tensor that contains the elementwise values of 'tensor 1' if the + corresponding value of 'tensor 0' is true or the value of 'tensor 2' if false. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 9d8ea55fd1edc0..ebb0aedc2001a8 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -125,7 +125,8 @@ Interpreter::~Interpreter() { for (int i = 0; i < context_.tensors_size; i++) { TfLiteTensor* tensor = &context_.tensors[i]; - if (tensor->buffer_handle != kTfLiteNullBufferHandle) { + if (tensor->buffer_handle != kTfLiteNullBufferHandle && + tensor->delegate->FreeBufferHandle != nullptr) { tensor->delegate->FreeBufferHandle(tensor->delegate, &tensor->buffer_handle); } diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 6f3433abcf71b6..7315d8360680ca 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,20 +237,46 @@ class Interpreter { return nullptr; } - // Return a pointer into the data of a given input tensor. The given index - // must be between 0 and inputs().size(). + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a mutable pointer into the data of a given input tensor. The given + // index must be between 0 and inputs().size(). template T* typed_input_tensor(int index) { return typed_tensor(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return an immutable pointer into the data of a given input tensor. The + // given index must be between 0 and inputs().size(). + template + const T* typed_input_tensor(int index) const { + return typed_tensor(inputs_[index]); + } + + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template T* typed_output_tensor(int index) { return typed_tensor(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template + const T* typed_output_tensor(int index) const { + return typed_tensor(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. @@ -325,9 +357,7 @@ class Interpreter { void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; } - profiling::Profiler* GetProfiler(profiling::Profiler* profiler) { - return profiler_; - } + profiling::Profiler* GetProfiler() { return profiler_; } // The default capacity of `tensors_` vector. static constexpr int kTensorsReservedCapacity = 128; diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1dda55b8edf8f8..593af81a18a1e2 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -1,7 +1,9 @@ # Description: # TensorFlow Lite Java API. -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//tensorflow/contrib/lite/java/ovic:__pkg__", +]) licenses(["notice"]) # Apache 2.0 @@ -46,23 +48,6 @@ android_library( ], ) -java_library( - name = "ovicbenchmarkerlib", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - javacopts = JAVACOPTS, - visibility = ["//visibility:public"], - deps = [ - ":libtensorflowlite_jni.so", - ":tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - java_library( name = "tensorflowlitelib", srcs = glob( @@ -165,28 +150,6 @@ java_test( ], ) -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - ], - javacopts = JAVACOPTS, - test_class = "org.tensorflow.ovic.OvicClassifierTest", - visibility = ["//visibility:public"], - deps = [ - ":ovicbenchmarkerlib", - "@com_google_truth", - "@junit", - ], -) - filegroup( name = "libtensorflowlite_jni", srcs = select({ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml index ba63dce5d9a719..95b6b7016f2818 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -31,6 +31,7 @@ android:theme="@style/MaterialTheme"> diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml index 20f520814d7154..ef8a9e08450d72 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -13,51 +13,55 @@ See the License for the specific language governing permissions and limitations under the License. --> - + android:layout_height="match_parent" + android:background="#bb7700" + android:orientation="horizontal"> + + + + + + - + + + - - - - - - - - - - - - - - + android:paddingTop="20dp" + android:textColor="#FFF" + android:textSize="20sp"/> + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml index 72a229ecdb19f5..ddb099a950c2f8 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml @@ -28,7 +28,7 @@ + - + android:id="@+id/bottom_info_view" + android:layout_marginBottom="10dp" + android:layout_height="50dp"> + + + android:layout_marginLeft="10dp" + android:background="#0000000f" + android:textColor="@android:color/white" /> + + - - diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml index d12435d5abda45..e567009a424ed7 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -15,101 +15,80 @@ --> + android:layout_height="match_parent" + android:background="#bb7700"> - + android:layout_weight="1" /> - - - - - + android:layout_alignParentTop="false" + android:background="#bb7700" + android:orientation="vertical" + android:weightSum="100"> + + + + - - + - + android:layout_height="match_parent" + android:textColor="@android:color/white" + android:textAlignment="center" + android:gravity="center" + android:text="@string/threads" /> - - - - - - - + android:layout_marginLeft="10dp" + android:background="#0000000f" + android:textColor="@android:color/white" /> + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml index 0a71dbd0e8010f..7af8f3a98c6319 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -16,7 +16,7 @@ --> - TfLiteCameraDemo + TfLite Camera Demo + Threads: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml index 3f3bdfb49480e7..1752b3b5f97e28 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -14,5 +14,10 @@ limitations under the License. --> - + diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD new file mode 100644 index 00000000000000..362d93636f7220 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -0,0 +1,68 @@ +# Description: +# OVIC Benchmarker Java API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +java_test( + name = "OvicClassifierTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.ovic.OvicClassifierTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +java_binary( + name = "ovic_validator", + srcs = ["src/main/java/org/tensorflow/ovic/OvicValidator.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + main_class = "org.tensorflow.ovic.OvicValidator", + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + ], +) + +android_library( + name = "ovicbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 76c33838bfe5b8..26349347faebac 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -2,11 +2,11 @@ This folder contains building code for track one of the [Low Power ImageNet Recognition Challenge workshop at CVPR 2018.](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018) -## Pre-requesits +## Pre-requisite Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. -## To test the benchmarker: +## Test the benchmarker: The testing utilities helps the developers (you) to make sure that your submissions in TfLite format will be processed as expected in the competition's benchmarking system. @@ -37,47 +37,122 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions -Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it as below. +Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it in two ways: -* Move your submission to the testdata folder: +#### Validate using randomly generated images -Let say the submission file is located at `/tmp/my_model.lite`, then +You can call the validator binary below to verify that your model fits the format requirements. This often helps you to catch size mismatches (e.g. output should be [1, 1001] instead of [1,1,1,1001]). Let say the submission file is located at `/path/to/my_model.lite`, then call: ```sh -cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:ovic_validator --cxxopt=-Wno-all +bazel-bin/tensorflow/contrib/lite/java/ovic/ovic_validator /path/to/my_model.lite +``` + +Successful validation should print the following message to terminal: + +``` +Successfully validated /path/to/my_model.lite. + +``` + +#### Test that the model produces sensible outcomes + +You can go a step further to verify that the model produces results as expected. This helps you catch bugs during TOCO conversion (e.g. using the wrong mean and std values). + +* Move your submission to the testdata folder: + +```sh +cp /path/to/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ ``` * Resize the test image to the resolutions that are expected by your submission: The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224. -* Add your model and test image to the BUILD rule: +* Add your model and test image to the BUILD rule at `tensorflow/contrib/lite/java/ovic/src/testdata/BUILD`: ```JSON -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - "ovic/src/testdata/my_model.lite", # <--- Your submission. - "ovic/src/testdata/my_test_image.jpg", # <--- Your test image. - ], - ... +filegroup( + name = "ovic_testdata", + srcs = [ + "@tflite_ovic_testdata//:float_model.lite", + "@tflite_ovic_testdata//:low_res_model.lite", + "@tflite_ovic_testdata//:quantized_model.lite", + "@tflite_ovic_testdata//:test_image_128.jpg", + "@tflite_ovic_testdata//:test_image_224.jpg" + "my_model.lite", # <--- Your submission. + "my_test_image.jpg", # <--- Your test image. + ], + ... ``` * Modify `OvicClassifierTest.java` to test your model. -Change `TEST_IMAGE_PATH` to `testdata/my_test_image.jpg`. If your model runs inference in floating point, change `FLOAT_MODEL_PATH` to `testdata/my_model.lite`. If your model runs [quantized inference](https://www.tensorflow.org/performance/quantization), change `QUANTIZED_MODEL_PATH` to `testdata/my_model.lite`. +Change `TEST_IMAGE_PATH` to `my_test_image.jpg`. Change either `FLOAT_MODEL_PATH` or `QUANTIZED_MODEL_PATH` to `my_model.lite` depending on whether your model runs inference in float or [8-bit](https://www.tensorflow.org/performance/quantization). Now you can run the bazel tests to catch any runtime issues with the submission. + +Note: Please make sure that your submission passes the test. If a submission fails to pass the test it will not be processed by the submission server. + +## Measure on-device latency + +We provide two ways to measure the on-device latency of your submission. The first is through our competition server, which is reliable and repeatable, but is limited to a few trials per day. The second is through the benchmarker Apk, which requires a device and may not be as accurate as the server, but has a fast turn-around and no access limitations. We recommend that the participants use the benchmarker apk for early development, and reserve the competition server for evaluating promising submissions. + +### Running the benchmarker app + +Make sure that you have followed instructions in [Test your submissions](#test-your-submissions) to add your model to the testdata folder and to the corresponding build rules. + +Modify `tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`: + +* Add your model to the benchmarker apk by changing `MODEL_PATH` and `TEST_IMAGE_PATH` below to your submission and test image. + +``` + private static final String TEST_IMAGE_PATH = "my_test_image.jpg"; + private static final String MODEL_PATH = "my_model.lite"; +``` + +* Adjust the benchmark parameters when needed: + +You can chnage the length of each experiment, and the processor affinity below. `BIG_CORE_MASK` is an integer whose binary encoding represents the set of used cores. This number is phone-specific. For example, Pixel 2 has 8 cores: the 4 little cores are represented by the 4 less significant bits, and the 4 big cores by the 4 more significant bits. Therefore a mask value of 16, or in binary `00010000`, represents using only the first big core. The mask 32, or in binary `00100000` uses the second big core and should deliver identical results as the mask 16 because the big cores are interchangeable. + +``` + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; +``` + +Note: You'll need ROOT access to the phone to change processor affinity. + +* Build and install the app. + +``` +bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/contrib/lite/java/ovic/demo/app:ovic_benchmarker_binary +adb install -r bazel-bin/tensorflow/contrib/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk +``` + +Start the app and click the `Start` button in dark green. The button should turn bright green, signaling that the experiment is running. The benchmarking results will be displayed after about the `WALL_TIME` you specified above. For example: + +``` +my_model.lite: Average latency=158.6ms after 20 runs. +``` + +### Sample latencies + +Note: the benchmarking results can be quite different depending on the background processes running on the phone. A few things that help stabilize the app's readings are placing the phone on a cooling plate, restarting the phone, and shutting down internet access. + +| Model | Pixel 1 latency (ms) | Pixel 2 latency (ms) | +| -------------------- |:---------------------:| --------------------:| +| float_model.lite | 120 | 155 | +| quantized_model.lite | 85 | 74 | +| low_res_model.lite | 4.2 | 4.0 | + +Since Pixel 2 has excellent support for 8-bit quantized models, we strongly recommend you to check out the [quantization training tutorial](https://www.tensorflow.org/performance/quantization). + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml new file mode 100644 index 00000000000000..55f2961fd717bd --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD new file mode 100644 index 00000000000000..83974f4b337bae --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -0,0 +1,29 @@ +# Sample app for OVIC benchmarking. +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "ovic_benchmarker_binary", + srcs = [ + "OvicBenchmarker.java", + "OvicBenchmarkerActivity.java", + ], + assets = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + assets_dir = "", + custom_package = "ovic.demo.app", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".lite", + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java similarity index 97% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java rename to tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java index d0102883e6b41f..113ab74a20dabc 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java @@ -1,4 +1,4 @@ -/*Copyright 2018 Google LLC +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.ovic; +package ovic.demo.app; import android.graphics.Bitmap; import android.os.SystemClock; @@ -22,6 +22,8 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import org.tensorflow.ovic.OvicClassifier; +import org.tensorflow.ovic.OvicSingleImageResult; /** * Class that benchmarks image classifier models. diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java new file mode 100644 index 00000000000000..59457c308ad7ca --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package ovic.demo.app; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.os.Bundle; +import android.os.Process; +import android.os.SystemClock; +import android.util.Log; +import android.view.View; +import android.widget.TextView; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.text.DecimalFormat; +import org.tensorflow.ovic.OvicSingleImageResult; + +/** Class that benchmark image classifier models. */ +public class OvicBenchmarkerActivity extends Activity { + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicBenchmarkerActivity"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + private static final String TEST_IMAGE_PATH = "test_image_224.jpg"; + private static final String MODEL_PATH = "float_model.lite"; + /** + * Each bottom press will launch a benchmarking experiment. The experiment stops when either the + * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS, + * whichever comes first. + */ + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; + /** Amount of time in milliseconds to wait for affinity to set. */ + private static final int WAIT_TIME_FOR_AFFINITY = 1000; + + /* The model to be benchmarked. */ + private MappedByteBuffer model = null; + private InputStream labelInputStream = null; + private OvicBenchmarker benchmarker; + /** Inference result of each iteration. */ + OvicSingleImageResult iterResult = null; + + private TextView textView = null; + // private Button startButton = null; + private static final DecimalFormat df2 = new DecimalFormat(".##"); + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + // TextView used to display the progress, for information purposes only. + textView = (TextView) findViewById(R.id.textView); + } + + private Bitmap loadTestBitmap() throws IOException { + InputStream imageStream = getAssets().open(TEST_IMAGE_PATH); + return BitmapFactory.decodeStream(imageStream); + } + + public void initializeTest() throws IOException { + Log.i(TAG, "Initializing benchmarker."); + benchmarker = new OvicBenchmarker(WALL_TIME); + AssetManager am = getAssets(); + AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH); + FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = modelInputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + labelInputStream = am.open(LABEL_PATH); + } + + public Boolean doTestIteration() throws IOException, InterruptedException { + if (benchmarker == null) { + throw new RuntimeException("Benchmarker has not been initialized."); + } + if (benchmarker.shouldStop()) { + return false; + } + if (!benchmarker.readyToTest()) { + Log.i(TAG, "getting ready to test."); + benchmarker.getReadyToTest(labelInputStream, model); + if (!benchmarker.readyToTest()) { + throw new RuntimeException("Failed to get the benchmarker ready."); + } + } + Log.i(TAG, "Going to do test iter."); + // Start testing. + Bitmap testImageBitmap = loadTestBitmap(); + iterResult = benchmarker.doTestIteration(testImageBitmap); + testImageBitmap.recycle(); + if (iterResult == null) { + throw new RuntimeException("Inference failed to produce a result."); + } + Log.i(TAG, iterResult.toString()); + return true; + } + + public void startPressed(View view) throws IOException { + Log.i(TAG, "Start pressed"); + try { + initializeTest(); + } catch (IOException e) { + Log.e(TAG, "Can't initialize benchmarker.", e); + throw e; + } + String displayText = ""; + try { + setProcessorAffinity(BIG_CORE_MASK); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + displayText = e.getMessage() + "\n"; + } + Log.i(TAG, "Successfully initialized benchmarker."); + int testIter = 0; + Boolean iterSuccess = false; + double totalLatency = 0.0f; + while (testIter < MAX_ITERATIONS) { + try { + iterSuccess = doTestIteration(); + } catch (IOException e) { + Log.e(TAG, "Error during iteration " + testIter); + throw e; + } catch (InterruptedException e) { + Log.e(TAG, "Interrupted at iteration " + testIter); + } + if (!iterSuccess) { + break; + } + testIter++; + totalLatency += (double) iterResult.latency; + } + ; + Log.i(TAG, "Benchmarking finished"); + + if (textView != null) { + if (testIter > 0) { + textView.setText( + displayText + + MODEL_PATH + + ": Average latency=" + + df2.format(totalLatency / testIter) + + "ms after " + + testIter + + " runs."); + } else { + textView.setText("Benchmarker failed to run on more than one images."); + } + } + } + + private static void setProcessorAffinity(int mask) throws IOException { + int myPid = Process.myPid(); + Log.i(TAG, String.format("Setting processor affinity to 0x%02x", mask)); + + String command = String.format("taskset -a -p %x %d", mask, myPid); + try { + Runtime.getRuntime().exec(command).waitFor(); + } catch (InterruptedException e) { + throw new IOException("Interrupted: " + e); + } + + // Make sure set took effect - try for a second to confirm the change took. If not then fail. + long startTimeMs = SystemClock.elapsedRealtime(); + while (true) { + int readBackMask = readCpusAllowedMask(); + if (readBackMask == mask) { + Log.i(TAG, String.format("Successfully set affinity to 0x%02x", mask)); + return; + } + if (SystemClock.elapsedRealtime() > startTimeMs + WAIT_TIME_FOR_AFFINITY) { + throw new IOException( + String.format( + "Core-binding failed: affinity set to 0x%02x but read back as 0x%02x\n" + + "please root device.", + mask, readBackMask)); + } + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + // Ignore sleep interrupted, will sleep again and compare is final cross-check. + } + } + } + + public static int readCpusAllowedMask() throws IOException { + // Determine how many CPUs there are total + final String pathname = "/proc/self/status"; + final String resultPrefix = "Cpus_allowed:"; + File file = new File(pathname); + String line = ""; + String allowedCPU = ""; + Integer allowedMask = null; + BufferedReader bufReader = null; + try { + bufReader = new BufferedReader(new FileReader(file)); + while ((line = bufReader.readLine()) != null) { + if (line.startsWith(resultPrefix)) { + allowedMask = Integer.valueOf(line.substring(resultPrefix.length()).trim(), 16); + allowedCPU = bufReader.readLine(); + break; + } + } + } catch (RuntimeException e) { + throw new IOException( + "Invalid number in " + pathname + " line: \"" + line + "\": " + e.getMessage()); + } finally { + if (bufReader != null) { + bufReader.close(); + } + } + if (allowedMask == null) { + throw new IOException(pathname + " missing " + resultPrefix + " line"); + } + Log.i(TAG, allowedCPU); + return allowedMask; + } +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle new file mode 100644 index 00000000000000..c5d19bad89a939 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.ovicbenchmarker" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "lite", "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 00000000000000..715d1b6d69c0f4 Binary files /dev/null and b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 00000000000000..9beff0885fd4c8 Binary files /dev/null and b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml new file mode 100644 index 00000000000000..93f5c6a016b499 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml new file mode 100644 index 00000000000000..e9d83bae543ae6 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml @@ -0,0 +1,54 @@ + + + + + + +