diff --git a/.bazelrc b/.bazelrc index 316949455a0114..a93862aa78a302 100644 --- a/.bazelrc +++ b/.bazelrc @@ -299,9 +299,11 @@ common:cuda --@local_config_cuda//:enable_cuda common:cuda --config=cuda_version # This flag is needed to include CUDA libraries. common:cuda --@local_config_cuda//cuda:include_cuda_libs=true +common:cuda --@cuda_driver//:include_cuda_umd_libs=true # This configuration is used for building the wheels. common:cuda_wheel --@local_config_cuda//cuda:include_cuda_libs=false +common:cuda_wheel --@cuda_driver//:include_cuda_umd_libs=false # CUDA: This config refers to building CUDA op kernels with clang. common:cuda_clang --config=cuda @@ -612,7 +614,6 @@ common:use_tar_archive_files --repo_env=USE_LLVM_TAR_ARCHIVE_FILES=1 common:use_tar_archive_files --repo_env=USE_MIRRORED_TAR_ARCHIVE_FILES=1 # Make Bazel not try to probe the host system for a C++ toolchain. -common:rbe_base --config=use_tar_archive_files common:rbe_base --config=resultstore common:rbe_base --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 common:rbe_base --define=EXECUTOR=remote @@ -655,8 +656,8 @@ common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instance # Download CUDA/CUDNN redistributions to preserve the repositories cache between # CPU and GPU builds. # TODO(ybaturina): Uncomment when RBE is ready to support this. -commonld:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 -commonld:rbe_linux_cpu --config=cuda_version +common:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +common:rbe_linux_cpu --config=cuda_version # Deprecated RBE config with non-hermetic toolchains. common:rbe_linux_cpu_clang_local --config=rbe_linux_cpu @@ -682,9 +683,6 @@ common:rbe_linux_cuda --config=cuda_clang_official common:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration common:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -# Enable forward compatibility for CUDA builds because RBE docker image doesn't -# have latest CUDA drivers installed. -common:rbe_linux_cuda --@cuda_driver//:enable_forward_compatibility=true common:rbe_linux_cuda_nvcc --config=rbe_linux_cuda common:rbe_linux_cuda_nvcc --config=cuda_nvcc @@ -877,7 +875,7 @@ test:linux_cpu_wheel_test --@local_xla//third_party/py:wheel_dependency=true --c test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --@local_xla//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --repo_env=HERMETIC_CUDA_UMD_VERSION=12.8.1 --@local_xla//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310,-no_oss_py313 diff --git a/.bazelversion b/.bazelversion index 5c733d6c13a497..26c75fe8ad4fc9 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1,2 +1,2 @@ -7.4.1 +7.7.0 # NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index c0682a4cac7035..07896a48470753 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.2.3" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.2.4" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 75339c6b4f6bd7..e635c4cd8ccc88 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -55,7 +55,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 with: name: SARIF file path: results.sarif @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@3599b3baa15b485a2e49ef411a7a4bb2452e7f93 # v3.29.5 + uses: github/codeql-action/upload-sarif@0499de31b99561a6d14a36a5f662c2a54f91beee # v3.29.5 with: sarif_file: results.sarif diff --git a/.github/workflows/stale-issues.yml b/.github/workflows/stale-issues.yml index d9408810eb32ac..53f272bd5b9d8a 100644 --- a/.github/workflows/stale-issues.yml +++ b/.github/workflows/stale-issues.yml @@ -31,7 +31,7 @@ jobs: pull-requests: write steps: - name: Awaiting response issues - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' @@ -59,7 +59,7 @@ jobs: close-pr-message: "This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further." repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Contribution issues - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: #Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: 'override-stale' diff --git a/RELEASE.md b/RELEASE.md index 4ce8bbb371728e..1c34667494f477 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -23,7 +23,9 @@ * Adds int8 and int16x8 support for SQRT operator. * Adds int16x8 support for EQUAL and NOT_EQUAL operators. * Adds support for int2 type. - * Adds support for int2/int4 in tfl.cast. + * Adds support for int2/int4 in tfl.cast . + * Adds support for SRQ int2 in tfl.fully_connected. + * Adds support for int4 in tfl.slice. ### Bug Fixes and Other Changes diff --git a/ci/official/containers/ml_build/Dockerfile b/ci/official/containers/ml_build/Dockerfile index d12c886cc6d57a..a4fb0cd9b1640a 100644 --- a/ci/official/containers/ml_build/Dockerfile +++ b/ci/official/containers/ml_build/Dockerfile @@ -12,14 +12,6 @@ COPY builder.packages.txt /builder.packages.txt RUN /setup.sources.sh && /setup.packages.sh /builder.packages.txt -# Install devtoolset-9 in /dt9 with glibc 2.17 and libstdc++ 4.8, for building -# manylinux2014-compatible packages. -COPY builder.devtoolset/fixlinks.sh /fixlinks.sh -COPY builder.devtoolset/rpm-patch.sh /rpm-patch.sh -COPY builder.devtoolset/build_devtoolset.sh /build_devtoolset.sh -COPY builder.devtoolset/glibc2.17-inline.patch /glibc2.17-inline.patch -RUN /build_devtoolset.sh devtoolset-9 /dt9 - # Setup Python COPY setup.python.sh /setup.python.sh COPY builder.requirements.txt /builder.requirements.txt @@ -56,9 +48,6 @@ RUN ln -sf /usr/bin/python3.12 /usr/bin/python3 RUN ln -sf /usr/bin/python3.12 /usr/bin/python RUN ln -sf /usr/lib/python3.12 /usr/lib/tf_python -# Make sure clang is on the path -RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang - # Link the compat driver to the location if available. RUN if [ -e "/usr/local/cuda/compat/libcuda.so.1" ]; then ln -s /usr/local/cuda/compat/libcuda.so.1 /usr/lib/x86_64-linux-gnu/libcuda.so.1; fi diff --git a/ci/official/containers/ml_build/builder.packages.txt b/ci/official/containers/ml_build/builder.packages.txt index 8dbbf4196440da..cf914a0425ef11 100644 --- a/ci/official/containers/ml_build/builder.packages.txt +++ b/ci/official/containers/ml_build/builder.packages.txt @@ -1,28 +1,9 @@ -# Packages to be installed for the new Docker image. - -# Packages needed to build devtoolset -file -flex -g++ -make -patch -rpm2cpio -unar -wget -xz-utils -cpio - # Other build-related tools apt-transport-https autoconf automake build-essential ca-certificates -llvm-18 -clang-18 -clang-tidy-18 -lld-18 -clang-format-12 curl git parallel @@ -32,4 +13,6 @@ unzip zip openjdk-21-jdk vim +wget jq +file diff --git a/ci/official/containers/ml_build/builder.requirements.txt b/ci/official/containers/ml_build/builder.requirements.txt index 114efaf9dc9757..ae113c68c2f03c 100644 --- a/ci/official/containers/ml_build/builder.requirements.txt +++ b/ci/official/containers/ml_build/builder.requirements.txt @@ -5,6 +5,9 @@ id urllib3 requests +# For XLA +pyyaml + # For JAX build ~= 1.2.2 # uv is faster than pip for installing Python packages. diff --git a/ci/official/containers/ml_build/setup.python.sh b/ci/official/containers/ml_build/setup.python.sh index cd56f3ca552d0f..b849457420f522 100755 --- a/ci/official/containers/ml_build/setup.python.sh +++ b/ci/official/containers/ml_build/setup.python.sh @@ -45,16 +45,6 @@ fi /setup.packages.sh pythons.txt -# Re-link pyconfig.h from x86_64-linux-gnu into the devtoolset directory -# for any Python version present -pushd /usr/include/x86_64-linux-gnu -for f in $(ls | grep python); do - # set up symlink for devtoolset-9 - rm -f /dt9/usr/include/x86_64-linux-gnu/$f - ln -s /usr/include/x86_64-linux-gnu/$f /dt9/usr/include/x86_64-linux-gnu/$f -done -popd - # Python 3.10 include headers fix: # sysconfig.get_path('include') incorrectly points to /usr/local/include/python # map /usr/include/python3.10 to /usr/local/include/python3.10 diff --git a/ci/official/envs/windows_x86_2022 b/ci/official/envs/windows_x86_2022 index 56187ad78eca17..3c57bcfb8114ee 100644 --- a/ci/official/envs/windows_x86_2022 +++ b/ci/official/envs/windows_x86_2022 @@ -15,7 +15,7 @@ TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" -TFCI_BAZEL_BAZELRC_ARGS="--output_user_root=C:/t" +TFCI_BAZEL_BAZELRC_ARGS="--output_user_root=C:/x" TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config=windows_x86_cpu_2022" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu_2022 TFCI_BUILD_PIP_PACKAGE_WHEEL_NAME_ARG="--repo_env=WHEEL_NAME=tensorflow" diff --git a/ci/official/utilities/cleanup_docker.sh b/ci/official/utilities/cleanup_docker.sh index 178da9310969ca..3be4a5f418172e 100755 --- a/ci/official/utilities/cleanup_docker.sh +++ b/ci/official/utilities/cleanup_docker.sh @@ -26,4 +26,5 @@ $ docker exec -it tf bash EOF docker ps -docker rm -f tf-${TFCI_PYTHON_VERSION} +echo "Removing container tf-$TFCI_PYTHON_VERSION-$TFCI_DOCKER_CONTAINER_POSTFIX" +docker rm -f tf-$TFCI_PYTHON_VERSION-$TFCI_DOCKER_CONTAINER_POSTFIX diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index 89318aa4ec78dc..01e549d02dfffc 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -51,7 +51,7 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then echo "GCE_METADATA_HOST=$IP_ADDR" >> $env_file fi - docker run $TFCI_DOCKER_ARGS --name tf-$TFCI_PYTHON_VERSION -w "$WORKING_DIR" -itd --rm \ + docker run $TFCI_DOCKER_ARGS --name tf-$TFCI_PYTHON_VERSION-$TFCI_DOCKER_CONTAINER_POSTFIX -w "$WORKING_DIR" -itd --rm \ -v "$TFCI_GIT_DIR:$WORKING_DIR" \ --env-file "$env_file" \ "$TFCI_DOCKER_IMAGE" \ @@ -65,4 +65,4 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then fi fi -tfrun() { docker exec tf-$TFCI_PYTHON_VERSION "$@"; } +tfrun() { docker exec tf-$TFCI_PYTHON_VERSION-$TFCI_DOCKER_CONTAINER_POSTFIX "$@"; } diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f000821983b779..558b59368e615b 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1033,6 +1033,7 @@ package_group( "//tensorflow_models/google/recml/...", "//third_party/cloud_tpu/convergence_tools/sdc_monitoring/...", "//third_party/cloud_tpu/inference_converter/...", + "//third_party/pathways/...", "//third_party/py/cloud_ml_autoflow/...", "//third_party/py/envlogger/...", "//third_party/py/gldm/...", @@ -1180,38 +1181,31 @@ tf_cc_shared_library( linkstatic = 1, per_os_targets = True, roots = [ - "//tensorflow/c/experimental/filesystem:filesystem_interface", - "//tensorflow/c/experimental/stream_executor:stream_executor", - "//tensorflow/c:env", - "//tensorflow/c:kernels", - "//tensorflow/c:kernels_experimental", - "//tensorflow/c:logging", - "//tensorflow/c:ops", - "//tensorflow/cc/saved_model:fingerprinting_impl", - "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/cc/saved_model:metrics_impl", - "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", - "//tensorflow/core/common_runtime:core_cpu_impl", - "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", - "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", - "//tensorflow/core:framework_internal_impl", - "//tensorflow/core/framework:tensor", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", - "//tensorflow/core:lib_internal_impl", - "//tensorflow/core/profiler:profiler_impl", - "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so. - "//tensorflow/lite/kernels/shim:tf_kernel_shim", - "@local_xla//xla/stream_executor:stream_executor_impl", - "@local_xla//xla/tsl/framework:bfc_allocator", - "@local_xla//xla/tsl/framework:metrics", - ] + tf_additional_binary_deps() + - # TODO(b/259305727): Remove this select and include captured_function in macos builds. - select({ - "//tensorflow:macos": [], - "//conditions:default": [ - "//tensorflow/core/data:captured_function", - ], - }), + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor", + "//tensorflow/c:env", + "//tensorflow/c:kernels", + "//tensorflow/c:kernels_experimental", + "//tensorflow/c:ops", + "//tensorflow/cc/saved_model:fingerprinting_impl", + "//tensorflow/cc/saved_model:loader_lite_impl", + "//tensorflow/cc/saved_model:metrics_impl", + "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", + "//tensorflow/core/common_runtime:core_cpu_impl", + "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", + "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", + "//tensorflow/core:framework_internal_impl", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", + "//tensorflow/core:lib_internal_impl", + "//tensorflow/core/profiler:profiler_impl", + "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so. + "//tensorflow/lite/kernels/shim:tf_kernel_shim", + "@local_xla//xla/stream_executor:stream_executor_impl", + "@local_xla//xla/tsl/framework:bfc_allocator", + "@local_xla//xla/tsl/framework:metrics", + "//tensorflow/core/data:captured_function", + ] + tf_additional_binary_deps(), soversion = VERSION, static_deps = PACKAGE_STATIC_DEPS, visibility = ["//visibility:public"], diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 726433bafded24..3f4ec98028e8c3 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -298,7 +298,6 @@ tf_cuda_library( ], "//conditions:default": [ ":env", - ":logging", ":tf_status", ":tf_tensor", "//tensorflow/c/experimental/filesystem:modular_filesystem", @@ -325,18 +324,6 @@ tf_cuda_library( alwayslink = 1, ) -cc_library( - name = "logging", - srcs = ["logging.cc"], - hdrs = ["logging.h"], - visibility = ["//visibility:public"], - deps = [ - ":c_api_macros", - "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:stringprintf", - ], -) - tf_cuda_library( name = "tf_status_internal", hdrs = [ diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index b919be52b0bf68..4dd78e4cd7bbb1 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1171,7 +1171,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 " "from function 'MyFunc'"), - string(TF_Message(s_))); + std::string(TF_Message(s_))); } TEST_F(CApiFunctionTest, NodeMissingInput) { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e3e7d812b15838..f59a73a0871945 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -2478,7 +2478,7 @@ TEST_F(CApiAttributesTest, Names) { TF_OperationGetAttrName(oper, 0, value.get(), s_); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ("v", string(static_cast(value.get()), 1)); + EXPECT_EQ("v", std::string(static_cast(value.get()), 1)); } TEST_F(CApiAttributesTest, Errors) { diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 97a5bbd4b6077a..9dae0d3afd46fe 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -119,8 +119,7 @@ CheckpointReader::BuildV2VarMaps() { BundleEntryProto entry; v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - CHECK(entry.ParseFromArray(v2_reader_->value().data(), - v2_reader_->value().size())) + CHECK(entry.ParseFromString(v2_reader_->value())) << entry.InitializationErrorString(); for (int i = 0; i < entry.slices_size(); ++i) { const auto& slice_proto = entry.slices(i); @@ -140,8 +139,7 @@ CheckpointReader::BuildV2VarMaps() { v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; - CHECK(entry.ParseFromArray(v2_reader_->value().data(), - v2_reader_->value().size())) + CHECK(entry.ParseFromString(v2_reader_->value())) << entry.InitializationErrorString(); string key(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index ccde2ba3d9b769..91f83b3f88967d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -939,7 +939,8 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status) { tensorflow::FunctionDef function_def; - if (!function_def.ParseFromArray(serialized_function_def, size)) { + if (!function_def.ParseFromString( + absl::string_view(serialized_function_def, size))) { status->status = tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index 03dd862f95cb0f..7d25709df2dfc7 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" struct TF_StringStream { - std::vector<::tensorflow::string>* list; + std::vector* list; size_t position; }; @@ -134,7 +134,7 @@ void TF_StringStreamDone(TF_StringStream* list) { delete list; } TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { - auto* children = new std::vector<::tensorflow::string>; + auto* children = new std::vector; TF_SetStatus(status, TF_OK, ""); ::tensorflow::Set_TF_Status_from_Status( @@ -147,7 +147,7 @@ TF_StringStream* TF_GetChildren(const char* dirname, TF_Status* status) { } TF_StringStream* TF_GetLocalTempDirectories() { - auto* tmpdirs = new std::vector<::tensorflow::string>; + auto* tmpdirs = new std::vector; ::tensorflow::Env::Default()->GetLocalTempDirectories(tmpdirs); diff --git a/tensorflow/c/env_test.cc b/tensorflow/c/env_test.cc index d4c9bfce3c2127..3d338d4377366b 100644 --- a/tensorflow/c/env_test.cc +++ b/tensorflow/c/env_test.cc @@ -35,14 +35,12 @@ TEST(TestEnv, TestDirHandling) { TF_Status* s = TF_NewStatus(); - ::tensorflow::string dirpath = - ::tensorflow::io::JoinPath(tempdir, "somedir"); + std::string dirpath = ::tensorflow::io::JoinPath(tempdir, "somedir"); TF_CreateDir(dirpath.c_str(), s); ASSERT_TF_OK(s) << "TF_CreateDir failed for " << dirpath << ": " << TF_Message(s); - ::tensorflow::string filepath = - ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); + std::string filepath = ::tensorflow::io::JoinPath(dirpath, "somefile.txt"); TF_WritableFileHandle* handle; TF_NewWritableFile(filepath.c_str(), &handle, s); ASSERT_TF_OK(s) << "NewWritableFile failed for " << filepath << ": " @@ -61,7 +59,7 @@ TEST(TestEnv, TestDirHandling) { ASSERT_TF_OK(s) << "TF_GetChildren failed for " << dirpath; const char* childpath; ASSERT_TRUE(TF_StringStreamNext(children, &childpath)); - ASSERT_EQ(::tensorflow::string(childpath), "somefile.txt"); + ASSERT_EQ(std::string(childpath), "somefile.txt"); // There should only be one file in this directory. ASSERT_FALSE(TF_StringStreamNext(children, &childpath)); ASSERT_EQ(childpath, nullptr); diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 8fa3e726e6a837..f0f6e5351372e1 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -31,10 +31,10 @@ cc_library( ":gcs_helper", ":ram_file_block_cache", "//tensorflow/c:env", - "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", + "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud:google_cloud_cpp_common", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", @@ -65,7 +65,6 @@ cc_library( deps = [ ":cleanup", "//tensorflow/c:env", - "//tensorflow/c:logging", "//tensorflow/c:tf_status", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -86,6 +85,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform/cloud:now_seconds_env", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc index b0d283fff82d9b..e639f9a7dda476 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache_test.cc @@ -27,7 +27,7 @@ namespace tensorflow { namespace { TEST(ExpiringLRUCacheTest, MaxAge) { - const string key = "a"; + const std::string key = "a"; std::unique_ptr env(new NowSecondsEnv); tf_gcs_filesystem::ExpiringLRUCache cache( 1, 0, [&env]() { return env->NowSeconds(); }); @@ -95,9 +95,10 @@ TEST(ExpiringLRUCacheTest, MaxEntries) { TEST(ExpiringLRUCacheTest, LookupOrCompute) { // max_age of 0 means we should always compute. - uint64 num_compute_calls = 0; + uint64_t num_compute_calls = 0; tf_gcs_filesystem::ExpiringLRUCache::ComputeFunc compute_func = - [&num_compute_calls](const string& key, int* value, TF_Status* status) { + [&num_compute_calls](const std::string& key, int* value, + TF_Status* status) { *value = num_compute_calls; num_compute_calls++; return TF_SetStatus(status, TF_OK, ""); diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index 3b9650b7416315..f61208c7b4a174 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -40,7 +40,6 @@ limitations under the License. #include "google/cloud/storage/client.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" -#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for GCS environments. diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h index 0060abc76699c3..3e972fa6292995 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -33,7 +33,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tensorflow/c/env.h" -#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" namespace tf_gcs_filesystem { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc index 4ad4a8ea1868f3..23645ed8e878bf 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" #include -#include #include #include #include @@ -25,6 +24,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" @@ -39,7 +39,7 @@ namespace tensorflow { namespace { absl::Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache, - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, std::vector* out) { out->clear(); out->resize(n, 0); @@ -54,7 +54,7 @@ absl::Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache, } TEST(RamFileBlockCacheTest, IsCacheEnabled) { - auto fetcher = [](const string& filename, size_t offset, size_t n, + auto fetcher = [](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { // Do nothing. TF_SetStatus(status, TF_OK, ""); @@ -73,14 +73,14 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) { TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); TF_SetStatus(status, TF_OK, ""); return n; }; - string filename = "file"; + std::string filename = "file"; tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); std::vector out; @@ -101,12 +101,12 @@ TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { } TEST(RamFileBlockCacheTest, PassThrough) { - const string want_filename = "foo/bar"; + const std::string want_filename = "foo/bar"; const size_t want_offset = 42; const size_t want_n = 1024; int calls = 0; auto fetcher = [&calls, want_filename, want_offset, want_n]( - const string& got_filename, size_t got_offset, + const std::string& got_filename, size_t got_offset, size_t got_n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_offset, want_offset); @@ -143,7 +143,7 @@ TEST(RamFileBlockCacheTest, BlockAlignment) { buf.push_back(i); } // The fetcher just fetches slices of the buffer. - auto fetcher = [&buf](const string& filename, size_t offset, size_t n, + auto fetcher = [&buf](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { int64_t bytes_transferred; if (offset < buf.size()) { @@ -191,8 +191,8 @@ TEST(RamFileBlockCacheTest, BlockAlignment) { TEST(RamFileBlockCacheTest, CacheHits) { const size_t block_size = 16; std::set calls; - auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, char* buffer, + auto fetcher = [&calls, block_size](const std::string& filename, + size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); @@ -202,7 +202,7 @@ TEST(RamFileBlockCacheTest, CacheHits) { TF_SetStatus(status, TF_OK, ""); return n; }; - const uint32 block_count = 256; + const uint32_t block_count = 256; tf_gcs_filesystem::RamFileBlockCache cache( block_size, block_count * block_size, 0, fetcher); std::vector out; @@ -225,7 +225,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) { bool first_block = false; bool second_block = false; auto fetcher = [block_size, file_size, &first_block, &second_block]( - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); @@ -269,8 +269,9 @@ TEST(RamFileBlockCacheTest, Inconsistent) { // where we expected complete blocks. const size_t block_size = 16; // This fetcher returns OK but only fills in one byte for any offset. - auto fetcher = [block_size](const string& filename, size_t offset, size_t n, - char* buffer, TF_Status* status) -> int64_t { + auto fetcher = [block_size](const std::string& filename, size_t offset, + size_t n, char* buffer, + TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); EXPECT_GE(n, 1); @@ -293,8 +294,8 @@ TEST(RamFileBlockCacheTest, Inconsistent) { TEST(RamFileBlockCacheTest, LRU) { const size_t block_size = 16; std::list calls; - auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, char* buffer, + auto fetcher = [&calls, block_size](const std::string& filename, + size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_FALSE(calls.empty()) << "at offset = " << offset; @@ -306,7 +307,7 @@ TEST(RamFileBlockCacheTest, LRU) { TF_SetStatus(status, TF_OK, ""); return n; }; - const uint32 block_count = 2; + const uint32_t block_count = 2; tf_gcs_filesystem::RamFileBlockCache cache( block_size, block_count * block_size, 0, fetcher); std::vector out; @@ -342,7 +343,7 @@ TEST(RamFileBlockCacheTest, LRU) { TEST(RamFileBlockCacheTest, MaxStaleness) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); @@ -386,13 +387,13 @@ TEST(RamFileBlockCacheTest, MaxStaleness) { TEST(RamFileBlockCacheTest, RemoveFile) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; if (offset > 0) { // The first block is lower case and all subsequent blocks are upper case. - c = toupper(c); + c = absl::ascii_toupper(c); } memset(buffer, c, n); TF_SetStatus(status, TF_OK, ""); @@ -448,7 +449,7 @@ TEST(RamFileBlockCacheTest, RemoveFile) { TEST(RamFileBlockCacheTest, Prune) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); @@ -458,7 +459,7 @@ TEST(RamFileBlockCacheTest, Prune) { std::vector out; // Our fake environment is initialized with the current timestamp. std::unique_ptr env(new NowSecondsEnv); - uint64 now = Env::Default()->NowSeconds(); + uint64_t now = Env::Default()->NowSeconds(); env->SetNowSeconds(now); tf_gcs_filesystem::RamFileBlockCache cache( 8, 32, 1 /* max staleness */, fetcher, @@ -487,7 +488,7 @@ TEST(RamFileBlockCacheTest, Prune) { // timestamp of `now` + 2, file "a" is stale because its first block is stale, // but file "b" is not stale yet. Thus, once the pruning thread wakes up (in // one second of wall time), it should remove "a" and leave "b" alone. - uint64 start = Env::Default()->NowSeconds(); + uint64_t start = Env::Default()->NowSeconds(); do { Env::Default()->SleepForMicroseconds(100000); } while (cache.CacheSize() == 24 && Env::Default()->NowSeconds() - start < 3); @@ -515,7 +516,7 @@ TEST(RamFileBlockCacheTest, ParallelReads) { absl::BlockingCounter counter(callers); absl::Notification notification; auto fetcher = [&counter, ¬ification]( - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { if (counter.DecrementCount()) { notification.Notify(); @@ -560,7 +561,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { int num_requests = 0; absl::Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( - const string& filename, size_t offset, size_t n, + const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset, 0); @@ -591,7 +592,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { TEST(RamFileBlockCacheTest, Flush) { int calls = 0; - auto fetcher = [&calls](const string& filename, size_t offset, size_t n, + auto fetcher = [&calls](const std::string& filename, size_t offset, size_t n, char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc index 32ac04832551c1..205aeec55ebf8c 100644 --- a/tensorflow/c/experimental/grappler/grappler_test.cc +++ b/tensorflow/c/experimental/grappler/grappler_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -70,11 +71,11 @@ TEST(Grappler, SuccessfulRegistration) { TF_ASSERT_OK(InitGraphPlugin(plugin_init)); ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers( - std::set{"Success"}) + std::set{"Success"}) .size(), 1); ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs( - true, std::set{"Success"}); + true, std::set{"Success"}); ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF); } @@ -95,7 +96,7 @@ TEST(Grappler, MultiplePluginRegistration) { TF_ASSERT_OK(InitGraphPlugin(plugin_init_0)); TF_ASSERT_OK(InitGraphPlugin(plugin_init_1)); ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers( - std::set{"Device0", "Device1"}) + std::set{"Device0", "Device1"}) .size(), 2); } @@ -132,12 +133,12 @@ TEST(Grappler, OptimizeFuncNotSet) { TEST(TF_GrapplerItem, NodesToPreserve) { GrapplerItem item; - item.fetch = std::vector{"Conv", "BiasAdd"}; - std::unordered_set nodes_preserved = item.NodesToPreserve(); + item.fetch = std::vector{"Conv", "BiasAdd"}; + std::unordered_set nodes_preserved = item.NodesToPreserve(); TF_GrapplerItem* c_item = reinterpret_cast(&item); int list_total_size = 0; - for (const string& s : nodes_preserved) { + for (const std::string& s : nodes_preserved) { list_total_size += s.size(); } @@ -158,20 +159,21 @@ TEST(TF_GrapplerItem, NodesToPreserve) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); for (size_t i = 0; i < nodes_preserved.size(); ++i) { - EXPECT_EQ(nodes_preserved.find(string(static_cast(values[i]), - lens[i])) != nodes_preserved.end(), - true); + EXPECT_EQ( + nodes_preserved.find(std::string(static_cast(values[i]), + lens[i])) != nodes_preserved.end(), + true); } TF_DeleteStatus(status); } TEST(TF_GrapplerItem, FetchNodes) { GrapplerItem item; - item.fetch = std::vector{"Conv", "BiasAdd"}; + item.fetch = std::vector{"Conv", "BiasAdd"}; TF_GrapplerItem* c_item = reinterpret_cast(&item); int list_total_size = 0; - for (const string& s : item.fetch) { + for (const std::string& s : item.fetch) { list_total_size += s.size(); } @@ -193,7 +195,7 @@ TEST(TF_GrapplerItem, FetchNodes) { for (size_t i = 0; i < item.fetch.size(); ++i) { EXPECT_EQ(item.fetch[i].size(), lens[i]) << i; EXPECT_EQ(item.fetch[i], - string(static_cast(values[i]), lens[i])) + std::string(static_cast(values[i]), lens[i])) << i; } TF_DeleteStatus(status); @@ -307,13 +309,13 @@ TEST(TF_FunctionLibraryDefinition, LookUpOpDef) { TF_NewFunctionLibraryDefinition(g_buf, status); TF_LookUpOpDef(func, "Add", op_buf, status); - string actual_string(reinterpret_cast(op_buf->data), - op_buf->length); + std::string actual_string(reinterpret_cast(op_buf->data), + op_buf->length); ASSERT_EQ(TF_OK, TF_GetCode(status)); const OpDef* expected_op_def; TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def)); - string expected_serialized; + std::string expected_serialized; expected_op_def->SerializeToString(&expected_serialized); EXPECT_EQ(expected_serialized, actual_string); TF_DeleteBuffer(g_buf); diff --git a/tensorflow/c/experimental/ops/gen/common/case_format.cc b/tensorflow/c/experimental/ops/gen/common/case_format.cc index 82acc32f623fd8..52808e9900ca49 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format.cc @@ -31,14 +31,14 @@ enum CaseFormatType { UPPER_SNAKE, }; -string FormatStringCase(const string &str, CaseFormatType to, - const char delimiter = '_') { +std::string FormatStringCase(const std::string& str, CaseFormatType to, + const char delimiter = '_') { const bool from_snake = (str == absl::AsciiStrToUpper(str)) || (str == absl::AsciiStrToLower(str)); const bool toUpper = (to == UPPER_CAMEL || to == UPPER_SNAKE); const bool toSnake = (to == LOWER_SNAKE || to == UPPER_SNAKE); - string result; + std::string result; bool inputStart = true; bool wordStart = true; @@ -52,7 +52,7 @@ string FormatStringCase(const string &str, CaseFormatType to, wordStart = true; continue; } - if (!from_snake && isupper(c)) { + if (!from_snake && absl::ascii_isupper(c)) { wordStart = true; } @@ -65,9 +65,9 @@ string FormatStringCase(const string &str, CaseFormatType to, const bool shouldCapIfSnake = toUpper; const bool shouldCapIfCamel = wordStart && (toUpper || !inputStart); if ((toSnake && shouldCapIfSnake) || (!toSnake && shouldCapIfCamel)) { - result += toupper(c); + result += absl::ascii_toupper(c); } else { - result += tolower(c); + result += absl::ascii_tolower(c); } // at this point we are no longer at the start of a word: @@ -90,16 +90,16 @@ string FormatStringCase(const string &str, CaseFormatType to, // Public interface // -string toLowerCamel(const string &s, const char delimiter) { +std::string toLowerCamel(const std::string& s, const char delimiter) { return FormatStringCase(s, LOWER_CAMEL, delimiter); } -string toLowerSnake(const string &s, const char delimiter) { +std::string toLowerSnake(const std::string& s, const char delimiter) { return FormatStringCase(s, LOWER_SNAKE, delimiter); } -string toUpperCamel(const string &s, const char delimiter) { +std::string toUpperCamel(const std::string& s, const char delimiter) { return FormatStringCase(s, UPPER_CAMEL, delimiter); } -string toUpperSnake(const string &s, const char delimiter) { +std::string toUpperSnake(const std::string& s, const char delimiter) { return FormatStringCase(s, UPPER_SNAKE, delimiter); } diff --git a/tensorflow/c/experimental/ops/gen/common/case_format.h b/tensorflow/c/experimental/ops/gen/common/case_format.h index f8255f6aa21c17..880f286788e0a2 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format.h +++ b/tensorflow/c/experimental/ops/gen/common/case_format.h @@ -35,10 +35,10 @@ namespace generator { // "__OneTwo__" (in camel case) <==> "__ONE_TWO__" (in snake case) // // Note: performance not yet tested. -string toLowerCamel(const string &s, const char delimiter = '_'); -string toLowerSnake(const string &s, const char delimiter = '_'); -string toUpperCamel(const string &s, const char delimiter = '_'); -string toUpperSnake(const string &s, const char delimiter = '_'); +std::string toLowerCamel(const std::string& s, const char delimiter = '_'); +std::string toLowerSnake(const std::string& s, const char delimiter = '_'); +std::string toUpperCamel(const std::string& s, const char delimiter = '_'); +std::string toUpperSnake(const std::string& s, const char delimiter = '_'); } // namespace generator } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/common/case_format_test.cc b/tensorflow/c/experimental/ops/gen/common/case_format_test.cc index 302bcc42453169..e60473fca7896d 100644 --- a/tensorflow/c/experimental/ops/gen/common/case_format_test.cc +++ b/tensorflow/c/experimental/ops/gen/common/case_format_test.cc @@ -25,13 +25,13 @@ namespace { // For each test case, we manually construct the 4 variations in string case and // test all 16 conversions: from and to each of the 4 string case variations. struct Variations { - string lower_camel; - string lower_snake; - string upper_camel; - string upper_snake; + std::string lower_camel; + std::string lower_snake; + std::string upper_camel; + std::string upper_snake; }; -void TestSingleVariation(const string &str, Variations expected, +void TestSingleVariation(const std::string& str, Variations expected, char delimiter = '_') { EXPECT_EQ(expected.lower_camel, toLowerCamel(str, delimiter)); EXPECT_EQ(expected.lower_snake, toLowerSnake(str, delimiter)); diff --git a/tensorflow/c/experimental/ops/gen/common/controller.cc b/tensorflow/c/experimental/ops/gen/common/controller.cc index fb3e321714b108..ae3be0379ff254 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.cc +++ b/tensorflow/c/experimental/ops/gen/common/controller.cc @@ -43,7 +43,7 @@ Controller::Controller(PathConfig path_config, Env* env) } Controller::~Controller() { delete api_def_map_; } -const void Controller::WriteFile(const string& file_path, +const void Controller::WriteFile(const std::string& file_path, const SourceCode& code) const { TF_CHECK_OK(WriteStringToFile(env_, file_path, code.Render())) << file_path; } @@ -60,8 +60,9 @@ void Controller::InitializeOpApi() { api_def_map_ = new ApiDefMap(op_list_); for (const auto& op : op_list_.op()) { for (const auto& dir : path_config_.api_dirs) { - const string file_name = absl::Substitute("api_def_$0.pbtxt", op.name()); - const string file_path = io::JoinPath(dir, file_name); + const std::string file_name = + absl::Substitute("api_def_$0.pbtxt", op.name()); + const std::string file_path = io::JoinPath(dir, file_name); if (env_->FileExists(file_path).ok()) { TF_CHECK_OK(api_def_map_->LoadFile(env_, file_path)) << file_path; } else { diff --git a/tensorflow/c/experimental/ops/gen/common/controller.h b/tensorflow/c/experimental/ops/gen/common/controller.h index e152efeb6d8f9f..c33891f963d7a6 100644 --- a/tensorflow/c/experimental/ops/gen/common/controller.h +++ b/tensorflow/c/experimental/ops/gen/common/controller.h @@ -32,7 +32,8 @@ class Controller { public: explicit Controller(PathConfig path_config, Env* env = Env::Default()); virtual ~Controller(); - const void WriteFile(const string& file_path, const SourceCode& code) const; + const void WriteFile(const std::string& file_path, + const SourceCode& code) const; const std::vector& GetModelOps() const; private: diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.cc b/tensorflow/c/experimental/ops/gen/common/path_config.cc index 2ec57d67c9d6f7..74b4c3e327223d 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.cc +++ b/tensorflow/c/experimental/ops/gen/common/path_config.cc @@ -24,9 +24,10 @@ limitations under the License. namespace tensorflow { namespace generator { -PathConfig::PathConfig(const string& output_dir, const string& source_dir, - const string& api_dir_list, - const std::vector op_names) +PathConfig::PathConfig(const std::string& output_dir, + const std::string& source_dir, + const std::string& api_dir_list, + const std::vector op_names) : output_path(output_dir), op_names(op_names) { api_dirs = str_util::Split(api_dir_list, ",", str_util::SkipEmpty()); @@ -39,7 +40,7 @@ PathConfig::PathConfig(const string& output_dir, const string& source_dir, tf_root_dir = "tensorflow"; // Prefix, e.g. "third_party" given root_dir "third_party/tensorflow/...." - std::vector source_path_components = + std::vector source_path_components = tensorflow::str_util::Split(source_dir, "/"); auto source_tfroot_pos = std::find(source_path_components.begin(), source_path_components.end(), tf_root_dir); @@ -51,7 +52,7 @@ PathConfig::PathConfig(const string& output_dir, const string& source_dir, } // TF subdir, e.g. "c/ops" given output_dir "blah/blah/tensorflow/c/ops" - std::vector output_path_components = + std::vector output_path_components = tensorflow::str_util::Split(output_dir, "/"); auto output_tfroot_pos = std::find(output_path_components.begin(), output_path_components.end(), tf_root_dir); diff --git a/tensorflow/c/experimental/ops/gen/common/path_config.h b/tensorflow/c/experimental/ops/gen/common/path_config.h index ce29063be5f682..d47266f86e38ef 100644 --- a/tensorflow/c/experimental/ops/gen/common/path_config.h +++ b/tensorflow/c/experimental/ops/gen/common/path_config.h @@ -23,17 +23,18 @@ namespace tensorflow { namespace generator { struct PathConfig { - string output_path; - std::vector op_names; - std::vector api_dirs; - string tf_prefix_dir; - string tf_root_dir; - string tf_output_dir; + std::string output_path; + std::vector op_names; + std::vector api_dirs; + std::string tf_prefix_dir; + std::string tf_root_dir; + std::string tf_output_dir; explicit PathConfig() = default; - explicit PathConfig(const string &output_dir, const string &source_dir, - const string &api_dir_list, - const std::vector op_names); + explicit PathConfig(const std::string& output_dir, + const std::string& source_dir, + const std::string& api_dir_list, + const std::vector op_names); }; } // namespace generator diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.cc b/tensorflow/c/experimental/ops/gen/common/source_code.cc index 2b7bce6a263184..b12949cd1dc12b 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.cc +++ b/tensorflow/c/experimental/ops/gen/common/source_code.cc @@ -25,20 +25,20 @@ limitations under the License. namespace tensorflow { namespace generator { -string SourceCode::Render() const { - string code; +std::string SourceCode::Render() const { + std::string code; for (const Line& line : lines_) { - absl::StrAppend(&code, string(line.indent * spaces_per_indent_, ' '), + absl::StrAppend(&code, std::string(line.indent * spaces_per_indent_, ' '), line.text, "\n"); } return code; } -void SourceCode::AddLineWithIndent(const string& line) { +void SourceCode::AddLineWithIndent(const std::string& line) { ValidateAndAddLine(current_indent_, line); } -void SourceCode::AddLineWithoutIndent(const string& line) { +void SourceCode::AddLineWithoutIndent(const std::string& line) { ValidateAndAddLine(0, line); } @@ -48,7 +48,7 @@ void SourceCode::IncreaseIndent() { current_indent_++; } void SourceCode::DecreaseIndent() { current_indent_--; } -void SourceCode::ValidateAndAddLine(int indent, const string& raw_line) { +void SourceCode::ValidateAndAddLine(int indent, const std::string& raw_line) { absl::string_view line(raw_line); bool had_trailing_newline = absl::ConsumeSuffix(&line, "\n"); @@ -57,7 +57,8 @@ void SourceCode::ValidateAndAddLine(int indent, const string& raw_line) { } else if (had_trailing_newline) { LOG(WARNING) << "Superfluous trailing newline in '" << line << "'"; } - lines_.push_back({indent, string(absl::StripTrailingAsciiWhitespace(line))}); + lines_.push_back( + {indent, std::string(absl::StripTrailingAsciiWhitespace(line))}); } } // namespace generator diff --git a/tensorflow/c/experimental/ops/gen/common/source_code.h b/tensorflow/c/experimental/ops/gen/common/source_code.h index df1aa90acf7b8c..9fd7f7eec5e174 100644 --- a/tensorflow/c/experimental/ops/gen/common/source_code.h +++ b/tensorflow/c/experimental/ops/gen/common/source_code.h @@ -24,13 +24,13 @@ namespace generator { class SourceCode { public: - string Render() const; + std::string Render() const; void SetSpacesPerIndent(int spaces_per_indent) { spaces_per_indent_ = spaces_per_indent; } - void AddLineWithIndent(const string &line); - void AddLineWithoutIndent(const string &line); + void AddLineWithIndent(const std::string& line); + void AddLineWithoutIndent(const std::string& line); void AddBlankLine(); void IncreaseIndent(); void DecreaseIndent(); @@ -38,10 +38,10 @@ class SourceCode { private: struct Line { int indent; - string text; + std::string text; }; - void ValidateAndAddLine(int indent_level, const string &raw_line); + void ValidateAndAddLine(int indent_level, const std::string& raw_line); int spaces_per_indent_ = 2; int current_indent_ = 0; diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.cc b/tensorflow/c/experimental/ops/gen/common/view_util.cc index 388aa0646db82b..5ca9b59c9841e9 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.cc +++ b/tensorflow/c/experimental/ops/gen/common/view_util.cc @@ -23,17 +23,20 @@ limitations under the License. namespace tensorflow { namespace generator { -string Call(const string& object, const string& method, - std::vector arguments, const char* oper) { +std::string Call(const std::string& object, const std::string& method, + std::vector arguments, const char* oper) { return absl::Substitute("$0$1$2($3)", object, oper, method, absl::StrJoin(arguments, ", ")); } -string Call(const string& function, std::vector arguments) { +std::string Call(const std::string& function, + std::vector arguments) { return absl::Substitute("$0($1)", function, absl::StrJoin(arguments, ", ")); } -string Quoted(const string& s) { return absl::Substitute("\"$0\"", s); } +std::string Quoted(const std::string& s) { + return absl::Substitute("\"$0\"", s); +} } // namespace generator } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/common/view_util.h b/tensorflow/c/experimental/ops/gen/common/view_util.h index 7ab437a90e4fd8..f23831ce8a07dd 100644 --- a/tensorflow/c/experimental/ops/gen/common/view_util.h +++ b/tensorflow/c/experimental/ops/gen/common/view_util.h @@ -22,10 +22,11 @@ limitations under the License. namespace tensorflow { namespace generator { -string Call(const string &function, std::vector arguments); -string Call(const string &object, const string &method, - std::vector arguments, const char *oper = "->"); -string Quoted(const string &s); +std::string Call(const std::string& function, + std::vector arguments); +std::string Call(const std::string& object, const std::string& method, + std::vector arguments, const char* oper = "->"); +std::string Quoted(const std::string& s); } // namespace generator } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc index 3fe5c059ca4e70..45e7b87069e361 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc @@ -52,11 +52,11 @@ SourceCode CppGenerator::SourceFileContents() const { return GenerateOneFile(cpp::RendererContext::kSource); } -string CppGenerator::HeaderFileName() const { +std::string CppGenerator::HeaderFileName() const { return io::JoinPath(path_config_.output_path, cpp_config_.unit + "_ops.h"); } -string CppGenerator::SourceFileName() const { +std::string CppGenerator::SourceFileName() const { return io::JoinPath(path_config_.output_path, cpp_config_.unit + "_ops.cc"); } diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h index 0a7b08cd9b171f..b4d016e0ecca44 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h @@ -30,8 +30,8 @@ class CppGenerator { explicit CppGenerator(cpp::CppConfig cpp_config, PathConfig path_config); SourceCode HeaderFileContents() const; SourceCode SourceFileContents() const; - string HeaderFileName() const; - string SourceFileName() const; + std::string HeaderFileName() const; + std::string SourceFileName() const; void WriteHeaderFile() const; void WriteSourceFile() const; diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc index f4a4d82bbce423..e1db2c9b8ce14b 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc @@ -30,12 +30,12 @@ namespace generator { namespace { TEST(CppGeneratorTest, typical_usage) { - string category = "testing"; - string name_space = "tensorflow::ops"; - string output_dir = "tensorflow/c/experimental/ops/gen/cpp/golden"; - string source_dir = "tensorflow"; - string api_dirs = ""; - std::vector ops = { + std::string category = "testing"; + std::string name_space = "tensorflow::ops"; + std::string output_dir = "tensorflow/c/experimental/ops/gen/cpp/golden"; + std::string source_dir = "tensorflow"; + std::string api_dirs = ""; + std::vector ops = { "Neg", // Simple unary Op "MatMul", // 2 inputs & attrs with default values "IdentityN", // Variadic input+output @@ -50,17 +50,19 @@ TEST(CppGeneratorTest, typical_usage) { CppGenerator generator(cpp_config, controller_config); Env *env = Env::Default(); - string golden_dir = io::JoinPath(testing::TensorFlowSrcRoot(), - controller_config.tf_output_dir); + std::string golden_dir = io::JoinPath(testing::TensorFlowSrcRoot(), + controller_config.tf_output_dir); - string generated_header = generator.HeaderFileContents().Render(); - string generated_source = generator.SourceFileContents().Render(); - string expected_header; - string header_file_name = io::JoinPath(golden_dir, "testing_ops.h.golden"); + std::string generated_header = generator.HeaderFileContents().Render(); + std::string generated_source = generator.SourceFileContents().Render(); + std::string expected_header; + std::string header_file_name = + io::JoinPath(golden_dir, "testing_ops.h.golden"); TF_CHECK_OK(ReadFileToString(env, header_file_name, &expected_header)); - string expected_source; - string source_file_name = io::JoinPath(golden_dir, "testing_ops.cc.golden"); + std::string expected_source; + std::string source_file_name = + io::JoinPath(golden_dir, "testing_ops.cc.golden"); TF_CHECK_OK(ReadFileToString(env, source_file_name, &expected_source)); // Remove carriage returns (for Windows) diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc index 4f0e64e3b0f8eb..7c8231a71133f5 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc @@ -22,7 +22,7 @@ namespace tensorflow { namespace generator { namespace cpp { -CppConfig::CppConfig(const string &category, const string &name_space) +CppConfig::CppConfig(const std::string& category, const std::string& name_space) : category(category), unit(absl::AsciiStrToLower(category)), namespaces(absl::StrSplit(name_space, "::")) {} diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h index fa7571d98a1214..eec5888e17e7cf 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h @@ -24,13 +24,13 @@ namespace generator { namespace cpp { struct CppConfig { - string category; - string unit; - std::vector namespaces; + std::string category; + std::string unit; + std::vector namespaces; explicit CppConfig() = default; - explicit CppConfig(const string &category, - const string &name_space = "tensorflow::ops"); + explicit CppConfig(const std::string& category, + const std::string& name_space = "tensorflow::ops"); }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc index 1a685cac0c405c..50db08df1db988 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.cc @@ -27,10 +27,10 @@ namespace generator { namespace cpp { GuardRenderer::GuardRenderer(RendererContext context) : Renderer(context) { - string self_path = io::JoinPath(context_.path_config.tf_root_dir, - context_.path_config.tf_output_dir, - context_.cpp_config.unit + "_ops.h"); - string with_underscores(self_path); + std::string self_path = io::JoinPath(context_.path_config.tf_root_dir, + context_.path_config.tf_output_dir, + context_.cpp_config.unit + "_ops.h"); + std::string with_underscores(self_path); std::replace(with_underscores.begin(), with_underscores.end(), '/', '_'); std::replace(with_underscores.begin(), with_underscores.end(), '.', '_'); guard_ = toUpperSnake(with_underscores) + "_"; diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h index a45fe89a7a011c..bbd29e4620e2c2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h @@ -31,7 +31,7 @@ class GuardRenderer : public Renderer { void Close(); private: - string guard_; + std::string guard_; }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc index 38f31209f6da24..0ec8108bee7aaf 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.cc @@ -30,13 +30,13 @@ void IncludeRenderer::SelfHeader() { BlankLine(); } -string IncludeRenderer::SelfHeaderPath() const { +std::string IncludeRenderer::SelfHeaderPath() const { return io::JoinPath(context_.path_config.tf_root_dir, context_.path_config.tf_output_dir, context_.cpp_config.unit + "_ops.h"); } -void IncludeRenderer::Include(const string &tf_file_path) { +void IncludeRenderer::Include(const std::string& tf_file_path) { CodeLine("#include \"$0\"", io::JoinPath(context_.path_config.tf_prefix_dir, tf_file_path)); } diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h index e43715a62e45b0..4178f0da5beeb9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h @@ -27,12 +27,12 @@ class IncludeRenderer : public Renderer { public: explicit IncludeRenderer(RendererContext context); - string SelfHeaderPath() const; + std::string SelfHeaderPath() const; void SelfHeader(); void Headers(); private: - void Include(const string &tf_file_path); + void Include(const std::string& tf_file_path); }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc index db28ab303ae5c6..b490cc7fe9e86a 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.cc @@ -26,7 +26,7 @@ NamespaceRenderer::NamespaceRenderer(RendererContext context) : Renderer(context) {} void NamespaceRenderer::Open() { - for (const string& ns : context_.cpp_config.namespaces) { + for (const std::string& ns : context_.cpp_config.namespaces) { CodeLine("namespace " + ns + " {"); } } diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc index c459d239ca699f..63cb5f30eb1d9d 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.cc @@ -31,11 +31,11 @@ namespace tensorflow { namespace generator { namespace cpp { -string OpRenderer::Signature() const { - std::vector args_with_default_val; - std::vector args_without_default_val; +std::string OpRenderer::Signature() const { + std::vector args_with_default_val; + std::vector args_without_default_val; for (OpArgumentView const& argument : op_.AllArguments()) { - string text = argument.Declaration(); + std::string text = argument.Declaration(); if (context_.mode == RendererContext::kHeader) { absl::StrAppend(&text, argument.Initializer()); } @@ -45,7 +45,7 @@ string OpRenderer::Signature() const { args_without_default_val.push_back(text); } } - std::vector arguments; + std::vector arguments; arguments.reserve(args_without_default_val.size() + args_with_default_val.size()); arguments.insert(arguments.end(), diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h index 3360e14e672e3a..1ea161f55bdad9 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h @@ -34,7 +34,7 @@ class OpRenderer : public Renderer { OpView op_; OpCommentRenderer comment_; - string Signature() const; + std::string Signature() const; }; } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc index a9efb94335c0a6..6a608d759a3753 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc @@ -34,21 +34,21 @@ Renderer& Renderer::BlankLine() { return *this; } -Renderer& Renderer::CodeLine(const string& text) { +Renderer& Renderer::CodeLine(const std::string& text) { context_.code.AddLineWithoutIndent(text); return *this; } -Renderer& Renderer::CodeLines(const string& text) { +Renderer& Renderer::CodeLines(const std::string& text) { absl::string_view trimmed_text(text); str_util::RemoveWhitespaceContext(&trimmed_text); - for (const string& line : str_util::Split(trimmed_text, '\n')) { + for (const std::string& line : str_util::Split(trimmed_text, '\n')) { context_.code.AddLineWithoutIndent(line); } return *this; } -Renderer& Renderer::Statement(const string& text) { +Renderer& Renderer::Statement(const std::string& text) { if (absl::EndsWith(text, ";")) { LOG(WARNING) << "Superfluous terminating ';' in '" << text << "'"; context_.code.AddLineWithIndent(text); @@ -58,22 +58,22 @@ Renderer& Renderer::Statement(const string& text) { return *this; } -Renderer& Renderer::TFStatement(const string& text) { +Renderer& Renderer::TFStatement(const std::string& text) { return Statement(absl::Substitute("TF_RETURN_IF_ERROR($0)", text)); } -Renderer& Renderer::CommentLine(const string& text) { +Renderer& Renderer::CommentLine(const std::string& text) { context_.code.AddLineWithIndent(absl::StrCat("// ", text)); return *this; } -Renderer& Renderer::BlockOpen(const string& text) { +Renderer& Renderer::BlockOpen(const std::string& text) { context_.code.AddLineWithIndent(absl::StrCat(text, " {")); context_.code.IncreaseIndent(); return *this; } -Renderer& Renderer::BlockClose(const string& text) { +Renderer& Renderer::BlockClose(const std::string& text) { context_.code.DecreaseIndent(); context_.code.AddLineWithIndent(absl::StrCat("}", text)); return *this; diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h index b6168b196b35b2..f41923651f44e2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h @@ -34,7 +34,7 @@ class Renderer { // Append a line of source code, left-justified (not indented). // Use for preprocessors directives ("#include"), namespaces, etc. - Renderer &CodeLine(const string &text); + Renderer& CodeLine(const std::string& text); template Renderer CodeLine(absl::string_view text, const Args &...args) { return CodeLine(absl::Substitute(text, args...)); @@ -44,7 +44,7 @@ class Renderer { // Note: Trims leading/trailing whitespace including newlines, making this // method convenient for multiline raw strings. // Newlines ('\n') are allowed/expected. - Renderer &CodeLines(const string &text); + Renderer& CodeLines(const std::string& text); template Renderer CodeLines(absl::string_view text, const Args &...args) { return CodeLines(absl::Substitute(text, args...)); @@ -52,7 +52,7 @@ class Renderer { // Indent and append a C++ statement. // Note: do *not* include a trailing semicolon in the statement text. - Renderer &Statement(const string &text); + Renderer& Statement(const std::string& text); template Renderer Statement(absl::string_view text, const Args &...args) { return Statement(absl::Substitute(text, args...)); @@ -60,14 +60,14 @@ class Renderer { // Indent and append a call to a TF method returning a Status to check. // Note: do *not* include a trailing semicolon in the statement text. - Renderer &TFStatement(const string &text); + Renderer& TFStatement(const std::string& text); template Renderer TFStatement(absl::string_view text, const Args &...args) { return TFStatement(absl::Substitute(text, args...)); } // Indent and append a C++ single-line style comment (using '//'). - Renderer &CommentLine(const string &text = ""); + Renderer& CommentLine(const std::string& text = ""); template Renderer CommentLine(absl::string_view text, const Args &...args) { return CommentLine(absl::Substitute(text, args...)); @@ -75,7 +75,7 @@ class Renderer { // Append a line of code which starts a new block: trailing with '{') and // indenting. - Renderer &BlockOpen(const string &text); + Renderer& BlockOpen(const std::string& text); template Renderer BlockOpen(absl::string_view text, const Args &...args) { return BlockOpen(absl::Substitute(text, args...)); @@ -83,7 +83,7 @@ class Renderer { // Append a line of code ending a block: unindenting and adding '}'. // Note: optional trailing text is often a comment, e.g. '// namespace xyz'. - Renderer &BlockClose(const string &text = ""); + Renderer& BlockClose(const std::string& text = ""); template Renderer BlockClose(absl::string_view text, const Args &...args) { return BlockClose(absl::Substitute(text, args...)); diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc index eff654c5938160..6621d1aea2c217 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_test.cc @@ -57,7 +57,7 @@ TEST(Renderer, typical_usage) { SourceCode code; TestRenderer(code).Render(); - string expected = R"(// File level comment. + std::string expected = R"(// File level comment. #include "header.h" void TestFunction() { diff --git a/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc b/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc index 18a506942de5b7..cb922d0a06b7ae 100644 --- a/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc +++ b/tensorflow/c/experimental/ops/gen/generate_cpp_main.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "absl/log/check.h" diff --git a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc index c2bf61d785e6b2..417a0f26d70b92 100644 --- a/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc +++ b/tensorflow/c/experimental/saved_model/core/object_graph_traversal_test.cc @@ -26,8 +26,7 @@ namespace { SavedObjectGraph ParseSavedObjectGraph(absl::string_view text_proto) { SavedObjectGraph value; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), - &value)); + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &value)); return value; } diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc index 6250af6dba1359..1796c99dc79f17 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc @@ -178,8 +178,7 @@ tuple_value: { StructuredValue ParseStructuredValue(absl::string_view text_proto) { StructuredValue value; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), - &value)); + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &value)); return value; } diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc index 64ff3dab035e8c..c44bc832547dab 100644 --- a/tensorflow/c/kernels/bitcast_op_test.cc +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -86,13 +86,13 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type, TEST(BitcastOpTest, TestUpcast) { Tensor int8_input(DT_UINT8, {8}); for (int i = 0; i < 8; i++) { - int8_input.vec()(i) = static_cast(1); + int8_input.vec()(i) = static_cast(1); } TestBitcastOp(&int8_input, DT_UINT64, TensorShape(), error::OK); } TEST(BitcastOpTest, TestDowncast) { - Tensor int64_input(static_cast(1)); + Tensor int64_input(static_cast(1)); TestBitcastOp(&int64_input, DT_UINT8, TensorShape({8}), error::OK); } diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc index 7f34e5217c20ba..35340baa5749ce 100644 --- a/tensorflow/c/kernels/histogram_summary_op.cc +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -151,13 +151,13 @@ void RegisterHistogramSummaryOpKernel() { TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() { if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) { RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); - RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); + RegisterHistogramSummaryOpKernel(); RegisterHistogramSummaryOpKernel(); RegisterHistogramSummaryOpKernel(); RegisterHistogramSummaryOpKernel(); diff --git a/tensorflow/c/kernels/merge_summary_op.cc b/tensorflow/c/kernels/merge_summary_op.cc index 339267d094a554..ddbc3440d47dc1 100644 --- a/tensorflow/c/kernels/merge_summary_op.cc +++ b/tensorflow/c/kernels/merge_summary_op.cc @@ -50,7 +50,7 @@ void MergeSummaryOp_Delete(void* kernel) {} void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { tensorflow::Summary s; - std::unordered_set tags; + std::unordered_set tags; Safe_TF_StatusPtr status(TF_NewStatus()); for (int input_num = 0; input_num < TF_NumInputs(ctx); ++input_num) { TF_Tensor* input; @@ -74,7 +74,7 @@ void MergeSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { for (int v = 0; v < summary_in.value_size(); ++v) { // This tag is unused by the TensorSummary op, so no need to check for // duplicates. - const tensorflow::string& tag = summary_in.value(v).tag(); + const std::string& tag = summary_in.value(v).tag(); if ((!tag.empty()) && !tags.insert(tag).second) { std::ostringstream err; err << "Duplicate tag " << tag << " found in summary inputs "; diff --git a/tensorflow/c/kernels/summary_op.cc b/tensorflow/c/kernels/summary_op.cc index d158a429433c40..5688d00fa8fa7c 100644 --- a/tensorflow/c/kernels/summary_op.cc +++ b/tensorflow/c/kernels/summary_op.cc @@ -155,13 +155,13 @@ void RegisterScalarSummaryOpKernel() { TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() { if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) { RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); - RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); RegisterScalarSummaryOpKernel(); RegisterScalarSummaryOpKernel(); RegisterScalarSummaryOpKernel(); diff --git a/tensorflow/c/kernels/summary_op_test.cc b/tensorflow/c/kernels/summary_op_test.cc index da7b92f99491df..43de49bc39419d 100644 --- a/tensorflow/c/kernels/summary_op_test.cc +++ b/tensorflow/c/kernels/summary_op_test.cc @@ -45,13 +45,15 @@ class DummyDevice : public DeviceBase { }; // Helper for comparing output and expected output -void ExpectSummaryMatches(const Summary& actual, const string& expected_str) { +void ExpectSummaryMatches(const Summary& actual, + const std::string& expected_str) { Summary expected; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected)); EXPECT_EQ(expected.DebugString(), actual.DebugString()); } -void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output, +void TestScalarSummaryOp(Tensor* tags, Tensor* values, + std::string expected_output, error::Code expected_code) { // Initialize node used to fetch OpKernel absl::Status status; diff --git a/tensorflow/c/kernels/tensor_shape_utils.cc b/tensorflow/c/kernels/tensor_shape_utils.cc index db0cfefedcbc86..ba54dc4eda4df9 100644 --- a/tensorflow/c/kernels/tensor_shape_utils.cc +++ b/tensorflow/c/kernels/tensor_shape_utils.cc @@ -26,7 +26,7 @@ namespace tensorflow { std::string ShapeDebugString(TF_Tensor* tensor) { // A TF_Tensor cannot have an unknown rank. CHECK_GE(TF_NumDims(tensor), 0); - tensorflow::string s = "["; + std::string s = "["; for (int i = 0; i < TF_NumDims(tensor); ++i) { if (i > 0) absl::StrAppend(&s, ","); int64_t dim = TF_Dim(tensor, i); diff --git a/tensorflow/c/logging.cc b/tensorflow/c/logging.cc deleted file mode 100644 index 13c9e6ac208a14..00000000000000 --- a/tensorflow/c/logging.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/logging.h" - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stringprintf.h" - -static ::tensorflow::string BuildMessage(const char* fmt, va_list args) { - ::tensorflow::string message; - ::tensorflow::strings::Appendv(&message, fmt, args); - return message; -} - -void TF_Log(TF_LogLevel level, const char* fmt, ...) { - if (level < TF_INFO || level > TF_FATAL) return; - va_list args; - va_start(args, fmt); - auto message = BuildMessage(fmt, args); - va_end(args); - switch (level) { - case TF_INFO: - LOG(INFO) << message; - break; - case TF_WARNING: - LOG(WARNING) << message; - break; - case TF_ERROR: - LOG(ERROR) << message; - break; - case TF_FATAL: - LOG(FATAL) << message; - break; - } -} - -void TF_VLog(int level, const char* fmt, ...) { - va_list args; - va_start(args, fmt); - auto message = BuildMessage(fmt, args); - va_end(args); - VLOG(level) << message; -} - -void TF_DVLog(int level, const char* fmt, ...) { - va_list args; - va_start(args, fmt); - auto message = BuildMessage(fmt, args); - va_end(args); - DVLOG(level) << message; -} diff --git a/tensorflow/c/logging.h b/tensorflow/c/logging.h deleted file mode 100644 index 9583777b661122..00000000000000 --- a/tensorflow/c/logging.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_LOGGING_H_ -#define TENSORFLOW_C_LOGGING_H_ - -#include "tensorflow/c/c_api_macros.h" - -// -------------------------------------------------------------------------- -// C API for tensorflow::Logging. - -#ifdef __cplusplus -extern "C" { -#endif - -typedef enum TF_LogLevel { - TF_INFO = 0, - TF_WARNING = 1, - TF_ERROR = 2, - TF_FATAL = 3, -} TF_LogLevel; - -TF_CAPI_EXPORT extern void TF_Log(TF_LogLevel level, const char* fmt, ...); -TF_CAPI_EXPORT extern void TF_VLog(int level, const char* fmt, ...); -TF_CAPI_EXPORT extern void TF_DVLog(int level, const char* fmt, ...); - -#ifdef __cplusplus -} -#endif - -#endif // TENSORFLOW_C_LOGGING_H_ diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h index 02a38e9b164eb3..c991fc1f74f2e8 100644 --- a/tensorflow/c/tf_datatype.h +++ b/tensorflow/c/tf_datatype.h @@ -65,6 +65,7 @@ typedef enum TF_DataType { TF_UINT4 = 30, TF_INT2 = 31, TF_UINT2 = 32, + TF_FLOAT4_E2M1FN = 33 // 2 exponent bits, 1 mantissa bit, finite-only } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding diff --git a/tensorflow/cc/framework/cc_op_gen_util.cc b/tensorflow/cc/framework/cc_op_gen_util.cc index 45c88283a47a6c..048378e68f4525 100644 --- a/tensorflow/cc/framework/cc_op_gen_util.cc +++ b/tensorflow/cc/framework/cc_op_gen_util.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/cc/framework/cc_op_gen_util.h" -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -107,10 +107,10 @@ string ToGuard(absl::string_view path) { string guard; guard.reserve(path.size() + 1); // + 1 -> trailing _ for (const char c : path) { - if (c >= 'A' && c <= 'Z') { + if (absl::ascii_isupper(c)) { guard += c; - } else if (c >= 'a' && c <= 'z') { - guard += c + 'A' - 'a'; + } else if (absl::ascii_islower(c)) { + guard += absl::ascii_toupper(c); } else { guard += '_'; } @@ -306,7 +306,7 @@ string ToCamelCase(absl::string_view str) { } else if (c == joiner) { cap = true; } else if (cap) { - result += toupper(c); + result += absl::ascii_toupper(c); cap = false; } else { result += c; diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc index dcac1e4c0373bd..cd332ed1791849 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc @@ -42,7 +42,7 @@ namespace tensorflow { namespace cc_op { namespace { -string DefaultValue(OpDef_AttrDef attr) { +std::string DefaultValue(OpDef_AttrDef attr) { static const auto* attr_default_value_map = new absl::flat_hash_map{ @@ -80,19 +80,19 @@ string DefaultValue(OpDef_AttrDef attr) { return std::string(entry->second); } -string WriteClassFuzzDef(const OpInfo& op_info) { - string class_signature_str = absl::Substitute( +std::string WriteClassFuzzDef(const OpInfo& op_info) { + std::string class_signature_str = absl::Substitute( "class Fuzz$0 : public FuzzSession<$1> {\n", op_info.op_name, absl::StrJoin(op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { absl::StrAppend(out, "Tensor"); if (ArgIsList(arg)) absl::StrAppend(out, ", Tensor"); })); - string build_graph_body = absl::StrCat( + std::string build_graph_body = absl::StrCat( absl::StrJoin( op_info.graph_op_def.input_arg(), "", - [op_info](string* out, const OpDef_ArgDef arg) { + [op_info](std::string* out, const OpDef_ArgDef arg) { std::string type = "DT_UINT8"; if (arg.type() != DT_INVALID) { @@ -130,7 +130,7 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } }), absl::StrJoin(op_info.graph_op_def.attr(), "", - [op_info](string* out, const OpDef_AttrDef attr) { + [op_info](std::string* out, const OpDef_AttrDef attr) { if (op_info.inferred_input_attrs.count(attr.name()) == 0 && !attr.has_default_value()) { @@ -139,22 +139,22 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } })); - string constructor_call_str = absl::Substitute( + std::string constructor_call_str = absl::Substitute( " tensorflow::ops::$0(scope.WithOpName(\"output\")$1);\n", op_info.op_name, absl::StrCat( op_info.api_def.arg_order().empty() ? absl::StrJoin(op_info.api_def.in_arg(), "", - [](string* out, const auto api_def_arg) { + [](std::string* out, const auto api_def_arg) { strings::StrAppend(out, ", ", api_def_arg.name()); }) : absl::StrJoin(op_info.api_def.arg_order(), "", - [](string* out, const auto name) { + [](std::string* out, const auto name) { strings::StrAppend(out, ", ", name); }), absl::StrJoin(op_info.graph_op_def.attr(), "", - [op_info](string* out, const OpDef_AttrDef attr) { + [op_info](std::string* out, const OpDef_AttrDef attr) { if (op_info.inferred_input_attrs.count(attr.name()) == 0 && !attr.has_default_value()) { @@ -162,20 +162,20 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } }))); - string fuzz_impl_signature_str = absl::Substitute( + std::string fuzz_impl_signature_str = absl::Substitute( " void FuzzImpl($0) final {\n", absl::StrJoin( op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { strings::StrAppend(out, "const Tensor& ", arg.name(), "_0"); if (ArgIsList(arg)) strings::StrAppend(out, ", const Tensor& ", arg.name(), "_1"); })); - string run_inputs_str = absl::Substitute( + std::string run_inputs_str = absl::Substitute( " RunInputs({$0});\n", absl::StrJoin(op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { if (ArgIsList(arg)) { strings::StrAppend( out, "{\"", arg.name(), "\", ", arg.name(), "_0}, ", @@ -186,7 +186,7 @@ string WriteClassFuzzDef(const OpInfo& op_info) { } })); - string fuzz_class_def = strings::StrCat( + std::string fuzz_class_def = strings::StrCat( class_signature_str, " void BuildGraph(const Scope& scope) override {\n", build_graph_body, constructor_call_str, " }\n", fuzz_impl_signature_str, run_inputs_str, " }\n", "};\n"); @@ -194,24 +194,24 @@ string WriteClassFuzzDef(const OpInfo& op_info) { return fuzz_class_def; } -string WriteFuzzTest(const OpInfo& op_info) { +std::string WriteFuzzTest(const OpInfo& op_info) { return absl::Substitute( "FUZZ_TEST_F(Fuzz$0, Fuzz).WithDomains($1);\n", op_info.op_name, absl::StrJoin(op_info.graph_op_def.input_arg(), ", ", - [](string* out, const auto arg) { + [](std::string* out, const auto arg) { absl::StrAppend(out, "AnyTensor()"); if (ArgIsList(arg)) absl::StrAppend(out, ", AnyTensor()"); })); } -string FuzzerFileStart() { - const string fuzz_namespace_begin = R"namespace( +std::string FuzzerFileStart() { + const std::string fuzz_namespace_begin = R"namespace( namespace tensorflow { namespace fuzzing { )namespace"; - const string fuzz_header = + const std::string fuzz_header = absl::StrCat(R"include(// This file is MACHINE GENERATED! Do not edit. #include "tensorflow/cc/ops/const_op.h" @@ -224,8 +224,8 @@ namespace fuzzing { return fuzz_header; } -string FuzzerFileEnd() { - const string fuzz_footer = R"footer( +std::string FuzzerFileEnd() { + const std::string fuzz_footer = R"footer( } // namespace fuzzing } // namespace tensorflow )footer"; @@ -258,7 +258,7 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { } // TODO(unda) : zero input ops - std::set zero_input_ops = {"Placeholder", "ImmutableConst"}; + std::set zero_input_ops = {"Placeholder", "ImmutableConst"}; if (zero_input_ops.find(op_info.op_name) != zero_input_ops.end()) { std::cout << "NOT fuzzing: " << op_info.graph_op_def.name() << " takes zero inputs.\n"; @@ -266,19 +266,19 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { } // TODO(unda, 253431636): constrained kernel - std::set constrained_kernel = {"Diag", - "DiagPart", - "GatherNd", - "GatherV2", - "QuantizeAndDequantizeV2", - "QuantizeAndDequantizeV3", - "QuantizeAndDequantizeV4", - "QuantizeAndDequantizeV4Grad", - "QuantizedConcat", - "QuantizedInstanceNorm", - "QuantizedReshape", - "ScatterNd", - "TensorScatterUpdate"}; + std::set constrained_kernel = {"Diag", + "DiagPart", + "GatherNd", + "GatherV2", + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "QuantizeAndDequantizeV4", + "QuantizeAndDequantizeV4Grad", + "QuantizedConcat", + "QuantizedInstanceNorm", + "QuantizedReshape", + "ScatterNd", + "TensorScatterUpdate"}; // TODO(unda, b/253431636): constrained kernel if (constrained_kernel.find(op_info.op_name) != constrained_kernel.end()) { @@ -297,7 +297,7 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { } } - std::set unhandled_attr_types = { + std::set unhandled_attr_types = { "list(type)", "func", "float", "bool", "tensor", "list(string)", "list(bool)", "list(shape)", "list(tensor)", "list(attr)"}; @@ -321,7 +321,7 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { return true; } -string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable) { +std::string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable) { return absl::StrCat( FuzzerFileStart(), is_fuzzable ? WriteClassFuzzDef(op_info) : "", is_fuzzable ? WriteFuzzTest(op_info) : "", FuzzerFileEnd()); diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h index c11c9635d6d149..9dfee93e55e2e1 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h @@ -25,7 +25,7 @@ namespace tensorflow { namespace cc_op { // String with single fuzzer file content. -string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable); +std::string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable); // Do we have all we need to create a fuzzer bool OpFuzzingIsOk(const OpInfo& op_info); diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index f4a1eb642557de..6da6e2af6c3445 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -39,8 +39,9 @@ namespace tensorflow { namespace cc_op { namespace { -void WriteAllFuzzers(string root_location, std::vector api_def_dirs, - std::vector op_names) { +void WriteAllFuzzers(std::string root_location, + std::vector api_def_dirs, + std::vector op_names) { OpList ops; absl::StatusOr api_def_map = LoadOpsAndApiDefs(ops, false, api_def_dirs); @@ -60,7 +61,7 @@ void WriteAllFuzzers(string root_location, std::vector api_def_dirs, continue; } - OpInfo op_info(op_def, *api_def, std::vector()); + OpInfo op_info(op_def, *api_def, std::vector()); status.Update(env->NewWritableFile( root_location + "/" + op_def.name() + "_fuzz.cc", &fuzz_file)); status.Update( @@ -87,9 +88,9 @@ int main(int argc, char* argv[]) { for (int i = 1; i < argc; ++i) { fprintf(stdout, "Arg %d = %s\n", i, argv[i]); } - std::vector api_def_srcs = tensorflow::str_util::Split( + std::vector api_def_srcs = tensorflow::str_util::Split( argv[2], ",", tensorflow::str_util::SkipEmpty()); - std::vector op_names = tensorflow::str_util::Split( + std::vector op_names = tensorflow::str_util::Split( argv[3], ",", tensorflow::str_util::SkipEmpty()); tensorflow::cc_op::WriteAllFuzzers(argv[1], api_def_srcs, op_names); return 0; diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 357515a5dccb00..f3c3fd045a3d6f 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -218,9 +218,9 @@ REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad); absl::Status CheckNumericsGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string message; + std::string message; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message)); - string err_msg = absl::StrCat( + std::string err_msg = absl::StrCat( "Not a number (NaN) or infinity (Inf) values detected in gradient. ", message); grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg)); @@ -411,7 +411,7 @@ REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad); absl::Status MirrorPadGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string mode; + std::string mode; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad( scope, grad_inputs[0], op.input(1), mode)); @@ -424,7 +424,7 @@ REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad); absl::Status MirrorPadGradGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string mode; + std::string mode; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode)); grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode)); grad_outputs->push_back(NoGradient()); diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index 77e2a3bfc38476..deb90eec264ee7 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -95,7 +95,7 @@ absl::Status ScaleAndTranslateGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string kernel_type; + std::string kernel_type; TF_RETURN_IF_ERROR( GetNodeAttr(op.node()->attrs(), "kernel_type", &kernel_type)); bool antialias; @@ -117,7 +117,7 @@ absl::Status CropAndResizeGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { DataType input_type; - string method; + std::string method; TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "method", &method)); TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "T", &input_type)); auto image_shape = Shape(scope, op.input(0)); diff --git a/tensorflow/cc/gradients/image_grad_test.cc b/tensorflow/cc/gradients/image_grad_test.cc index f7a39f39cfc42a..b77f5512237024 100644 --- a/tensorflow/cc/gradients/image_grad_test.cc +++ b/tensorflow/cc/gradients/image_grad_test.cc @@ -203,7 +203,7 @@ class ScaleAndTranslateGradTest : public ::testing::Test { template void MakeOp(const Tensor& x_data, const Input& y_shape, Input scale, - Input translation, const string& kernel_type, bool antialias, + Input translation, const std::string& kernel_type, bool antialias, Output* x, Output* y) { *x = Const(scope_, x_data); *y = ScaleAndTranslate(scope_, *x, y_shape, scale, translation, @@ -216,7 +216,7 @@ class ScaleAndTranslateGradTest : public ::testing::Test { template void TestScaleAndTranslate(const TensorShape x_shape, const int out_height, const int out_width, Input scale, - Input translation, const string& kernel_type, + Input translation, const std::string& kernel_type, bool antialias) { Tensor x_data = MakeData(x_shape); Output x, y; diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index bf6f509c21ee8a..c785af15f95447 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -1070,8 +1070,8 @@ absl::Status MatMulGradHelper(const Scope& scope, const bool is_batch, absl::Status MatMulGradCommon(const Scope& scope, const Operation& op, const bool is_batch, const std::vector& grad_inputs, - const string& attr_adj_x, - const string& attr_adj_y, + const std::string& attr_adj_x, + const std::string& attr_adj_y, std::vector* grad_outputs) { auto a = op.input(0); auto b = op.input(1); diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 34c0a8fd54b4c4..6309080492c1da 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -54,7 +54,7 @@ absl::Status SoftmaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); bool IsZero(const Scope& scope, const Output& grad) { - string op_type_name = grad.op().node()->type_string(); + std::string op_type_name = grad.op().node()->type_string(); if (op_type_name == "ZerosLike" || op_type_name == "Zeros") { return true; } @@ -204,7 +204,7 @@ REGISTER_GRADIENT_OP("L2Loss", L2LossGrad); absl::Status BiasAddGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; + std::string data_format; TF_RETURN_IF_ERROR( GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format)); auto dx_1 = @@ -218,9 +218,9 @@ REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper); absl::Status Conv2DGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; - string padding; - std::vector strides; + std::string data_format; + std::string padding; + std::vector strides; bool use_cudnn_on_gpu; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); @@ -245,10 +245,10 @@ REGISTER_GRADIENT_OP("Conv2D", Conv2DGrad); absl::Status MaxPoolGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; - string padding; - std::vector strides; - std::vector ksize; + std::string data_format; + std::string padding; + std::vector strides; + std::vector ksize; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); @@ -265,8 +265,8 @@ REGISTER_GRADIENT_OP("MaxPool", MaxPoolGradHelper); absl::Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - string data_format; - string padding; + std::string data_format; + std::string padding; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); @@ -283,10 +283,10 @@ REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); absl::Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - std::vector ksize; - std::vector strides; - string padding; - string data_format; + std::vector ksize; + std::vector strides; + std::string padding; + std::string data_format; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); @@ -304,10 +304,10 @@ REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); absl::Status AvgPoolGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - std::vector ksize; - std::vector strides; - string padding; - string data_format; + std::vector ksize; + std::vector strides; + std::string padding; + std::string data_format; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); @@ -325,10 +325,10 @@ REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); absl::Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - std::vector ksize; - std::vector strides; - string padding; - string data_format; + std::vector ksize; + std::vector strides; + std::string padding; + std::string data_format; auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 56ac37e86b7168..1d23f9d87e2d7d 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include #include +#include #include #include "absl/log/log.h" @@ -70,7 +72,7 @@ absl::Status QueueRunner::Init(const QueueRunnerDef& queue_runner_def) { queue_runner_def.enqueue_op_name().begin(), queue_runner_def.enqueue_op_name().end()); size_t op_names_size = enqueue_op_names_.size(); - if (op_names_size > kint32max) { + if (op_names_size > std::numeric_limits::max()) { return absl::Status(absl::StatusCode::kInvalidArgument, "Enqueue ops to run cannot exceed kint32max"); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index ddcf94fbc07951..1722da0d390915 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -51,8 +51,8 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/backends/cpu/runtime:convolution_lib", - "@local_xla//xla/backends/cpu/runtime:dot_lib", + "@local_xla//xla/backends/cpu/runtime:convolution_dims", + "@local_xla//xla/backends/cpu/runtime:dot_dims", "@local_xla//xla/backends/cpu/runtime:thunk_proto_cc", "@local_xla//xla/service/cpu:executable_proto_cc", "@local_xla//xla/tsl/platform:statusor", @@ -96,6 +96,7 @@ cc_library( ":thunk_proto_execution_deserializer", "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:allocator", + "//tensorflow/compiler/tf2xla:encoded_buffer_allocation_info", "//tensorflow/compiler/tf2xla:mlir_tf2xla", # fixdeps: keep "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", "//tensorflow/compiler/tf2xla:tf2xla_util", diff --git a/tensorflow/compiler/aot/aot_only_var_handle_op.cc b/tensorflow/compiler/aot/aot_only_var_handle_op.cc index 86666b073b0f71..f6293e0a2063bb 100644 --- a/tensorflow/compiler/aot/aot_only_var_handle_op.cc +++ b/tensorflow/compiler/aot/aot_only_var_handle_op.cc @@ -31,7 +31,7 @@ class XlaAotOnlyVarHandleOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override; private: - string name_; + std::string name_; }; XlaAotOnlyVarHandleOp::XlaAotOnlyVarHandleOp(OpKernelConstruction* c) diff --git a/tensorflow/compiler/aot/benchmark.cc b/tensorflow/compiler/aot/benchmark.cc index 43b9c06418c2e1..ee4af4ca65a20f 100644 --- a/tensorflow/compiler/aot/benchmark.cc +++ b/tensorflow/compiler/aot/benchmark.cc @@ -37,10 +37,10 @@ namespace benchmark { // // TODO(b/33546473): Refactor tensorflow::Env::NowMicros() so that we can re-use // the implementation without pulling in all of the Env dependencies. -static uint64 NowMicros() { +static uint64_t NowMicros() { struct timeval tv; gettimeofday(&tv, nullptr); - return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; } void DumpStatsToStdout(const Stats& stats) { diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index f4969b93353e42..783dc69b6ad5c2 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/aot/thunk_proto_execution_deserializer.h" #include "tensorflow/compiler/tf2xla/allocator.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "xla/backends/cpu/buffer_allocation_info.h" @@ -67,41 +69,35 @@ namespace { using xla::cpu::BufferAllocationInfo; -bool IsAlpha(char c) { - return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); -} - -bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } - // Convert an XLA type into a C++ type. -absl::Status XLATypeToCpp(xla::PrimitiveType type, string* str) { +absl::Status XLATypeToCpp(xla::PrimitiveType type, std::string* str) { switch (type) { case xla::PRED: *str = "bool"; break; case xla::S8: - *str = "tensorflow::int8"; + *str = "int8_t"; break; case xla::S16: - *str = "tensorflow::int16"; + *str = "int16_t"; break; case xla::S32: - *str = "tensorflow::int32"; + *str = "int32_t"; break; case xla::S64: *str = "int64_t"; break; case xla::U8: - *str = "tensorflow::uint8"; + *str = "uint8_t"; break; case xla::U16: - *str = "tensorflow::uint16"; + *str = "uint16_t"; break; case xla::U32: - *str = "tensorflow::uint32"; + *str = "uint32_t"; break; case xla::U64: - *str = "tensorflow::uint64"; + *str = "uint64_t"; break; case xla::F32: *str = "float"; @@ -155,11 +151,11 @@ std::vector ExtractTempBufferAllocationInfos( // are used to generate methods for args and results. absl::Status AddRewritesForShape( int i, const xla::Shape& shape, - std::vector>* rewrites) { - string type; + std::vector>* rewrites) { + std::string type; TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type)); - std::vector dim_vars; - string dim_sizes, indices; + std::vector dim_vars; + std::string dim_sizes, indices; int count = 1; if (shape.dimensions().size() == 0 || (shape.dimensions().size() == 1 && shape.dimensions(0) == 1)) { @@ -168,8 +164,8 @@ absl::Status AddRewritesForShape( } else { for (int dim = 0; dim < shape.dimensions().size(); ++dim) { dim_vars.push_back(absl::StrCat("size_t dim", dim)); - dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); - indices += absl::StrCat("[dim", dim, "]"); + absl::StrAppend(&dim_sizes, "[", shape.dimensions(dim), "]"); + absl::StrAppend(&indices, "[dim", dim, "]"); count *= shape.dimensions(dim); } } @@ -190,8 +186,9 @@ absl::Status AddRewritesForShape( // TODO(toddw): If this becomes a problem, we should be able to change the // algorithm to O(N) by using a state machine, e.g. regexps or a real // text-templating mechanism. -string RewriteWithName(const string& name, string code, - const std::vector>& rewrites) { +std::string RewriteWithName( + const std::string& name, std::string code, + const std::vector>& rewrites) { absl::StrReplaceAll(rewrites, &code); absl::StrReplaceAll({{"{{NAME}}", name}}, &code); return code; @@ -201,7 +198,7 @@ string RewriteWithName(const string& name, string code, absl::Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, - string* methods) { + std::string* methods) { const int num_args = ps.parameters_size(); // feed_size() + variable_size() is the maximum number of args as an // implementation may not create an argument for an unused variable. @@ -211,11 +208,11 @@ absl::Status GenArgMethods(const tf2xla::Config& config, config.variable_size(), ") and num_args(", num_args, ")"); } for (int i = 0; i < config.feed_size(); ++i) { - std::vector> rewrites; + std::vector> rewrites; TF_ASSIGN_OR_RETURN(xla::Shape shape, xla::Shape::FromProto(ps.parameters(i))); TF_RETURN_IF_ERROR(AddRewritesForShape(i, shape, &rewrites)); - const string code = R"( + const std::string code = R"( void set_arg{{NAME}}_data(const void* data) { set_arg_data({{I}}, data); } @@ -251,7 +248,7 @@ absl::Status GenArgMethods(const tf2xla::Config& config, // Generate methods for results (outputs). absl::Status GenResultMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, - string* methods) { + std::string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -270,11 +267,11 @@ absl::Status GenResultMethods(const tf2xla::Config& config, ps.result().tuple_shapes_size(), ")"); } for (int i = 0; i < config.fetch_size(); ++i) { - std::vector> rewrites; + std::vector> rewrites; TF_ASSIGN_OR_RETURN(xla::Shape shape, xla::Shape::FromProto(ps.result().tuple_shapes(i))); TF_RETURN_IF_ERROR(AddRewritesForShape(i, shape, &rewrites)); - string code = R"( + std::string code = R"( {{TYPE}}* result{{NAME}}_data() { return static_cast<{{TYPE}}*>(result_data({{I}})); } @@ -307,14 +304,14 @@ absl::Status GenResultMethods(const tf2xla::Config& config, // Generate methods for variables. absl::Status GenVariableMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, - string* methods) { + std::string* methods) { const int num_args = ps.parameters_size(); for (int i = config.feed_size(); i < num_args; ++i) { - std::vector> rewrites; + std::vector> rewrites; TF_ASSIGN_OR_RETURN(xla::Shape shape, xla::Shape::FromProto(ps.parameters(i))); TF_RETURN_IF_ERROR(AddRewritesForShape(i, shape, &rewrites)); - const string code = R"( + const std::string code = R"( void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) { set_arg_data({{I}}, data); } @@ -348,7 +345,8 @@ absl::Status GenVariableMethods(const tf2xla::Config& config, } // Generate shape infos for args (inputs). -absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { +absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, + std::string* infos) { for (int i = 0; i < ps.parameters_size(); ++i) { const xla::ShapeProto& shape = ps.parameters(i); if (shape.element_type() == xla::TUPLE) { @@ -386,7 +384,7 @@ absl::Status GenArgShapeInfos(const xla::ProgramShapeProto& ps, string* infos) { // Generate shape infos for results. absl::Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, - string* infos) { + std::string* infos) { if (ps.result().element_type() != xla::TUPLE) { return absl::InternalError("codegen requires the XLA result to be a tuple"); } @@ -420,7 +418,7 @@ absl::Status GenResultShapeInfos(const xla::ProgramShapeProto& ps, // tf2xla::{Feed,Fetch,Variable}. Each feed or fetch name results in a C-style // string literal in the array, with nullptr terminating the array. template -string GenNameToIndexCode(const T& entries, bool generate) { +std::string GenNameToIndexCode(const T& entries, bool generate) { // No need for a static array if we're not supposed to generate the data. if (!generate) { return "{\n return nullptr;\n }"; @@ -435,7 +433,7 @@ string GenNameToIndexCode(const T& entries, bool generate) { end = i; } // Emit string literals up to the last non-empty name. - string code = "{\n static const char* kNames[] = {"; + std::string code = "{\n static const char* kNames[] = {"; for (int i = 0; i < end; ++i) { if (i > 0) { code += ", "; @@ -704,13 +702,13 @@ absl::Status ExtendRewrites( if (HasThunkKind(aot_thunks->proto().thunk_sequence(), xla::cpu::ThunkProto::kSortThunk)) { runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_key_value_sort.h")"); + R"(#include "xla/backends/cpu/runtime/sort_lib.h")"); } if (HasThunkKind(aot_thunks->proto().thunk_sequence(), xla::cpu::ThunkProto::kTopKThunk)) { runtime_specific_includes.push_back( - R"(#include "xla/service/cpu/runtime_topk.h")"); + R"(#include "xla/backends/cpu/runtime/topk_lib.h")"); } TF_ASSIGN_OR_RETURN( @@ -836,18 +834,19 @@ absl::Status ExtendRewrites( absl::Status GenerateHeader( const CodegenOpts& opts, const tf2xla::Config& config, const CompileResult& compile_result, const MetadataResult& metadata_result, - const EmbeddedConstantBuffers& embedded_constant_buffers, string* header) { + const EmbeddedConstantBuffers& embedded_constant_buffers, + std::string* header) { TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); absl::Span buffer_infos = compile_result.aot->buffer_allocation_infos(); - const std::vector arg_index_table = + const std::vector arg_index_table = ::xla::cpu::CreateArgIndexTable(buffer_infos); - const std::vector result_index_table = + const std::vector result_index_table = ::xla::cpu::CreateResultIndexTable(buffer_infos); - std::vector buffer_infos_as_strings = + std::vector buffer_infos_as_strings = BufferAllocationInfosToCppExpression(buffer_infos); // Compute sizes and generate methods. @@ -856,11 +855,11 @@ absl::Status GenerateHeader( std::vector buffer_infos_for_temps = ExtractTempBufferAllocationInfos(buffer_infos); const xla::ProgramShapeProto& ps = compile_result.program_shape; - string methods_arg, methods_result, methods_variable; + std::string methods_arg, methods_result, methods_variable; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable)); - string arg_shape_infos, result_shape_infos; + std::string arg_shape_infos, result_shape_infos; TF_RETURN_IF_ERROR(GenArgShapeInfos(ps, &arg_shape_infos)); TF_RETURN_IF_ERROR( CheckEqual(ps.parameters_size(), arg_index_table.size(), @@ -880,19 +879,19 @@ absl::Status GenerateHeader( const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps); // Create rewrite strings for namespace start and end. - string ns_start; - for (const string& n : opts.namespaces) { + std::string ns_start; + for (const std::string& n : opts.namespaces) { ns_start += absl::StrCat("namespace ", n, " {\n"); } ns_start += "\n"; - string ns_end("\n"); + std::string ns_end("\n"); for (int i = opts.namespaces.size() - 1; i >= 0; --i) { - const string& n = opts.namespaces[i]; + const std::string& n = opts.namespaces[i]; ns_end += absl::StrCat("} // end namespace ", n, "\n"); } // Generate metadata. - const string arg_names_code = + const std::string arg_names_code = GenNameToIndexCode(config.feed(), opts.gen_name_to_index); auto variable_copy = config.variable(); @@ -901,12 +900,12 @@ absl::Status GenerateHeader( var.set_name(var.node_name()); } } - const string variable_names_code = + const std::string variable_names_code = GenNameToIndexCode(variable_copy, opts.gen_name_to_index); - const string result_names_code = + const std::string result_names_code = GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); - const string include_xla_data_proto = + const std::string include_xla_data_proto = opts.gen_program_shape ? R"(#include "xla/xla_data.pb.h")" : ""; @@ -1155,7 +1154,7 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { } // The replacement strategy is naive, but good enough for our purposes. - std::vector> rewrites = { + std::vector> rewrites = { {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, @@ -1194,10 +1193,10 @@ class {{CLASS}} final : public tensorflow::{{COMPUTATION_CLASS_BASE}} { return absl::OkStatus(); } -static string CreateUniqueIdentifier(const CodegenOpts& opts, - absl::string_view suffix) { - string result = "__tfcompile"; - for (const string& n : opts.namespaces) { +static std::string CreateUniqueIdentifier(const CodegenOpts& opts, + absl::string_view suffix) { + std::string result = "__tfcompile"; + for (const std::string& n : opts.namespaces) { absl::StrAppend(&result, "_", n); } @@ -1303,14 +1302,15 @@ absl::Status GenerateMetadata(const CodegenOpts& opts, return absl::OkStatus(); } -absl::Status ParseCppClass(const string& cpp_class, string* class_name, - std::vector* namespaces) { +absl::Status ParseCppClass(const std::string& cpp_class, + std::string* class_name, + std::vector* namespaces) { class_name->clear(); namespaces->clear(); if (cpp_class.empty()) { return errors::InvalidArgument("empty cpp_class: " + cpp_class); } - std::vector parts = absl::StrSplit(cpp_class, "::"); + std::vector parts = absl::StrSplit(cpp_class, "::"); if (parts.front().empty()) { // Allow a fully qualified name that starts with "::". parts.erase(parts.begin()); @@ -1343,11 +1343,11 @@ absl::Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { // implementation-defined characters`. We disallow those here to give // better error messages, at the expensive of being more restrictive than // the standard. - if (ident[0] != '_' && !IsAlpha(ident[0])) { + if (ident[0] != '_' && !absl::ascii_isalpha(ident[0])) { return errors::InvalidArgument("illegal leading char: ", msg); } for (size_t pos = 1; pos < ident.size(); ++pos) { - if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) { + if (ident[pos] != '_' && !absl::ascii_isalnum(ident[pos])) { return errors::InvalidArgument("illegal char: ", msg); } } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 77300b0fde4e3d..ff7d96720b4eba 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -32,14 +32,14 @@ namespace tfcompile { // and the generated metadata object file. struct CodegenOpts { // The name of the generated C++ class, wrapping the generated function. - string class_name; + std::string class_name; // Target triple for the architecture we're targeting. - string target_triple; + std::string target_triple; // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. - std::vector namespaces; + std::vector namespaces; // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. bool gen_name_to_index = false; @@ -62,27 +62,27 @@ struct CodegenOpts { struct MetadataResult { // These are top level "extern C" declarations that are expected to be visible // wherever program_shape_access_shim is emitted. - std::vector header_variable_decls; + std::vector header_variable_decls; // program_shape_access_shim is a C++ expression that constructs the // xla::ProgramShapeProto instance for the CompileResult passed to // GenerateMetadata. - string program_shape_access_shim; + std::string program_shape_access_shim; // hlo_profile_printer_data_access_shim is a C++ expression that constructs // the xla::HloProfilePrinterData instance for the CompileResult passed to // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a // C++ expression that evaluates to nullptr at runtime. // This is set only for AOT legacy. - string hlo_profile_printer_data_access_shim; + std::string hlo_profile_printer_data_access_shim; // cpu_executable_access_shim is a C++ expression that constructs // a protobuf required to construct a CpuExecutable. // This is set only for AOT thunks. - string cpu_executable_access_shim; + std::string cpu_executable_access_shim; // The contents of the object (".o") file. - string object_file_data; + std::string object_file_data; }; // Generates a set of constant buffers embedded into an object file. @@ -105,14 +105,16 @@ absl::Status GenerateMetadata(const CodegenOpts& opts, absl::Status GenerateHeader( const CodegenOpts& opts, const tf2xla::Config& config, const CompileResult& compile_result, const MetadataResult& metadata_result, - const EmbeddedConstantBuffers& embedded_constant_buffers, string* header); + const EmbeddedConstantBuffers& embedded_constant_buffers, + std::string* header); // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` // components. The syntax is [[::],...]. This // mirrors the C++ syntax for referring to a class, where multiple namespaces // may precede the class name, separated by double-colons. -absl::Status ParseCppClass(const string& cpp_class, string* class_name, - std::vector* namespaces); +absl::Status ParseCppClass(const std::string& cpp_class, + std::string* class_name, + std::vector* namespaces); // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index a4f18482db7f32..ec0f336d87f716 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -53,7 +53,7 @@ TEST(ValidateCppIdent, Simple) { TF_EXPECT_OK(ValidateCppIdent("_abc", "")); TF_EXPECT_OK(ValidateCppIdent("_abc123", "")); // Make sure we didn't skip a valid letter or digit - string ident; + std::string ident; for (char c = 'a'; c <= 'z'; c++) { ident.append(1, c); } @@ -78,18 +78,19 @@ TEST(ValidateCppIdent, Simple) { class ParseCppClassTest : public ::testing::Test { protected: - void ExpectOK(const string& cpp_class, const string& want_class_name, - const std::vector& want_namespaces) { - string class_name; - std::vector namespaces; + void ExpectOK(const std::string& cpp_class, + const std::string& want_class_name, + const std::vector& want_namespaces) { + std::string class_name; + std::vector namespaces; TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces)); EXPECT_EQ(class_name, want_class_name); EXPECT_EQ(namespaces, want_namespaces); } - void ExpectFail(const string& cpp_class) { - string class_name; - std::vector namespaces; + void ExpectFail(const std::string& cpp_class) { + std::string class_name; + std::vector namespaces; EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), absl::OkStatus()) << cpp_class; @@ -110,7 +111,7 @@ TEST_F(ParseCppClassTest, ParseOK) { ExpectOK("::_foo::MyClass", "MyClass", {"_foo"}); ExpectOK("::_foo::_MyClass", "_MyClass", {"_foo"}); // Make sure we didn't skip a valid letter or digit - string ident; + std::string ident; for (char c = 'a'; c <= 'z'; c++) { ident.append(1, c); } @@ -143,10 +144,10 @@ TEST_F(ParseCppClassTest, ParseFail) { } static void CompareWithGoldenFile( - const string& tensorflow_relative_golden_file_name, - const string& expected_contents, bool ignore_cr) { + const std::string& tensorflow_relative_golden_file_name, + const std::string& expected_contents, bool ignore_cr) { // Get rid of all CR characters, we may be running under windows. - string sanitized_expected_contents(expected_contents); + std::string sanitized_expected_contents(expected_contents); if (ignore_cr) { sanitized_expected_contents.erase( std::remove(sanitized_expected_contents.begin(), @@ -159,7 +160,7 @@ static void CompareWithGoldenFile( // blaz test --test_strategy=local \ // "third_party/tensorflow/compiler/aot:codegen_test" const bool update_golden = false; - string golden_file_name = + std::string golden_file_name = GetDataDependencyFilepath(tensorflow_relative_golden_file_name); if (update_golden) { @@ -167,7 +168,7 @@ static void CompareWithGoldenFile( WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); } - string golden_file_contents; + std::string golden_file_contents; TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, &golden_file_contents)); if (ignore_cr) { diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 7d0897829b98ca..48c92bf346926f 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -212,7 +212,7 @@ absl::Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, return CompileXla(client, computation, aot_opts, compile_result); } -static absl::Status ReadProtoFile(const string& fname, +static absl::Status ReadProtoFile(const std::string& fname, protobuf::Message* proto) { if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); @@ -297,7 +297,7 @@ absl::Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config)); TF_RETURN_IF_ERROR(ValidateConfig(config)); if (flags.dump_fetch_nodes) { - std::set nodes; + std::set nodes; for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } @@ -368,7 +368,7 @@ absl::Status Main(const MainFlags& flags) { GenerateMetadata(codegen_opts, compile_result, &metadata_result)); TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object, metadata_result.object_file_data)); - string header; + std::string header; TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, metadata_result, embedded_constant_buffers, &header)); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 303854f40ed88c..2a0418126b8aaf 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -38,7 +38,7 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShapeProto program_shape; // Static shape of args and results. - string entry_point; // Name of generated function. + std::string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index b6a6e4cfc2c8d9..1626686ba465ad 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -41,9 +41,9 @@ using xla::llvm_ir::AsStringRef; static void AddEmbeddedProtocolBufferToLlvmModule( llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, - absl::string_view unique_identifier, string* protobuf_array_symbol_name, - int64_t* protobuf_array_size) { - string protobuf_array_contents = proto.SerializeAsString(); + absl::string_view unique_identifier, + std::string* protobuf_array_symbol_name, int64_t* protobuf_array_size) { + std::string protobuf_array_contents = proto.SerializeAsString(); *protobuf_array_symbol_name = absl::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); @@ -58,10 +58,10 @@ static void AddEmbeddedProtocolBufferToLlvmModule( protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); } -static string CreateCPPShimExpression( +static std::string CreateCPPShimExpression( absl::string_view qualified_cpp_protobuf_name, absl::string_view protobuf_array_symbol_name, int64_t protobuf_array_size) { - string code = + std::string code = "[]() {\n" " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" " proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n" @@ -77,7 +77,7 @@ static string CreateCPPShimExpression( }); } -static absl::StatusOr CodegenModule( +static absl::StatusOr CodegenModule( llvm::TargetMachine* target_machine, std::unique_ptr module) { llvm::SmallVector stream_buffer; llvm::raw_svector_ostream ostream(stream_buffer); @@ -91,7 +91,7 @@ static absl::StatusOr CodegenModule( codegen_passes.run(*module); - return string(stream_buffer.begin(), stream_buffer.end()); + return std::string(stream_buffer.begin(), stream_buffer.end()); } static absl::StatusOr> @@ -124,9 +124,9 @@ absl::StatusOr CreateEmbeddedProtocolBuffers( EmbeddedProtocolBuffers result; for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) { - string cpp_shim, cpp_variable_decl; + std::string cpp_shim, cpp_variable_decl; if (protobuf_to_embed.message) { - string protobuf_array_symbol_name; + std::string protobuf_array_symbol_name; int64_t protobuf_array_size; AddEmbeddedProtocolBufferToLlvmModule( diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 0af4d4a3362f8c..aa3553f3b6a85b 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -37,11 +37,11 @@ struct EmbeddedProtocolBuffers { struct CPPShim { // `expression` is a C++ expression that creates an instance of said // protocol buffer when executed. - string expression; + std::string expression; // `variable_decl` is an "extern C" array declaration that is used in // `expression`. It must be visible wherever `expression` is emitted. - string variable_decl; + std::string variable_decl; }; // Each cpp_shim corresponds to one embedded protocol buffer. @@ -50,20 +50,20 @@ struct EmbeddedProtocolBuffers { // The contents of the object (".o") file the protocol buffers are embbed in. // This needs to be linked in to any program that wants to execute any of the // expressions in `cpp_shims`. - string object_file_data; + std::string object_file_data; }; // Describes a protocol buffer to embed into an object file. struct ProtobufToEmbed { // `symbol_prefix` is prefix that is guaranteed to be unique across the binary // or DSO the generated object file will be linked into. - string symbol_prefix; + std::string symbol_prefix; // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ // namespace qualified) protocol buffer name. This is only used in // CPPShim::expression so relatively qualified names are fine as long as // they're valid wherever CPPShim::expression is emitted. - string qualified_cpp_protobuf_name; + std::string qualified_cpp_protobuf_name; // `message` is the protocol buffer to be embedded. It is allowed to be // nullptr, in which case the generated C++ shim expression is just `nullptr`, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 9a3f2900dbafe4..5d0f93f7d67b88 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -27,27 +27,27 @@ namespace tfcompile { // Flags for the tfcompile binary. See *.cc file for descriptions. struct MainFlags { - string graph; - string debug_info; - string debug_info_path_begin_marker; - string config; + std::string graph; + std::string debug_info; + std::string debug_info_path_begin_marker; + std::string config; bool dump_fetch_nodes = false; - string target_triple; - string target_cpu; - string target_features; - string entry_point; - string cpp_class; - string out_function_object; - string out_metadata_object; - string out_header; - string out_constant_buffers_object; - string out_session_module; - string mlir_components; + std::string target_triple; + std::string target_cpu; + std::string target_features; + std::string entry_point; + std::string cpp_class; + std::string out_function_object; + std::string out_metadata_object; + std::string out_header; + std::string out_constant_buffers_object; + std::string out_session_module; + std::string mlir_components; bool experimental_quantize = false; // Sanitizer pass options bool sanitize_dataflow = false; - string sanitize_abilists_dataflow; + std::string sanitize_abilists_dataflow; // C++ codegen options bool gen_name_to_index = false; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 8caeec32b7bc5e..e2509d653974e7 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -63,6 +63,7 @@ def _tfcompile_model_library_rule_impl(ctx): "--xla_cpu_fast_math_honor_functions=false " + "--xla_cpu_fast_math_honor_division=false " + "--xla_cpu_enable_fast_min_max=true " + + "--xla_cpu_experimental_ynn_fusion_type= " + additional_xla_flags + " " + "$${XLA_FLAGS:-}' "), "CUDA_VISIBLE_DEVICES": "", @@ -335,11 +336,10 @@ def _tf_library( ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. + "@local_xla//xla/backends/cpu/runtime:sort_lib", + "@local_xla//xla/backends/cpu/runtime:topk_lib", "@local_xla//xla/service/cpu:runtime_conv2d", - "@local_xla//xla/service/cpu:runtime_custom_call_status", - "@local_xla//xla/service/cpu:runtime_key_value_sort", "@local_xla//xla/service/cpu:runtime_matmul", - "@local_xla//xla/service/cpu:runtime_topk", "@local_xla//xla/service/cpu:runtime_single_threaded_conv2d", "@local_xla//xla/service/cpu:runtime_single_threaded_matmul", "@eigen_archive//:eigen3", diff --git a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc index d2ced20a8d5eec..0c4edc85f99d19 100644 --- a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc +++ b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc @@ -28,8 +28,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" -#include "xla/backends/cpu/runtime/convolution_lib.h" -#include "xla/backends/cpu/runtime/dot_lib.h" +#include "xla/backends/cpu/runtime/convolution_dims.h" +#include "xla/backends/cpu/runtime/dot_dims.h" #include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/layout_util.h" #include "xla/service/cpu/executable.pb.h" @@ -594,35 +594,46 @@ ThunkProtoExecutionDeserializer::GetSortThunkRunImpl( std::vector buffers_to_sort; buffers_to_sort.reserve(sort_thunk.inputs_shapes_size()); - std::vector values_primitive_type_size_in_bytes; - values_primitive_type_size_in_bytes.reserve(sort_thunk.inputs_shapes_size()); + std::vector primitive_sizes; + primitive_sizes.reserve(sort_thunk.inputs_shapes_size()); for (const auto& buffer_proto : sort_thunk.inputs_shapes()) { buffers_to_sort.push_back( - absl::StrCat("reinterpret_cast(", + absl::StrCat("reinterpret_cast(", GetBufferAllocationString(buffer_proto.slice()), ")")); - values_primitive_type_size_in_bytes.push_back( - xla::ShapeUtil::ByteSizeOfPrimitiveType( - buffer_proto.shape().element_type())); + primitive_sizes.push_back(xla::ShapeUtil::ByteSizeOfPrimitiveType( + buffer_proto.shape().element_type())); } absl::string_view sort_thunk_invocation_format = R"( // Sort Thunk { - std::vector values = { + std::vector values = { {{BUFFERS_TO_SORT}} }; - std::vector values_primitive_type_size_in_bytes = { + std::vector primitive_sizes = { {{VALUES_PRIMITIVE_TYPE_SIZE_IN_BYTES}} }; - __xla_cpu_runtime_KeyValueSort( - {{HIGHER_DIMENSIONS}}, {{SORT_DIMENSION_ELEMENTS}}, {{LOWER_DIMENSIONS}}, - values.data(), - int32_t(values.size()), - values_primitive_type_size_in_bytes.data(), - /*is_stable=*/{{IS_STABLE}}, - reinterpret_cast(run_options), - /*prof_counters=*/nullptr, - reinterpret_cast({{SORT_FUNCTION_NAME}})); + // Type alias compatible with `FunctionLibrary::Comparator`. + using Comparator = void(bool* result, const void* run_options, + const void** params, const void* buffer_table, + const void* status, const void* prof_counters); + Comparator* comparator = reinterpret_cast( + {{SORT_FUNCTION_NAME}}); + + absl::AnyInvocable less_than = + [comparator](const void** data) { + bool result; + (*comparator)(&result, nullptr, data, nullptr, nullptr, nullptr); + return result; + }; + + xla::cpu::internal::SortInplace( + { + {{HIGHER_DIMENSIONS}}, + {{SORT_DIMENSION_ELEMENTS}}, + {{LOWER_DIMENSIONS}} + }, + values, primitive_sizes, {{IS_STABLE}}, &less_than); })"; TF_ASSIGN_OR_RETURN( @@ -660,7 +671,7 @@ ThunkProtoExecutionDeserializer::GetSortThunkRunImpl( {"{{SORT_FUNCTION_NAME}}", sort_thunk.comparator_name()}, {"{{BUFFERS_TO_SORT}}", absl::StrJoin(buffers_to_sort, ", ")}, {"{{VALUES_PRIMITIVE_TYPE_SIZE_IN_BYTES}}", - absl::StrJoin(values_primitive_type_size_in_bytes, ", ")}, + absl::StrJoin(primitive_sizes, ", ")}, {"{{IS_STABLE}}", sort_thunk.is_stable() ? "true" : "false"}, }); } @@ -677,7 +688,7 @@ ThunkProtoExecutionDeserializer::GetTopKThunkRunImpl( absl::string_view topk_thunk_invocation_format = R"( // TopK Thunk { - __xla_cpu_runtime_TopKF32({{BATCH_SIZE}}, {{INPUT_SIZE}}, {{K}}, + ::xla::cpu::internal::TopK({{BATCH_SIZE}}, {{INPUT_SIZE}}, {{K}}, reinterpret_cast({{VALUES_PTR}}), reinterpret_cast({{OUTPUT_PTR}}), reinterpret_cast({{INDICES_PTR}})); diff --git a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h index 1e5e47f140020e..a5adeff3917b46 100644 --- a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h +++ b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "xla/backends/cpu/runtime/convolution_lib.h" +#include "xla/backends/cpu/runtime/convolution_dims.h" #include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/debug_options_flags.h" #include "xla/service/cpu/executable.pb.h" diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index c65bb6c44b1079..7c1772c084750c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -654,6 +654,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_xla//xla:future", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:util", @@ -662,7 +663,6 @@ cc_library( "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:pjrt_common", "@local_xla//xla/pjrt:pjrt_executable", - "@local_xla//xla/pjrt:pjrt_future", "@local_xla//xla/service:executable", "@local_xla//xla/service:maybe_owning_device_memory", "@local_xla//xla/service:shaped_buffer", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index bed899bfed2f3e..31f1aeedd9850e 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -132,7 +132,7 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, if (merged_output.node() == nullptr) { Output new_output(new_node, oidx); if (debugging_opts.print_outputs) { - string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + std::string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; ops::Print print_op(s.WithOpName("print_", oidx) .WithDevice(cpu_device) .WithAssignedDevice(cpu_device), @@ -298,7 +298,8 @@ absl::StatusOr ReplaceFunctionCallWithPartitionedCall( const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, Node* n, Graph* g, const NameAttrList& func, const Scope& root) { - string config_string = options.session_options->config.SerializeAsString(); + std::string config_string = + options.session_options->config.SerializeAsString(); int input_count = absl::c_count_if( n->in_edges(), [](const Edge* e) { return !e->IsControlEdge(); }); @@ -346,7 +347,8 @@ absl::StatusOr ReplaceFunctionCallWithPartitionedCall( absl::StatusOr InferDeviceForCluster( jit::DeviceInfoCache* device_info_cache, Node* n, - const string& function_name, const FunctionLibraryDefinition& flib_def) { + const std::string& function_name, + const FunctionLibraryDefinition& flib_def) { const FunctionDef* func_def = flib_def.Find(function_name); TF_RET_CHECK(func_def) << "Could not find " << function_name; @@ -485,7 +487,8 @@ absl::Status ReplaceNodeWithXlaCompileAndXlaRun( requires_compilation = true; } - string device_name_str = string(device_info_cache->GetNameFor(device)); + std::string device_name_str = + std::string(device_info_cache->GetNameFor(device)); absl::Status status; Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index c3b5ba5521ee65..6b90557df4b86f 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -85,8 +85,8 @@ absl::Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, return absl::OkStatus(); } -absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, +absl::Status MakeXlaCompiledKernel(Graph* graph, const std::string& callee_name, + const std::string& node_name, int num_constant_args, int num_resource_args, Node** result) { NodeDef call_node; @@ -99,14 +99,16 @@ absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, return absl::OkStatus(); } -absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, Node** result) { +absl::Status MakeXlaCompiledKernel(Graph* graph, const std::string& callee_name, + const std::string& node_name, + Node** result) { return MakeXlaCompiledKernel(graph, callee_name, node_name, /*num_constant_args=*/0, /*num_resource_args=*/0, result); } -Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { +Node* MakeWrite(const Scope& scope, Output value_to_write, + const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var_" + id), DT_FLOAT, TensorShape({})); ops::AssignVariableOp assign_op(scope.WithOpName("Assignee_" + id), @@ -114,12 +116,13 @@ Node* MakeWrite(const Scope& scope, Output value_to_write, const string& id) { return assign_op.operation.node(); } -Node* MakeWrite(const Scope& scope, const string& id) { +Node* MakeWrite(const Scope& scope, const std::string& id) { return MakeWrite( scope, ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f), id); } -FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { +FunctionDefLibrary CreateFunctionDefLibWithConstFunction( + const std::string& name) { FunctionDefLibrary fdef_lib; FunctionDef func = FunctionDefHelper::Create( /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index bb8dce848cfbc9..4164efc65a8f4c 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -36,19 +36,21 @@ class CloneConstantsForBetterClusteringPassImpl { private: absl::Status CloneSmallConstantInputs( - const absl::flat_hash_set& name_set, Node* n); - string GenerateUniqueName(const absl::flat_hash_set& name_set, - absl::string_view prefix); - absl::StatusOr CloneNode(const absl::flat_hash_set& name_set, - Node* n); + const absl::flat_hash_set& name_set, Node* n); + std::string GenerateUniqueName( + const absl::flat_hash_set& name_set, + absl::string_view prefix); + absl::StatusOr CloneNode( + const absl::flat_hash_set& name_set, Node* n); Graph* graph_; int unique_name_counter_; }; -string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName( - const absl::flat_hash_set& name_set, absl::string_view prefix) { - string candidate; +std::string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName( + const absl::flat_hash_set& name_set, + absl::string_view prefix) { + std::string candidate; do { candidate = absl::StrCat(prefix, "/clone_", unique_name_counter_++); } while (name_set.contains(candidate)); @@ -56,7 +58,7 @@ string CloneConstantsForBetterClusteringPassImpl::GenerateUniqueName( } absl::StatusOr CloneConstantsForBetterClusteringPassImpl::CloneNode( - const absl::flat_hash_set& name_set, Node* n) { + const absl::flat_hash_set& name_set, Node* n) { NodeDef new_in_def = n->def(); new_in_def.clear_input(); new_in_def.set_name(GenerateUniqueName(name_set, new_in_def.name())); @@ -112,7 +114,7 @@ bool IsInPlaceOp(absl::string_view op_name) { absl::Status CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( - const absl::flat_hash_set& name_set, Node* n) { + const absl::flat_hash_set& name_set, Node* n) { std::vector in_edges; // Get the edges and sort them so we clone in a deterministic order. absl::c_copy(n->in_edges(), std::back_inserter(in_edges)); @@ -142,7 +144,7 @@ CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( } absl::Status CloneConstantsForBetterClusteringPassImpl::Run() { - absl::flat_hash_set name_set; + absl::flat_hash_set name_set; absl::c_transform(graph_->nodes(), std::inserter(name_set, name_set.begin()), [](Node* n) { return n->name(); }); std::vector nodes; diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc index e70be48f0b7341..20a3d98be1d0f2 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -51,8 +51,8 @@ class ClusterScopingPassImpl { size_t unique_scope_id_; }; -std::optional GetXlaInternalScope(Node* node) { - string scope; +std::optional GetXlaInternalScope(Node* node) { + std::string scope; if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) { return scope; } @@ -85,8 +85,8 @@ void SetXlaInternalScope(Node* node, absl::string_view scope) { // Node_X (scope "stage") -> Stage // void AddOrAppendXlaInternalScope(Node* node, absl::string_view suffix) { - string updated_scope; - std::optional cur_scope = GetXlaInternalScope(node); + std::string updated_scope; + std::optional cur_scope = GetXlaInternalScope(node); if (cur_scope == std::nullopt) { updated_scope = std::string(suffix); } else { @@ -96,7 +96,7 @@ void AddOrAppendXlaInternalScope(Node* node, absl::string_view suffix) { } void ClusterScopingPassImpl::AddScopeToAllTransitivePredecessors(Node* start) { - const string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); + const std::string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); std::vector starts; starts.push_back(start); @@ -106,7 +106,7 @@ void ClusterScopingPassImpl::AddScopeToAllTransitivePredecessors(Node* start) { } void ClusterScopingPassImpl::AddScopeToAllTransitiveSuccessors(Node* start) { - const string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); + const std::string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); std::vector starts; starts.push_back(start); diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc index b09cb2c12fa297..66cc10775992a3 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -45,10 +45,11 @@ absl::Status ClusterScoping(std::unique_ptr* graph) { return pass.Run(opt_options); } -absl::flat_hash_map GetXlaInternalScopes(const Graph& graph) { - absl::flat_hash_map scopes; +absl::flat_hash_map GetXlaInternalScopes( + const Graph& graph) { + absl::flat_hash_map scopes; for (Node* node : graph.nodes()) { - string scope; + std::string scope; if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) { scopes[node->name()] = scope; } @@ -63,7 +64,7 @@ absl::flat_hash_map GetXlaInternalScopes(const Graph& graph) { return scopes; } -Node* BuildStageNode(GraphDefBuilder& builder, string name, +Node* BuildStageNode(GraphDefBuilder& builder, std::string name, std::initializer_list dtypes, absl::Span values) { auto opts = builder.opts() diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 50b26371698877..6c77648817f808 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -172,7 +172,7 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( } bool RecursiveCompilabilityChecker::HasXLAKernel( - const Node& node, string* uncompilable_reason) const { + const Node& node, std::string* uncompilable_reason) const { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by // IsCompilableCall(). @@ -424,7 +424,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } - string uncompilable_reason; + std::string uncompilable_reason; if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) { if (!IsCompilableCall(node.def(), lib_runtime, stack_trace, encapsulating_function, uncompilable_nodes)) { diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 0d86c22de11a22..7d6741529ebd08 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -262,7 +262,7 @@ class RecursiveCompilabilityChecker { } bool HasXLAKernel(const Node& node, - string* uncompilable_reason = nullptr) const; + std::string* uncompilable_reason = nullptr) const; static void MaybeMarkUncompilableNode( const absl::string_view reason, diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 2b2db07642d1ab..fa546e3543e358 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -123,7 +123,7 @@ class Predicate { public: enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol, kIntSymbol }; - virtual string ToString() const = 0; + virtual std::string ToString() const = 0; // An ID assigned to the Predicate at construction time. Conceptually like a // pointer, except that it is stable across runs. @@ -156,12 +156,12 @@ class AndPredicate : public Predicate { explicit AndPredicate(int64_t id, std::vector operands) : Predicate(id), operands_(std::move(operands)) {} - string ToString() const override { + std::string ToString() const override { if (operands().empty()) { return "#true"; } - std::vector operands_str; + std::vector operands_str; std::transform(operands().begin(), operands().end(), std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); @@ -186,12 +186,12 @@ class OrPredicate : public Predicate { explicit OrPredicate(int64_t id, std::vector operands) : Predicate(id), operands_(std::move(operands)) {} - string ToString() const override { + std::string ToString() const override { if (operands().empty()) { return "#false"; } - std::vector operands_str; + std::vector operands_str; std::transform(operands().begin(), operands().end(), std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); @@ -215,7 +215,7 @@ class NotPredicate : public Predicate { explicit NotPredicate(int64_t id, Predicate* operand) : Predicate(id), operands_({operand}) {} - string ToString() const override { + std::string ToString() const override { return absl::StrCat("~", operand()->ToString()); } @@ -251,14 +251,14 @@ class NotPredicate : public Predicate { class AndRecurrencePredicate : public Predicate { public: explicit AndRecurrencePredicate(int64_t id, Predicate* start, Predicate* step, - std::vector frame) + std::vector frame) : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {} Predicate* start() const { return operands_[0]; } Predicate* step() const { return operands_[1]; } - absl::Span frame() const { return frame_; } + absl::Span frame() const { return frame_; } - string ToString() const override { + std::string ToString() const override { return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), "}<", absl::StrJoin(frame(), ";"), ">"); } @@ -271,7 +271,7 @@ class AndRecurrencePredicate : public Predicate { private: std::array operands_; - std::vector frame_; + std::vector frame_; }; // Represents an uninterpreted symbol in a logical predicate. @@ -286,7 +286,7 @@ class SymbolPredicate : public Predicate { tensor_id_(std::move(tensor_id)), must_be_true_(must_be_true) {} - string ToString() const override { + std::string ToString() const override { return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } @@ -320,7 +320,7 @@ class IntSymbolPredicate : public Predicate { tensor_id_(std::move(tensor_id)), must_have_value_(must_have_value) {} - string ToString() const override { + std::string ToString() const override { return must_have_value().has_value() ? absl::StrCat(tensor_id_.ToString(), "=", *must_have_value_) : tensor_id_.ToString(); @@ -396,7 +396,7 @@ class PredicateFactory { } Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step, - std::vector frame) { + std::vector frame) { SignatureForAndRec signature(start, step, std::move(frame)); auto it = interned_and_rec_instances_.find(signature); if (it != interned_and_rec_instances_.end()) { @@ -463,8 +463,8 @@ class PredicateFactory { Tensor tensor(proto->dtype()); TF_RET_CHECK(tensor.FromProto(*proto)); - *predicate = tensor.scalar()() == *must_have_value ? MakeTrue() - : MakeFalse(); + *predicate = tensor.scalar()() == *must_have_value ? MakeTrue() + : MakeFalse(); return absl::OkStatus(); } SignatureForIntSymbol signature = {tensor_id, must_have_value}; @@ -559,9 +559,9 @@ class PredicateFactory { std::pair>; using SignatureForNot = Predicate*; using SignatureForAndRec = - std::tuple>; + std::tuple>; using SignatureForSymbol = std::pair; - using SignatureForIntSymbol = std::pair>; + using SignatureForIntSymbol = std::pair>; struct HashSignatureForAndOr { size_t operator()(const SignatureForAndOr& signature) const { @@ -586,7 +586,7 @@ class PredicateFactory { SafeTensorId::Hasher()(signature.first), Hash64Combine( ::tensorflow::hash()(signature.second.has_value()), - ::tensorflow::hash()( + ::tensorflow::hash()( signature.second.has_value() ? *signature.second : 0))); } }; @@ -830,8 +830,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { absl::StatusOr GetPredicateFor( Node* n, int oidx) const override; void Print() const override; - absl::flat_hash_map PredicateMapAsString() - const; + absl::flat_hash_map + PredicateMapAsString() const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; @@ -958,7 +958,7 @@ absl::Status DeadnessAnalysisImpl::HandleSwitch( for (int i = 0; i < n->num_outputs() - 1; i++) { TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( pred_edge->src(), pred_edge->src_output(), - /*must_have_value=*/std::optional(i), &branch_pred)); + /*must_have_value=*/std::optional(i), &branch_pred)); input_preds.push_back(branch_pred); SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds), should_revisit); @@ -982,7 +982,7 @@ absl::Status DeadnessAnalysisImpl::HandleSwitch( namespace { absl::Status CreateMultipleNextIterationInputsError(Node* merge) { - std::vector backedges; + std::vector backedges; for (const Edge* backedge : merge->in_edges()) { if (backedge->src()->IsNextIteration()) { backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); @@ -1058,7 +1058,7 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, absl::Status GetFullFrame(const Node* n, absl::Span cfi_infos, - std::vector* frame) { + std::vector* frame) { int depth = 0; for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { @@ -1174,7 +1174,7 @@ absl::Status DeadnessAnalysisImpl::HandleMerge( Predicate* start = predicate_factory_.MakeOrPredicate(non_recurrent_inputs); - std::vector frame; + std::vector frame; TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame)); Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate( start, step, std::move(frame)); @@ -1358,7 +1358,7 @@ absl::Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( // nested while, as there is no clean cut for separating them in the topological // order. absl::Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { - std::vector unreachable_nodes; + std::vector unreachable_nodes; // Compute the loop structure of the graph. TF_RETURN_IF_ERROR( BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes)); @@ -1582,9 +1582,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return absl::OkStatus(); } -absl::flat_hash_map +absl::flat_hash_map DeadnessAnalysisImpl::PredicateMapAsString() const { - absl::flat_hash_map result; + absl::flat_hash_map result; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); } @@ -1603,7 +1603,7 @@ absl::Status ComputePredicates(const Graph& graph, } // namespace deadness_analysis_internal -string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { +std::string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const { return static_cast(predicate.pred_)->ToString(); } diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index 80fa9a20faef41..1cd394154faf36 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -81,7 +81,7 @@ class DeadnessAnalysis { virtual void Print() const = 0; virtual ~DeadnessAnalysis(); - string DebugString(DeadnessPredicate predicate) const; + std::string DebugString(DeadnessPredicate predicate) const; // Run the deadness analysis over `graph` and returns an error or a populated // instance of DeadnessAnalysis in `result`. diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 0dc18d3e129d79..569cdeadae735e 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -24,7 +24,8 @@ namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = absl::flat_hash_map; +using PredicateMapTy = + absl::flat_hash_map; absl::Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, bool enable_optimistic = true); diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 894ee659121e25..fd7d93b3772f5f 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -61,7 +61,7 @@ absl::Status AnalyzeDeadness(Graph* graph, return DeadnessAnalysis::Run(*graph, result); } -ops::Switch CreateSwitch(const Scope& root, const string& prefix) { +ops::Switch CreateSwitch(const Scope& root, const std::string& prefix) { Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT); Output predicate = ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL); @@ -76,7 +76,7 @@ void VLogGraphIfAsked(const Graph& graph) { if (VLOG_IS_ON(3)) { GraphDef graph_def; graph.ToGraphDef(&graph_def); - string serialized; + std::string serialized; ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized); LOG(INFO) << serialized; } @@ -127,8 +127,8 @@ struct InductionVarInfo { // +-----> | Exit | // +---------------+ InductionVarInfo CreateInductionVariable(const Scope& root, - const string& prefix, - const string& frame_name, + const std::string& prefix, + const std::string& frame_name, const Output& initial_value) { Output enter_initial_value = ops::internal::Enter( root.WithOpName(prefix + "/enter"), initial_value, frame_name); @@ -158,8 +158,8 @@ InductionVarInfo CreateInductionVariable(const Scope& root, } InductionVarInfo CreateInductionVariable(const Scope& root, - const string& prefix, - const string& frame_name, + const std::string& prefix, + const std::string& frame_name, int32_t init) { return CreateInductionVariable( root, prefix, frame_name, @@ -201,7 +201,7 @@ struct DependentInductionVar { }; DependentInductionVar CreateDependentLoopInvariantValue( - const Scope& root, const string& prefix, const string& frame_name, + const Scope& root, const std::string& prefix, const std::string& frame_name, const Output& loop_cond, const Output& value) { Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"), value, frame_name); @@ -218,7 +218,7 @@ DependentInductionVar CreateDependentLoopInvariantValue( } DependentInductionVar CreateDependentLoopInvariantValue( - const Scope& root, const string& prefix, const string& frame_name, + const Scope& root, const std::string& prefix, const std::string& frame_name, const Output& loop_cond, int32_t value) { return CreateDependentLoopInvariantValue( root, prefix, frame_name, loop_cond, diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc index 9ec02d92d37cd6..8288b44e7f1c1d 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc @@ -65,9 +65,9 @@ struct SignatureNotEqual { // Functor that incrementally computes a Signature's hash given its current hash // and one of its args. struct SignatureHashCombiner { - explicit SignatureHashCombiner(const uint64 h) : h(h) {} - uint64 h; - uint64 operator()(const Tensor& arg) { + explicit SignatureHashCombiner(const uint64_t h) : h(h) {} + uint64_t h; + uint64_t operator()(const Tensor& arg) { h = Hash64Combine(h, std::hash()(static_cast(arg.dtype()))); h = Hash64Combine( h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); @@ -76,7 +76,7 @@ struct SignatureHashCombiner { } return h; } - uint64 operator()(const TensorTypeAndShape& arg) { + uint64_t operator()(const TensorTypeAndShape& arg) { h = Hash64Combine(h, std::hash()(static_cast(arg.first))); h = Hash64Combine(h, std::hash()(arg.second.size())); for (int dim : arg.second) { @@ -108,8 +108,8 @@ bool Signature::operator==(const Signature& other) const { return true; } -uint64 Signature::Hash::operator()(const Signature& signature) const { - uint64 h = std::hash()(signature.name); +uint64_t Signature::Hash::operator()(const Signature& signature) const { + uint64_t h = std::hash()(signature.name); for (const auto& arg : signature.args) { h = std::visit(SignatureHashCombiner(h), arg); } diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.h b/tensorflow/compiler/jit/device_compilation_cluster_signature.h index b4c2840eedee59..721c1d3b78c50e 100644 --- a/tensorflow/compiler/jit/device_compilation_cluster_signature.h +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.h @@ -58,7 +58,8 @@ struct DeviceCompilationClusterSignature { bool operator==(const DeviceCompilationClusterSignature& other) const; struct Hash { - uint64 operator()(const DeviceCompilationClusterSignature& signature) const; + uint64_t operator()( + const DeviceCompilationClusterSignature& signature) const; }; // Returns a human-readable description of the signature. diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc index 5e1b3b26e8ecb5..ec161293b7643d 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.cc +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -107,7 +107,7 @@ absl::Status DeviceCompilationProfiler::RegisterCompilation( cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) .first; - const uint64 compile_time_s = compile_time_us / 1.0e6; + const uint64_t compile_time_s = compile_time_us / 1.0e6; it->second.compile_count++; it->second.cumulative_compile_time_us += compile_time_us; VLOG(1) << "Compiled " << function_name << " " << it->second.compile_count diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index 0fae07abd22897..a9f2418282c414 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -137,7 +137,7 @@ class DeviceCompiler : public ResourceBase { return compiler_client_.get(); } - string DebugString() const override; + std::string DebugString() const override; private: // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` @@ -259,7 +259,7 @@ DeviceCompiler::~DeviceCompiler() { } template -string DeviceCompiler::DebugString() const { +std::string DeviceCompiler::DebugString() const { return "DeviceCompiler"; } @@ -331,7 +331,7 @@ DeviceCompiler::CompileStrict( CompileScope scope, OpKernelContext* ctx, DeviceCompilationProfiler* profiler, mutex* mu) { tensorflow::Env* env = tensorflow::Env::Default(); - const uint64 compile_start_us = env->NowMicros(); + const uint64_t compile_start_us = env->NowMicros(); TfGraphToHloCompiler compiler(options); cache_value.compile_state = DeviceCompileState::kCompiled; @@ -385,8 +385,8 @@ DeviceCompiler::CompileStrict( // Finalize the cache to release the XlaComputation after it was compiled. cache_->Finalize(); - const uint64 compile_end_us = env->NowMicros(); - const uint64 compile_time_us = compile_end_us - compile_start_us; + const uint64_t compile_end_us = env->NowMicros(); + const uint64_t compile_time_us = compile_end_us - compile_start_us; device_compiler_internal::LogOnceXlaCompiledFirstCluster(); TF_RETURN_IF_ERROR(profiler->RegisterCompilation( @@ -496,7 +496,7 @@ absl::Status DeviceCompiler::CompileImpl( profiler->RegisterExecution(function); - string human_signature; + std::string human_signature; if (VLOG_IS_ON(2)) { human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name(); VLOG(2) << "DeviceCompilationClusterSignature: " << human_signature; diff --git a/tensorflow/compiler/jit/device_compiler_test.cc b/tensorflow/compiler/jit/device_compiler_test.cc index 64e286bff55b07..749110be186311 100644 --- a/tensorflow/compiler/jit/device_compiler_test.cc +++ b/tensorflow/compiler/jit/device_compiler_test.cc @@ -139,7 +139,7 @@ class MockXlaDeviceExecutablePersistor Config{testing::TmpDir(), false, "xla"}, DeviceType(DEVICE_CPU_XLA_JIT)) {} MOCK_METHOD(absl::Status, TryToPersistExecutable, - (uint64, const std::string&, const XlaCompiler::Options&, + (uint64_t, const std::string&, const XlaCompiler::Options&, const XlaCompiler::CompilationResult&, const xla::LocalExecutable&, (DeviceCompilerClient*)), @@ -425,7 +425,7 @@ TEST_F(DeviceCompilerTest, CompileFailedToLoadFromPersistentCache) { &xla_executable)); // Corrupt the file which contains the serialized executable. - std::vector files; + std::vector files; TF_ASSERT_OK(Env::Default()->GetChildren(testing::TmpDir(), &files)); std::string const* serialized_executable_filename = nullptr; for (const auto& file : files) { diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc index 34a0c3d5ea067b..33bba30f3db3e1 100644 --- a/tensorflow/compiler/jit/device_context_test.cc +++ b/tensorflow/compiler/jit/device_context_test.cc @@ -38,7 +38,7 @@ static bool Initialized = [] { class DeviceContextTest : public ::testing::Test { public: - void SetDevice(const string& device_type) { + void SetDevice(const std::string& device_type) { auto& rollout_config = GetXlaOpsCommonFlags()->tf_xla_use_device_api; rollout_config.AllowForDeviceInXlaLaunch(DeviceType(device_type)); rollout_config.AllowForDeviceInXlaCompileOnDemand(DeviceType(device_type)); diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h index 458441c86b5c43..5a64b078e1a93c 100644 --- a/tensorflow/compiler/jit/device_executable_persistor.h +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -96,7 +96,7 @@ class DeviceExecutablePersistor { // TODO(b/255826209): Take in Signature instead of hash and string once cache // is refactored. std::optional>> TryToLoadExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, DeviceCompilerClient* client) const; @@ -107,7 +107,7 @@ class DeviceExecutablePersistor { // TODO(b/255826209): Take in Signature instead hash and string once cache // is refactored. virtual absl::Status TryToPersistExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, @@ -123,15 +123,15 @@ class DeviceExecutablePersistor { // Returns a cache key proto that identifies an entry in the compilation // cache. XlaSerializedCacheKey BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module) const; + uint64_t signature_hash, const xla::HloModuleProto& hlo_module) const; XlaSerializedCacheKey BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module, + uint64_t signature_hash, const xla::HloModuleProto& hlo_module, bool compiled_using_pjrt) const; // Serializes the signature and its corresponding entry to a proto message. absl::StatusOr SerializeEntry( - uint64 signature_hash, const XlaCompiler::Options& options, + uint64_t signature_hash, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, DeviceCompilerClient* compiler_client) const; @@ -189,7 +189,7 @@ std::string DeviceExecutablePersistor::GetFilePath( template XlaSerializedCacheKey DeviceExecutablePersistor::BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module, + uint64_t signature_hash, const xla::HloModuleProto& hlo_module, bool compiled_using_pjrt) const { XlaSerializedCacheKey key; key.set_signature_fingerprint(signature_hash); @@ -203,7 +203,7 @@ DeviceExecutablePersistor::BuildSerializedCacheKey( template XlaSerializedCacheKey DeviceExecutablePersistor::BuildSerializedCacheKey( - uint64 signature_hash, const xla::HloModuleProto& hlo_module) const { + uint64_t signature_hash, const xla::HloModuleProto& hlo_module) const { return BuildSerializedCacheKey(signature_hash, hlo_module, false); } @@ -212,7 +212,7 @@ DeviceExecutablePersistor::BuildSerializedCacheKey( template <> inline XlaSerializedCacheKey DeviceExecutablePersistor:: - BuildSerializedCacheKey(uint64 signature_hash, + BuildSerializedCacheKey(uint64_t signature_hash, const xla::HloModuleProto& hlo_module) const { return BuildSerializedCacheKey(signature_hash, hlo_module, true); } @@ -305,7 +305,7 @@ DeviceExecutablePersistor::SaveSerializedEntry( template absl::StatusOr DeviceExecutablePersistor::SerializeEntry( - uint64 signature_hash, const XlaCompiler::Options& options, + uint64_t signature_hash, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, DeviceCompilerClient* compiler_client) const { @@ -340,7 +340,7 @@ DeviceExecutablePersistor::SerializeEntry( template std::optional>> DeviceExecutablePersistor::TryToLoadExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, DeviceCompilerClient* compiler_client) const { @@ -376,7 +376,7 @@ DeviceExecutablePersistor::TryToLoadExecutable( template absl::Status DeviceExecutablePersistor::TryToPersistExecutable( - uint64 signature_hash, const std::string& signature_str, + uint64_t signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, const ExecutableType& executable, diff --git a/tensorflow/compiler/jit/device_executable_persistor_test.cc b/tensorflow/compiler/jit/device_executable_persistor_test.cc index 7779f1112e7b9e..62cfd4c1b8e0b7 100644 --- a/tensorflow/compiler/jit/device_executable_persistor_test.cc +++ b/tensorflow/compiler/jit/device_executable_persistor_test.cc @@ -222,7 +222,7 @@ absl::StatusOr ReadCacheEntryFromFile( } XlaSerializedCacheKey CreateCacheKey( - uint64 signature_hash, + uint64_t signature_hash, const XlaCompiler::CompilationResult& compilation_result, const DeviceType& device_type, const std::string& persistence_prefix, bool compiled_using_pjrt = false) { diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 828da0b08c2590..1979aec5bcf0c3 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -44,7 +44,7 @@ void DeviceSet::UnionWith(const DeviceSet& other) { } bool DeviceSet::IsEmpty() const { - return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; }); + return absl::c_all_of(storage_, [&](uint64_t val) { return val == 0; }); } absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { @@ -56,7 +56,7 @@ absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { } int new_id = names_.size(); - names_.push_back(string(name)); + names_.push_back(std::string(name)); id_to_device_type_.push_back(std::make_unique("")); DeviceType* device_type = id_to_device_type_.back().get(); TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type)); @@ -64,7 +64,7 @@ absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { is_cpu_.push_back(device_type->type_string() == DEVICE_CPU); is_gpu_.push_back(device_type->type_string() == DEVICE_GPU); - name_to_id_.emplace(string(name), DeviceId(new_id)); + name_to_id_.emplace(std::string(name), DeviceId(new_id)); const XlaOpRegistry::DeviceRegistration* compilation_device; if (!XlaOpRegistry::GetCompilationDevice(device_type->type(), @@ -76,10 +76,10 @@ absl::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { return DeviceId(new_id); } -string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { - std::vector names; +std::string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { + std::vector names; device_set.ForEach([&](DeviceId device_id) { - names.push_back(string(GetNameFor(device_id))); + names.push_back(std::string(GetNameFor(device_id))); return true; }); @@ -87,7 +87,7 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { } } // namespace jit -absl::Status DeviceNameToDeviceType(const string& device, +absl::Status DeviceNameToDeviceType(const std::string& device, DeviceType* device_type) { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(device, &parsed)) { diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index 745f87309501d8..fa862aac88c394 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -75,9 +75,9 @@ class DeviceSet { // iterator if this ends up being used widely. for (int word_index = 0, end = storage_.size(); word_index < end; word_index++) { - uint64 word = storage_[word_index]; + uint64_t word = storage_[word_index]; while (word != 0) { - uint64 only_lowest_bit_set = word & -word; + uint64_t only_lowest_bit_set = word & -word; // The number of trailing zeros in a non-zero word is the index of the // least significant 1. int bit_index = absl::countr_zero(word); @@ -90,7 +90,7 @@ class DeviceSet { } private: - absl::InlinedVector storage_; + absl::InlinedVector storage_; const int kWordSize = 64; }; @@ -131,17 +131,17 @@ class DeviceInfoCache { return std::cref(*id_to_device_type_[device_id.id()]); } - string DebugString(const DeviceSet& device_set) const; + std::string DebugString(const DeviceSet& device_set) const; private: - absl::flat_hash_map name_to_id_; + absl::flat_hash_map name_to_id_; // These fields are populated for a device in GetIdFor, *before* we give out a // DeviceId. std::vector id_to_compilation_device_; std::vector> id_to_device_type_; - std::vector names_; + std::vector names_; std::vector is_cpu_; std::vector is_gpu_; }; @@ -149,7 +149,7 @@ class DeviceInfoCache { } // namespace jit // Returns the DeviceType corresponding to 'device'. -absl::Status DeviceNameToDeviceType(const string& device, +absl::Status DeviceNameToDeviceType(const std::string& device, DeviceType* device_type); // Picks the device for which XLA should compile a cluster that contains diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc index cef39df6283f2b..be58292f931686 100644 --- a/tensorflow/compiler/jit/device_util_test.cc +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -23,7 +23,7 @@ namespace { absl::Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, absl::Span device_names, - string* result) { + std::string* result) { jit::DeviceInfoCache cache; jit::DeviceSet device_set; for (absl::string_view name : device_names) { @@ -34,14 +34,14 @@ absl::Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, TF_ASSIGN_OR_RETURN( jit::DeviceId result_id, PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu)); - *result = string(cache.GetNameFor(result_id)); + *result = std::string(cache.GetNameFor(result_id)); return absl::OkStatus(); } void CheckPickDeviceResult(absl::string_view expected_result, bool allow_mixing_unknown_and_cpu, absl::Span inputs) { - string result; + std::string result; TF_ASSERT_OK(PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result)) << "inputs = [" << absl::StrJoin(inputs, ", ") << "], allow_mixing_unknown_and_cpu=" << allow_mixing_unknown_and_cpu @@ -51,7 +51,7 @@ void CheckPickDeviceResult(absl::string_view expected_result, void CheckPickDeviceHasError(bool allow_mixing_unknown_and_cpu, absl::Span inputs) { - string result; + std::string result; EXPECT_FALSE( PickDeviceHelper(allow_mixing_unknown_and_cpu, inputs, &result).ok()); } @@ -110,10 +110,10 @@ void SimpleRoundTripTestForDeviceSet(int num_devices) { jit::DeviceSet device_set; jit::DeviceInfoCache device_info_cache; - std::vector expected_devices, actual_devices; + std::vector expected_devices, actual_devices; for (int i = 0; i < num_devices; i++) { - string device_name = + std::string device_name = absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i); TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id, device_info_cache.GetIdFor(device_name)); @@ -122,7 +122,8 @@ void SimpleRoundTripTestForDeviceSet(int num_devices) { } device_set.ForEach([&](jit::DeviceId device_id) { - actual_devices.push_back(string(device_info_cache.GetNameFor(device_id))); + actual_devices.push_back( + std::string(device_info_cache.GetNameFor(device_id))); return true; }); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 3e8a43ce08ed58..6e7d16de16a4f6 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -115,7 +115,7 @@ void MarkGuaranteedConstants( } struct OutputInputTensorPairHasher { - uint64 operator()(std::pair const& s) const { + uint64_t operator()(std::pair const& s) const { return Hash64Combine(OutputTensor::Hash()(s.first), InputTensor::Hash()(s.second)); } @@ -128,7 +128,7 @@ static const char* const kRetValOp = "_Retval"; class Encapsulator { public: - Encapsulator(string group_attribute, Graph const* graph_in) + Encapsulator(std::string group_attribute, Graph const* graph_in) : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} // Find subgraphs marked with 'group_attribute', and build a new @@ -182,7 +182,7 @@ class Encapsulator { // 'reuse_existing_functions' is set, use an existing function with the same // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the // subgraph before function conversion. - absl::Status BuildFunctionDef(const string& name_in, + absl::Status BuildFunctionDef(const std::string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library); @@ -226,7 +226,7 @@ class Encapsulator { const absl::flat_hash_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. - absl::Status MakeSequencingNode(const string& subgraph_name, + absl::Status MakeSequencingNode(const std::string& subgraph_name, Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to @@ -243,14 +243,14 @@ class Encapsulator { // Which device are these nodes on? Used to assign a device to the call // node. - string device_; + std::string device_; // NodeDef for the function call node. NodeDef call_node_def_; // Name that is used for the call node. This may not be // call_node_def_.name() if the client supplies a rewrite lambda. - string function_def_name_; + std::string function_def_name_; // Placeholder node simulating the host compute key in the output graph. // Not owned. @@ -275,7 +275,7 @@ class Encapsulator { // Set of node names that are the source of a control output of the // subgraph. We store strings here so that we can tolerate nodes being // removed from the graph. - absl::flat_hash_set control_output_nodes_; + absl::flat_hash_set control_output_nodes_; // NoOp node in the output graph that is sequenced after the call node. Node* sequencer_ = nullptr; @@ -283,7 +283,7 @@ class Encapsulator { // Returns the key attribute associated with a node in attr. Sets either // result to the empty string if the respective attribute is not found. - absl::Status GetFunctionNameAttr(Node const* node, string* attr) const; + absl::Status GetFunctionNameAttr(Node const* node, std::string* attr) const; // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to // subgraphs for data edges that cross subgraph boundaries. @@ -308,36 +308,35 @@ class Encapsulator { // a subgraph boundary it is the output of a call node, otherwise it is a node // in the output graph. absl::Status FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image); // Finds an edge source slot in the output graph. If the edge crosses a // subgraph boundary it is a slot on the output of a call node, otherwise it // is a slot on a node in the output graph. - int FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& dst_func_id, - const Edge* edge); + int FindOutputSlotOfEdgeSrc(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge); // Finds the image of an edge destination in the output graph. If the edge // crosses a subgraph boundary it is the input of a call node, otherwise it is // a node in the output graph. absl::Status FindOutputImageOfEdgeDst( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image); // Finds an edge destination slot in the output graph. If the edge crosses a // subgraph boundary it is a slot on the input of a call node, otherwise it is // a slot on a node in the output graph. - int FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& dst_func_id, - const Edge* edge); + int FindOutputSlotOfEdgeDst(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge); // Copies a single edge to the output graph. The edge is either entirely // within the output graph, or crosses into or out of a compiled subgraph. absl::Status CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, const string& dst_func_id, + const Edge* edge, const std::string& src_func_id, + const std::string& dst_func_id, const absl::flat_hash_map& node_images, Graph* graph_out, absl::flat_hash_set, @@ -358,10 +357,10 @@ class Encapsulator { absl::flat_hash_map* node_images, FunctionLibraryDefinition* library); - const string group_attribute_; + const std::string group_attribute_; const Graph* graph_in_; - absl::flat_hash_map subgraphs_; + absl::flat_hash_map subgraphs_; Encapsulator(const Encapsulator&) = delete; void operator=(const Encapsulator&) = delete; @@ -374,19 +373,20 @@ namespace { // including clusters that are not present in the ancestors map. has_successors // is the set of clusters that are ancestors of some other cluster. void TopologicalClusterSort( - const absl::flat_hash_set& clusters, - const absl::flat_hash_set& has_successors, - const absl::flat_hash_map>& ancestors, - std::vector* sorted) { + const absl::flat_hash_set& clusters, + const absl::flat_hash_set& has_successors, + const absl::flat_hash_map>& + ancestors, + std::vector* sorted) { // The nodes are placed in 'sorted' in topological order. sorted->clear(); // We don't use the standard DFS because we are not operating on Node* // objects. struct Work { - string cluster; + std::string cluster; bool leave; }; - std::set visited; + std::set visited; std::vector stack; // Seed the processing list with clusters that have no successors. for (const auto& cluster : clusters) { @@ -523,7 +523,7 @@ absl::Status Encapsulator::Subgraph::RecordResult( } absl::Status Encapsulator::Subgraph::MakeSequencingNode( - const string& subgraph_name, Graph* graph_out) { + const std::string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; // TODO(shikharagarwal): What source node should we use for errors? @@ -547,11 +547,11 @@ void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { } absl::Status Encapsulator::Subgraph::BuildFunctionDef( - const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + const std::string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { // name_in is copied here because name may be modified below if // rewrite_subgraph_fn is true. - string name = name_in; + std::string name = name_in; call_node_def_.set_op(name); call_node_def_.set_name(name); call_node_def_.set_device(device_); @@ -596,7 +596,7 @@ absl::Status Encapsulator::Subgraph::BuildFunctionDef( function_def_name_ = name; FunctionDef fdef; - auto lookup = [this](const Node* node) -> std::optional { + auto lookup = [this](const Node* node) -> std::optional { if (control_output_nodes_.contains(node->name())) { return std::make_optional(node->name()); } @@ -625,7 +625,7 @@ absl::Status Encapsulator::Subgraph::BuildFunctionDef( absl::Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { - const string& name = function_def_name_; + const std::string& name = function_def_name_; FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -654,7 +654,7 @@ absl::Status Encapsulator::Subgraph::AddFunctionCallNode( } absl::Status Encapsulator::GetFunctionNameAttr(Node const* node, - string* attr) const { + std::string* attr) const { AttrSlice attrs = node->attrs(); attr->clear(); for (const auto& node_attr : attrs) { @@ -667,12 +667,12 @@ absl::Status Encapsulator::GetFunctionNameAttr(Node const* node, return absl::OkStatus(); } -bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } +bool IsInSubgraph(const std::string& func_id) { return !func_id.empty(); } absl::Status Encapsulator::CopySubgraphNodes( absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id; + std::string func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); if (!IsInSubgraph(func_id)) continue; @@ -688,9 +688,9 @@ absl::Status Encapsulator::CopySubgraphEdges( const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs) { for (const Edge* edge : graph_in_->edges()) { - string src_func_id; + std::string src_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); - string dst_func_id; + std::string dst_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); @@ -793,7 +793,7 @@ absl::Status Encapsulator::BuildFunctionDefs( const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { for (auto& subgraph_entry : subgraphs_) { - string name = subgraph_entry.first; + std::string name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( name, rewrite_subgraph_fn, reuse_existing_functions, library)); @@ -804,7 +804,7 @@ absl::Status Encapsulator::BuildFunctionDefs( absl::Status Encapsulator::CopyNodesToOutputGraph( Graph* graph_out, absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id; + std::string func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id)); // Don't copy nodes that are going to be encapsulated. @@ -829,7 +829,7 @@ absl::Status Encapsulator::AddFunctionCallNodes( } absl::Status Encapsulator::FindOutputImageOfEdgeSrc( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image) { if (IsInSubgraph(src_func_id)) { @@ -844,8 +844,8 @@ absl::Status Encapsulator::FindOutputImageOfEdgeSrc( return absl::OkStatus(); } -int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, - const string& dst_func_id, +int Encapsulator::FindOutputSlotOfEdgeSrc(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge) { if (IsInSubgraph(src_func_id)) { const Subgraph& src_subgraph = subgraphs_.at(src_func_id); @@ -860,7 +860,7 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, } absl::Status Encapsulator::FindOutputImageOfEdgeDst( - const string& src_func_id, const string& dst_func_id, + const std::string& src_func_id, const std::string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image) { if (IsInSubgraph(dst_func_id)) { @@ -875,8 +875,8 @@ absl::Status Encapsulator::FindOutputImageOfEdgeDst( return absl::OkStatus(); } -int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, - const string& dst_func_id, +int Encapsulator::FindOutputSlotOfEdgeDst(const std::string& src_func_id, + const std::string& dst_func_id, const Edge* edge) { if (IsInSubgraph(dst_func_id)) { const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); @@ -891,7 +891,8 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, } absl::Status Encapsulator::CopyEdgeToOutputGraph( - const Edge* edge, const string& src_func_id, const string& dst_func_id, + const Edge* edge, const std::string& src_func_id, + const std::string& dst_func_id, const absl::flat_hash_map& node_images, Graph* graph_out, absl::flat_hash_set, @@ -943,9 +944,9 @@ absl::Status Encapsulator::AddEdgesToOutputGraph( edges_added; for (const Edge* edge : graph_in_->edges()) { - string src_func_id; + std::string src_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id)); - string dst_func_id; + std::string dst_func_id; TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id)); // Ignore edges that are strictly contained within one subgraph, unless @@ -1091,7 +1092,7 @@ absl::Status Encapsulator::BuildOutputGraph( } // anonymous namespace absl::Status EncapsulateSubgraphsInFunctions( - string group_attribute, const Graph& graph_in, + std::string group_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { Encapsulator encapsulator(std::move(group_attribute), diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 0c7729f67349b5..ed2c9ef45a2c16 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -73,7 +73,7 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 1e05ad067def7f..94b136a02b99cf 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -46,7 +46,7 @@ const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; absl::Status AddGraphDefToFunctionLibrary( - const GraphDefBuilder& graphdef_builder, const string& name_suffix, + const GraphDefBuilder& graphdef_builder, const std::string& name_suffix, FunctionDefLibrary* library) { GraphDef graphdef; TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); @@ -64,13 +64,14 @@ absl::Status AddGraphDefToFunctionLibrary( } template -bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, - const ::tensorflow::protobuf::Map& b, - const std::function& key_to_string, - const std::function& value_to_string, - const std::function& compare, - const string& map_name, string* diff) { +bool EqualProtoMap( + const ::tensorflow::protobuf::Map& a, + const ::tensorflow::protobuf::Map& b, + const std::function& key_to_string, + const std::function& value_to_string, + const std::function& + compare, + const std::string& map_name, std::string* diff) { for (const auto& elt_a : a) { const auto iter = b.find(elt_a.first); if (iter == b.end()) { @@ -106,7 +107,7 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, } bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, - const string& diff_preamble, string* diff) { + const std::string& diff_preamble, std::string* diff) { if (a.op() != b.op()) { if (diff) { *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), @@ -131,8 +132,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } return false; } - std::unordered_set control_input_a; - std::unordered_set control_input_b; + std::unordered_set control_input_a; + std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { if (absl::StartsWith(a.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) { @@ -164,17 +165,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } return false; } - return EqualProtoMap( - a.attr(), b.attr(), [](const string& s) { return s; }, + return EqualProtoMap( + a.attr(), b.attr(), [](const std::string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, - [](const string& key, const AttrValue& av, const AttrValue& bv) { + [](const std::string& key, const AttrValue& av, const AttrValue& bv) { if (key == "ancestors") { // The ancestors are added from a set so the order is unpredictable; // just compare set equality not list equality. - std::unordered_set a_set(av.list().s().begin(), - av.list().s().end()); - std::unordered_set b_set(bv.list().s().begin(), - bv.list().s().end()); + std::unordered_set a_set(av.list().s().begin(), + av.list().s().end()); + std::unordered_set b_set(bv.list().s().begin(), + bv.list().s().end()); return a_set == b_set; } else { return av.DebugString() == bv.DebugString(); @@ -184,7 +185,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, - string* diff) { + std::string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { *diff = @@ -194,22 +195,21 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } return false; } - if (!EqualProtoMap( - a.attr(), b.attr(), [](const string& s) { return s; }, + if (!EqualProtoMap( + a.attr(), b.attr(), [](const std::string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, - [](const string& key, const AttrValue& av, const AttrValue& bv) { + [](const std::string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } - if (!EqualProtoMap( - a.ret(), b.ret(), [](const string& s) { return s; }, - [](const string& s) { return s; }, - [](const string& key, const string& av, const string& bv) { - return av == bv; - }, + if (!EqualProtoMap( + a.ret(), b.ret(), [](const std::string& s) { return s; }, + [](const std::string& s) { return s; }, + [](const std::string& key, const std::string& av, + const std::string& bv) { return av == bv; }, absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; @@ -257,8 +257,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, - const FunctionDefLibrary& actual, string* diff) { - std::unordered_map actual_index; + const FunctionDefLibrary& actual, + std::string* diff) { + std::unordered_map actual_index; for (const FunctionDef& function : actual.function()) { actual_index[function.signature().name()] = &function; } @@ -343,7 +344,7 @@ REGISTER_OP("AddNLikeTest") .SetIsAggregate(); Node* Sequencer(const GraphDefBuilder::Options& opts, - const string& call_node_name) { + const std::string& call_node_name) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp", opts.op_registry()); @@ -383,7 +384,7 @@ Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) { return KnownShapeBase(DT_STRING, {2}, opts); } -Node* KeyPlaceholder(const string& call_node, +Node* KeyPlaceholder(const std::string& call_node, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"), @@ -396,15 +397,16 @@ Node* KeyPlaceholder(const string& call_node, .FinalizeBuilder(&node_builder); } -Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, - const string& new_func_name, const string& oc_cluster, +Node* RecvAtHost(ops::NodeOut key_input, const std::string& cluster, + const std::string& new_func_name, + const std::string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = absl::StrCat("host_compute_channel_", cluster, "_", - new_func_name, "_", oc_cluster); - string name = absl::StrCat("outside_compilation_", cluster, "_", - new_func_name, "_", oc_cluster, "_recv"); + std::string key = absl::StrCat("host_compute_channel_", cluster, "_", + new_func_name, "_", oc_cluster); + std::string name = absl::StrCat("outside_compilation_", cluster, "_", + new_func_name, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -416,15 +418,16 @@ Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, .FinalizeBuilder(&node_builder); } -Node* SendFromHost(ops::NodeOut key_input, const string& cluster, - const string& new_func_name, const string& oc_cluster, +Node* SendFromHost(ops::NodeOut key_input, const std::string& cluster, + const std::string& new_func_name, + const std::string& oc_cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = absl::StrCat("host_compute_channel_", cluster, "_", - new_func_name, "_", oc_cluster); - string name = absl::StrCat("outside_compilation_", cluster, "_", - new_func_name, "_", oc_cluster, "_send"); + std::string key = absl::StrCat("host_compute_channel_", cluster, "_", + new_func_name, "_", oc_cluster); + std::string name = absl::StrCat("outside_compilation_", cluster, "_", + new_func_name, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -477,8 +480,9 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { return opts.FinalizeBuilder(&node_builder); } -absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, - const std::vector& encapsulated_functions) { +absl::Status Encapsulate( + GraphDef* graphdef, FunctionDefLibrary* library, + const std::vector& encapsulated_functions) { absl::Status s; // Convert the GraphDef to a Graph std::unique_ptr lib_def( @@ -512,7 +516,7 @@ absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, &graph_out, lib_def.get()); if (!s.ok()) return s; - std::unordered_map clusters; + std::unordered_map clusters; for (const auto& func : encapsulated_functions) { Node* xla_computation_node; for (Node* n : graph_out->nodes()) { @@ -527,7 +531,7 @@ absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, func_name_attrs.set_name(func); clusters.emplace(func, XlaClusterInfo{func, func_name_attrs, xla_computation_node, - std::map{}}); + std::map{}}); } bool modified; s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters, @@ -551,7 +555,7 @@ absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, } absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { - std::vector encapsulated_functions; + std::vector encapsulated_functions; return Encapsulate(graphdef, library, encapsulated_functions); } @@ -698,8 +702,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { } // Returns a vector of node names in 'graph', sorted by name. -std::vector GraphNodes(const Graph& graph) { - std::vector nodes; +std::vector GraphNodes(const Graph& graph) { + std::vector nodes; for (const auto& node : graph.nodes()) { if (!node->IsSource() && !node->IsSink()) { nodes.push_back(node->name()); @@ -710,8 +714,9 @@ std::vector GraphNodes(const Graph& graph) { } // Returns a sorted vector of (src, dst) edges in 'graph'. -std::vector> GraphEdges(const Graph& graph) { - std::vector> edges; +std::vector> GraphEdges( + const Graph& graph) { + std::vector> edges; for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( @@ -742,10 +747,11 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); - std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; + std::vector expected_nodes = {"cluster1", "cluster2", "mul", + "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - std::vector> expected_edges = { + std::vector> expected_edges = { {"cluster1:0", "cluster2:0"}, {"cluster1:0", "mul:0"}, {"cluster2:0", "mul:1"}, @@ -753,7 +759,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -const Node* FindNodeByName(const Graph& graph, const string& name) { +const Node* FindNodeByName(const Graph& graph, const std::string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; } @@ -889,7 +895,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -931,7 +937,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"C:o:0", "c:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -941,7 +947,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"c"}}, @@ -1025,7 +1031,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1102,7 +1108,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"F:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -1112,8 +1118,9 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"F", "outside_compilation_O1_host_compute"}}, @@ -1122,7 +1129,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1132,7 +1139,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1235,7 +1242,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1", "F2"}; + std::vector encapsulated_functions{"F1", "F2"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1262,7 +1269,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1273,7 +1280,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1295,7 +1302,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"d_0_arg", "G:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1306,7 +1313,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1409,7 +1416,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1", "F2"}; + std::vector encapsulated_functions{"F1", "F2"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1432,7 +1439,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1443,7 +1450,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1462,7 +1469,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"G:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1473,7 +1480,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1556,7 +1563,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1578,7 +1585,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1589,7 +1596,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1652,7 +1659,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1674,7 +1681,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1685,7 +1692,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"D"}}, @@ -1748,7 +1755,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1785,7 +1792,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1795,7 +1802,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1858,7 +1865,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -1899,7 +1906,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -1909,7 +1916,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -1978,7 +1985,7 @@ TEST(EncapsulateSubgraphsTest, TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2037,7 +2044,7 @@ TEST(EncapsulateSubgraphsTest, {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2047,7 +2054,7 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2055,7 +2062,7 @@ TEST(EncapsulateSubgraphsTest, {"F:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -2065,8 +2072,9 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"outside_compilation_O1_host_compute"}}, @@ -2149,7 +2157,7 @@ TEST(EncapsulateSubgraphsTest, TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2189,7 +2197,7 @@ TEST(EncapsulateSubgraphsTest, {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -2199,8 +2207,9 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", + "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"outside_compilation_O1_host_compute"}}, @@ -2209,7 +2218,7 @@ TEST(EncapsulateSubgraphsTest, {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2219,7 +2228,7 @@ TEST(EncapsulateSubgraphsTest, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -2303,7 +2312,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2340,7 +2349,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2350,7 +2359,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, {{"outside_compilation_O2_host_compute"}, @@ -2358,7 +2367,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"send_key", ""}, {"recv_key", ""}, @@ -2368,7 +2377,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", - absl::Span( + absl::Span( {"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O2_host_compute"}}, {"outside_compilation_O1_host_compute"}}, @@ -2377,7 +2386,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"D:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O3"}, {"send_key", ""}, {"recv_key", ""}, @@ -2387,9 +2396,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node", - "outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + absl::Span( + {"_xla_token_arg_node", "outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, {"_xla_original_oc_node_name", "outside_compilation_O3_host_compute"}}, {"outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"}}}, @@ -2470,7 +2479,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2507,7 +2516,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"a_0_arg"}, {{"Tinputs", absl::Span({DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2517,7 +2526,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}}, }, @@ -2586,7 +2595,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } - std::vector encapsulated_functions{"F1"}; + std::vector encapsulated_functions{"F1"}; TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions)); FunctionDefLibrary library_expected; @@ -2627,7 +2636,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"c_0_arg", "c:o:0"}, {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"Toutputs", absl::Span({DT_FLOAT})}, - {"ancestors", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"send_key", ""}, {"recv_key", ""}, @@ -2637,7 +2646,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", - absl::Span({"_xla_token_arg_node"})}, + absl::Span({"_xla_token_arg_node"})}, {"_xla_original_oc_node_name", "outside_compilation_O1_host_compute"}}, {"c"}}, diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index fa94a341bbabc6..445ca63c05ad66 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -36,7 +36,8 @@ namespace { // Returns string attribute value for the node if the attribute is present, // otherwise returns empty optional value. -std::optional GetStringAttr(const Node& n, const string& attr_name) { +std::optional GetStringAttr(const Node& n, + const std::string& attr_name) { auto attr = n.attrs().Find(attr_name); if (!attr) { return std::nullopt; @@ -47,8 +48,8 @@ std::optional GetStringAttr(const Node& n, const string& attr_name) { // Adds a value to the node's list attribute. template -absl::Status AppendToListAttr(Node* n, const string& attr_name, - const string& value) { +absl::Status AppendToListAttr(Node* n, const std::string& attr_name, + const std::string& value) { std::vector attr_value; absl::Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value); if (!s.ok() && s.code() != error::NOT_FOUND) { @@ -63,7 +64,7 @@ absl::Status AppendToListAttr(Node* n, const string& attr_name, // Replaces attribute value. template -void ReplaceAttr(Node* n, const string& attr_name, const T& value) { +void ReplaceAttr(Node* n, const std::string& attr_name, const T& value) { n->ClearAttr(attr_name); n->AddAttr(attr_name, value); } @@ -71,7 +72,7 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. absl::Status PreprocessControlEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Gather edges to remove. We should not remove the edge while iterating. std::vector edges_to_remove; for (const Edge* e : g->edges()) { @@ -89,7 +90,7 @@ absl::Status PreprocessControlEdgesBetweenOutsideCompilations( // Case 1a: outside compilation to outside compilation control edge. edges_to_remove.push_back(e); - TF_RETURN_IF_ERROR(AppendToListAttr( + TF_RETURN_IF_ERROR(AppendToListAttr( e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName, e->src()->name())); } @@ -111,7 +112,7 @@ absl::Status PreprocessControlEdgesBetweenOutsideCompilations( // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. absl::Status PreprocessDataEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Gather edges between outside compilation and host computation. Notice that // we do not store `Edge*` directly because we remove some nodes while adding // Identity nodes, and those Edge pointers might be invalidated. @@ -138,7 +139,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( // Remove the edge from host to outside compilation. Add a placeholder as // outside compilation node input. - std::map, Node*> placeholders; + std::map, Node*> placeholders; for (int i = 0, end = edges.size(); i < end; i++) { Node* dst = g->FindNodeId(edges[i].dst_node_id); const Edge* e; @@ -148,7 +149,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( g->RemoveEdge(e); // Find or create placeholder node. - string new_name = + std::string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output); auto placeholder_index = std::make_pair(src->name(), src_output); auto iter = placeholders.find(placeholder_index); @@ -156,7 +157,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( if (iter == placeholders.end()) { NodeDefBuilder placeholder_builder(new_name, "Placeholder"); placeholder_builder.Attr("dtype", src->output_type(src_output)); - string outside_compilation_attr; + std::string outside_compilation_attr; TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(), outside_compilation_attr_name, &outside_compilation_attr)); @@ -195,7 +196,7 @@ absl::Status PreprocessDataEdgesBetweenOutsideCompilations( // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of // `PostprocessEdgesBetweenOutsideCompilations` for details. absl::Status PostprocessDataEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Gather all outside compilation to outside compilation nodes. std::vector placeholder_nodes; for (Node* n : g->nodes()) { @@ -208,7 +209,7 @@ absl::Status PostprocessDataEdgesBetweenOutsideCompilations( // Remove the placeholder nodes, and reconnect original edge. auto node_name_index = g->BuildNodeNameIndex(); for (auto n : placeholder_nodes) { - string node_name; + std::string node_name; int node_src_output; TF_RETURN_IF_ERROR(GetNodeAttr( n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name)); @@ -271,12 +272,12 @@ absl::Status PostprocessDataEdgesBetweenOutsideCompilations( // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of // `PostprocessEdgesBetweenOutsideCompilations` for details. absl::Status PostprocessControlEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { auto node_name_index = g->BuildNodeNameIndex(); // Reconnect outside compilation to outside compilation control edge. for (Node* n : g->nodes()) { - std::vector control_deps; + std::vector control_deps; absl::Status s = GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, &control_deps); @@ -288,7 +289,7 @@ absl::Status PostprocessControlEdgesBetweenOutsideCompilations( } } else { n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName); - for (const string& control_input : control_deps) { + for (const std::string& control_input : control_deps) { auto iter = node_name_index.find(control_input); if (iter == node_name_index.end()) { return errors::Internal("Cannot find original node for ", @@ -342,11 +343,11 @@ absl::Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { } absl::StatusOr< - std::unique_ptr>>> + std::unique_ptr>>> OutsideCompilationClusterDependencies( - const Graph* g, const string& outside_compilation_attr_name) { + const Graph* g, const std::string& outside_compilation_attr_name) { auto cluster_deps = std::make_unique< - absl::flat_hash_map>>(); + absl::flat_hash_map>>(); for (const Edge* e : g->edges()) { auto src_outside_compilation = @@ -360,18 +361,18 @@ OutsideCompilationClusterDependencies( if (dst_deps_it == cluster_deps->end()) { cluster_deps->insert(std::make_pair( *dst_outside_compilation, - absl::flat_hash_set({*src_outside_compilation}))); + absl::flat_hash_set({*src_outside_compilation}))); } else { dst_deps_it->second.insert(*src_outside_compilation); } } } - auto cluster_deps_ordered = - std::make_unique>>(); + auto cluster_deps_ordered = std::make_unique< + absl::flat_hash_map>>(); for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) { - std::vector ordered_deps(it->second.begin(), it->second.end()); + std::vector ordered_deps(it->second.begin(), it->second.end()); std::sort(ordered_deps.begin(), ordered_deps.end()); cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps)); } @@ -380,7 +381,7 @@ OutsideCompilationClusterDependencies( } absl::Status PreprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges // from outside compilation nodes to sink node. std::vector edges_to_remove; @@ -406,7 +407,7 @@ absl::Status PreprocessEdgesBetweenOutsideCompilations( } absl::Status PostprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations( diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index 7c99763c770728..81ab31c79dcda2 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -95,21 +95,21 @@ struct XlaClusterInfo { // without losing aggregate initialization, which allows us to get rid of // the constructor definitions again. XlaClusterInfo() {} - XlaClusterInfo(const string& cluster_name, + XlaClusterInfo(const std::string& cluster_name, const NameAttrList& func_name_attrs, Node* node, - const std::map& host_compute_core) + const std::map& host_compute_core) : cluster_name(cluster_name), func_name_attrs(func_name_attrs), node(node), host_compute_core(host_compute_core) {} // XLA cluster name. It might be different from `func_name`. - const string cluster_name; + const std::string cluster_name; // Name and attributes of XLA computation function. const NameAttrList func_name_attrs; // The XLA computation node in the graph. Node* node; // A mapping from outside compilation cluster name to its device assignment. - const std::map host_compute_core; + const std::map host_compute_core; }; // Finds dependencies between outside compilation clusters, including both data @@ -117,9 +117,9 @@ struct XlaClusterInfo { // outside compilation cluster to a set of names of outside compilation clusters // that it depends on. absl::StatusOr< - std::unique_ptr>>> + std::unique_ptr>>> OutsideCompilationClusterDependencies( - const Graph* g, const string& outside_compilation_attr_name); + const Graph* g, const std::string& outside_compilation_attr_name); // Preprocesses edges within the same XLA cluster. It will perform the following // operations in order: @@ -135,7 +135,7 @@ OutsideCompilationClusterDependencies( // 2. For data edges between different outside compilations, remove the edge // and create a Placeholder node as dst node's input. absl::Status PreprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name); + Graph* g, const std::string& outside_compilation_attr_name); // Postprocesses edges within the same XLA cluster. This function reverts what // `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the @@ -149,7 +149,7 @@ absl::Status PreprocessEdgesBetweenOutsideCompilations( // `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. // They are handled in `RewriteOutsideCompilationSubgraphFn`. absl::Status PostprocessEdgesBetweenOutsideCompilations( - Graph* g, const string& outside_compilation_attr_name); + Graph* g, const std::string& outside_compilation_attr_name); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 0e59bf0c19d93e..8ba11404010363 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -46,7 +46,7 @@ const char* const kXlaClusterOutput = "XlaClusterOutput"; bool IsCpuGpuCompile(const Graph* graph) { for (Node* n : graph->nodes()) { - string name; + std::string name; // Only consider nodes being compiled. if (!TryGetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name)) continue; // Early return for any node with a device that is not a CPU or GPU. @@ -185,7 +185,7 @@ absl::Status RewriteSubgraph( // Uniquify the function name by computing a fingerprint of the function. // Nondeterminism in serialization would not lead to incorrect results, but // may cause spurious cache misses. - TF_ASSIGN_OR_RETURN(uint64 fingerprint, FingerprintGraph(*graph)); + TF_ASSIGN_OR_RETURN(uint64_t fingerprint, FingerprintGraph(*graph)); VLOG(1) << "Subgraph fingerprint:" << fingerprint; call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); return absl::OkStatus(); @@ -360,7 +360,8 @@ absl::Status RewriteSubgraph( /*static*/ absl::Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( Graph* graph) { const auto is_xla_launch_node = [](const Node& node) -> absl::StatusOr { - const string& name = GetNodeAttrString(node.attrs(), kXlaClusterIdAttr); + const std::string& name = + GetNodeAttrString(node.attrs(), kXlaClusterIdAttr); return !name.empty(); }; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index 16a17c3c2a03a6..acd5319cf8ed16 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -34,7 +34,7 @@ limitations under the License. namespace tensorflow { static std::unique_ptr MakeOuterGraph( - const FunctionLibraryDefinition& flib_def, const string& function) { + const FunctionLibraryDefinition& flib_def, const std::string& function) { Scope scope = Scope::NewRootScope().ExitOnError(); TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); @@ -143,7 +143,7 @@ TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { // Test that control edge insertion order doesn't affect the cache key // (cluster name) generated by TPU encapsulate pass. auto get_serialized_graph = [](bool control_input_reversed, - bool operand_reversed) -> string { + bool operand_reversed) -> std::string { FunctionLibraryDefinition flib_def(OpRegistry::Global(), FunctionDefLibrary()); std::unique_ptr graph(new Graph(&flib_def)); @@ -250,8 +250,8 @@ TEST(EncapsulateXlaComputations, Encapsulate) { TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - std::unordered_map index = graph->BuildNodeNameIndex(); - string function = index.at("launch0")->type_string(); + std::unordered_map index = graph->BuildNodeNameIndex(); + std::string function = index.at("launch0")->type_string(); // Tests the outer graph is as expected. { @@ -285,9 +285,9 @@ TEST(EncapsulateXlaComputations, Encapsulate) { // function. Encapsulation should be deterministic to avoid recompilation. TF_ASSERT_OK( EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); - std::unordered_map index_copy = + std::unordered_map index_copy = graph_copy->BuildNodeNameIndex(); - string function_copy = index_copy.at("launch0")->type_string(); + std::string function_copy = index_copy.at("launch0")->type_string(); EXPECT_EQ(function, function_copy); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 140c47dbcac804..05514f00bd29d5 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -42,7 +42,7 @@ namespace { // Control return mapping function for outside compilation host graphs. // All nodes with kXlaHasHostTransfer attribute are control outputs. -std::optional HostGraphControlRetMapping(const Node* n) { +std::optional HostGraphControlRetMapping(const Node* n) { if (HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { return n->name(); } @@ -52,7 +52,7 @@ std::optional HostGraphControlRetMapping(const Node* n) { // Add a key placeholder node to the graph. The key placeholder node will be // used as input for XlaRecvAtHost/XlaSendFromHost nodes. absl::StatusOr AddHostComputeKeyPlaceholder( - const string& xla_cluster_name, Graph* g) { + const std::string& xla_cluster_name, Graph* g) { NodeDef key_def; NodeDefBuilder builder(absl::StrCat(xla_cluster_name, "_key_placeholder"), "Placeholder"); @@ -74,7 +74,8 @@ bool IsKeyPlaceholderNode(const Node& n) { } // Returns nodes with given type. -std::vector GatherNodesWithType(const Graph& g, const string& type) { +std::vector GatherNodesWithType(const Graph& g, + const std::string& type) { std::vector result; for (Node* n : g.nodes()) { if (n->type_string() == type) { @@ -105,7 +106,7 @@ absl::Status GetArgDataTypes(const std::vector& arg_nodes, // Builds XlaRecvAtHost node. absl::StatusOr BuildRecvAtHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, const std::vector& recv_at_host_dtypes, Node* key_placeholder) { NodeDefBuilder recv_at_host_builder( absl::StrCat("outside_compilation_", oc_cluster_name, "_recv"), @@ -128,7 +129,7 @@ absl::StatusOr BuildRecvAtHostNode( // Builds XlaRecvAtHost node, and replaces all _Arg nodes with it. absl::StatusOr ReplaceArgNodesWithRecvAtHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, std::vector* recv_at_host_dtypes, Node* key_placeholder) { // TODO(b/77601805): use out nodes for source node, instead of traversing all // nodes. @@ -205,7 +206,7 @@ absl::Status GetRetDataTypes(const std::vector& ret_nodes, // Builds XlaSendFromHost node. absl::StatusOr BuildSendFromHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, const std::vector& ret_nodes, const std::vector& send_from_host_dtypes, Node* key_placeholder) { NodeDefBuilder send_from_host_builder( @@ -245,7 +246,7 @@ absl::StatusOr BuildSendFromHostNode( // Builds XlaSendFromHost node, and replaces all _Retval nodes with it. absl::StatusOr ReplaceRetNodesWithSendFromHostNode( - Graph* g, const string& oc_cluster_name, + Graph* g, const std::string& oc_cluster_name, std::vector* send_from_host_dtypes, Node* key_placeholder) { // TODO(b/77601805): use in nodes for sink node, instead of traversing all // nodes. @@ -299,16 +300,17 @@ std::optional> GetInferredInputShapes( return results; } -string host_compute_node_name(const string& original_oc_name) { +std::string host_compute_node_name(const std::string& original_oc_name) { return absl::StrCat("outside_compilation_", original_oc_name, "_host_compute"); } // Builds XlaHostCompute NodeDef from the outside compilation call node. absl::StatusOr BuildXlaHostComputeNodeDef( - const Node* call_node, const std::map& host_compute_core, - const absl::flat_hash_map>& cluster_deps) { - string original_oc_name; + const Node* call_node, const std::map& host_compute_core, + const absl::flat_hash_map>& + cluster_deps) { + std::string original_oc_name; TF_RETURN_IF_ERROR(GetNodeAttr( call_node->attrs(), "_outside_compilation_subgraph", &original_oc_name)); NodeDefBuilder host_compute_builder(host_compute_node_name(original_oc_name), @@ -341,7 +343,7 @@ absl::StatusOr BuildXlaHostComputeNodeDef( // according to their host-side graph dependency. This can cause deadlock. // Therefore, we hint XLA what the correct ordering of these clusters should // be to avoid deadlocks. - std::vector xla_token_input_nodes; + std::vector xla_token_input_nodes; xla_token_input_nodes.emplace_back(kXlaTokenArgNodeName); auto cluster_deps_it = cluster_deps.find(original_oc_name); if (cluster_deps_it != cluster_deps.end()) { @@ -376,8 +378,10 @@ absl::StatusOr BuildXlaHostComputeNodeDef( // Replace outside compilation function call node with XlaHostCompute node. TF_ATTRIBUTE_NOINLINE absl::StatusOr ReplaceOutsideCompilationCallNode( - Graph* g, Node* call_node, const std::map& host_compute_core, - const absl::flat_hash_map>& cluster_deps) { + Graph* g, Node* call_node, + const std::map& host_compute_core, + const absl::flat_hash_map>& + cluster_deps) { // Build XlaHostCompute NodeDef. TF_ASSIGN_OR_RETURN( NodeDef node_def, @@ -405,8 +409,8 @@ absl::Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->ClearAttr("device_ordinal"); n->AddAttr("device_ordinal", device_ordinal_value); } else if (n->IsIfNode()) { - for (const string& attr_name : - std::vector{"then_branch", "else_branch"}) { + for (const std::string& attr_name : + std::vector{"then_branch", "else_branch"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -414,7 +418,8 @@ absl::Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->AddAttr(attr_name, branch_func); } } else if (n->IsWhileNode()) { - for (const string& attr_name : std::vector{"cond", "body"}) { + for (const std::string& attr_name : + std::vector{"cond", "body"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -448,11 +453,12 @@ bool HasLiftedArgs(const FunctionDef& function_def) { absl::StatusOr>> LiftedArgsAndOutsideCompilationNodesInFunctionBody( const FunctionBody& function_body, - const std::unordered_map& outside_compilation_attr_to_node) { + const std::unordered_map& + outside_compilation_attr_to_node) { std::vector> lifted_arg_nodes_and_outside_compilation_nodes; for (Node* n : function_body.graph->op_nodes()) { - string oc_cluster; + std::string oc_cluster; if (n->type_string() == "Placeholder" && GetNodeAttr(n->def(), kXlaLiftedArgOutsideCompilationAttrName, &oc_cluster) @@ -471,7 +477,7 @@ LiftedArgsAndOutsideCompilationNodesInFunctionBody( absl::StatusOr> UpdateTypesAttribute( const std::vector>& lifted_arg_nodes_and_outside_compilation_nodes, - const string& type_attr_name, Node* n) { + const std::string& type_attr_name, Node* n) { std::vector data_types; data_types.reserve(lifted_arg_nodes_and_outside_compilation_nodes.size()); TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), type_attr_name, &data_types)); @@ -578,7 +584,8 @@ absl::Status AddFunctionWithNewName(const std::string& new_name, // Reconnect outside compilation lifted arguments in a functional While node to // its outside compilation tensor sources. absl::Status PostprocessLiftedArgsForWhile( - const std::unordered_map& outside_compilation_attr_to_node, + const std::unordered_map& + outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { TF_RET_CHECK(n->IsWhileNode()); @@ -687,7 +694,8 @@ absl::Status PostprocessLiftedArgsForWhile( } absl::Status PostprocessLiftedArgsForIf( - const std::unordered_map& outside_compilation_attr_to_node, + const std::unordered_map& + outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { TF_RET_CHECK(n->IsIfNode()); @@ -826,7 +834,8 @@ absl::Status PostprocessLiftedArgsForIf( } absl::Status PostprocessLiftedArgsForCall( - const std::unordered_map& outside_compilation_attr_to_node, + const std::unordered_map& + outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { const FunctionDef* fdef = fld->Find(n->type_string()); TF_RET_CHECK(fdef); @@ -924,12 +933,12 @@ absl::Status PostprocessLiftedArgsForCall( // Creates a mapping from outside compilation cluster name to lifted argument // placeholder. -absl::StatusOr> OutsideCompilationAttrToNode( - const Graph& g) { - std::unordered_map outside_compilation_attr_to_node; +absl::StatusOr> +OutsideCompilationAttrToNode(const Graph& g) { + std::unordered_map outside_compilation_attr_to_node; for (Node* n : g.op_nodes()) { bool is_lifted_arg; - string outside_compilation_attr; + std::string outside_compilation_attr; if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) && TryGetNodeAttr(n->def(), "_xla_outside_compilation", &outside_compilation_attr)) { @@ -988,8 +997,9 @@ absl::Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { // replace this node with compilation result node. // 3) all outside compilation graphs. absl::Status ConstructHostGraph( - const string& xla_cluster_name, const string& outside_compilation_attr_name, - const std::vector& outside_compilation_host_graphs, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::vector& outside_compilation_host_graphs, FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { host_graph->reset(new Graph(fld)); @@ -1013,7 +1023,7 @@ absl::Status ConstructHostGraph( // XlaSendFromHost, If/While nodes containing // XlaRecvAtHost/XlaSendFromHost) to sequencer node. // c) Clear node_def.device(), so device placer won't get confused. - for (const string& host_func : outside_compilation_host_graphs) { + for (const std::string& host_func : outside_compilation_host_graphs) { VLOG(4) << "Expanding host graph " << host_func; // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder // value after we expanded all host graphs. We cannot just use placeholder @@ -1021,7 +1031,7 @@ absl::Status ConstructHostGraph( // value for attributes. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr host_fbody; const FunctionDef* host_fdef = fld->Find(host_func); @@ -1123,18 +1133,17 @@ absl::Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -absl::Status ExpandHostGraphIntoMainGraph(Graph* main_graph, - FunctionLibraryDefinition* fld, - const string& host_graph_func_name, - Node* xla_computation_node, - Node* pivot_node) { +absl::Status ExpandHostGraphIntoMainGraph( + Graph* main_graph, FunctionLibraryDefinition* fld, + const std::string& host_graph_func_name, Node* xla_computation_node, + Node* pivot_node) { // Temporarily use "0" as "_device_ordinal". It will be rewritten with the // correct value in a later pass. We cannot just use placeholder value here // because FunctionDef instantiation does not allow placeholder value for // attributes. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr fbody; const FunctionDef* host_graph_func = fld->Find(host_graph_func_name); @@ -1207,12 +1216,12 @@ absl::Status ExpandHostGraphIntoMainGraph(Graph* main_graph, // 2) Remove control edges. // 3) Prune nodes that are not useful for shape inference. absl::Status RewriteShapeInferenceGraph( - const string& shape_inference_graph_name, Graph* host_graph, + const std::string& shape_inference_graph_name, Graph* host_graph, Node* pivot_node, FunctionLibraryDefinition* fld) { // Use "0" as "_device_ordinal". It does not matter for shape inference. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr fbody; const FunctionDef* shape_inference_graph = @@ -1338,13 +1347,13 @@ void SetMaximalSharding(NodeDefBuilder& node_builder) { // Builds XlaSendToHost node which sends cond predicate to host. TF_ATTRIBUTE_NOINLINE absl::StatusOr BuildSendIfPredNode( - const string& name, const string& host_transfer_key, Node* pred_node, - Graph* g) { + const std::string& name, const std::string& host_transfer_key, + Node* pred_node, Graph* g) { NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); send_pred_builder.Attr("Tinput", DT_BOOL); send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); send_pred_builder.Attr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); send_pred_builder.Attr(kXlaOriginalOutsideCompilationNodeName, name); SetMaximalSharding(send_pred_builder); send_pred_builder.Input(pred_node->name(), 0, DT_BOOL); @@ -1356,14 +1365,14 @@ TF_ATTRIBUTE_NOINLINE absl::StatusOr BuildSendIfPredNode( } // Replaces key placeholder node with an _Arg node. -absl::Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, - const string& func_name, - FunctionLibraryDefinition* fld) { +absl::Status ReplaceKeyPlaceholderWithArgNode( + const std::string& xla_cluster_name, const std::string& func_name, + FunctionLibraryDefinition* fld) { // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder // value after rewriting. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_attr; std::unique_ptr fbody; const FunctionDef* func = fld->Find(func_name); @@ -1404,14 +1413,15 @@ absl::Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, // Builds host side graph for If node. TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForIfNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const string& if_node_name, const string& host_transfer_key, - const string& host_graph_func_name, FunctionLibraryDefinition* fld, - const string& then_branch_host_func_name, - const string& else_branch_host_func_name) { + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const std::string& if_node_name, + const std::string& host_transfer_key, + const std::string& host_graph_func_name, FunctionLibraryDefinition* fld, + const std::string& then_branch_host_func_name, + const std::string& else_branch_host_func_name) { Graph host_graph(fld); - string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); + std::string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("_device_ordinal"); @@ -1484,7 +1494,7 @@ TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForIfNode( // Rewrites loop cond to add a node which sends loop cond to host. TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( - const string& cond_xla_func_name, const string& host_transfer_key, + const std::string& cond_xla_func_name, const std::string& host_transfer_key, NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld, Node* while_node) { // Instantiate the loop cond function. @@ -1523,7 +1533,7 @@ TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( send_loop_cond_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); send_loop_cond_builder.Attr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); send_loop_cond_builder.Attr(kXlaOriginalOutsideCompilationNodeName, send_loop_cond_builder.node_name()); SetMaximalSharding(send_loop_cond_builder); @@ -1560,10 +1570,13 @@ TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( // Rewrites while loop cond function for host. absl::Status RewriteHostWhileLoopCond( - const string& cond_host_func_name, const string& while_node_name, - const string& host_transfer_key, const string& xla_cluster_attr_name, - const string& xla_cluster_name, const string& outside_compilation_attr_name, - const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + const std::string& cond_host_func_name, const std::string& while_node_name, + const std::string& host_transfer_key, + const std::string& xla_cluster_attr_name, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::string& outside_compilation_name, + FunctionLibraryDefinition* fld) { // Replace key placeholder node with _Arg node. TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( xla_cluster_name, cond_host_func_name, fld)); @@ -1571,7 +1584,7 @@ absl::Status RewriteHostWhileLoopCond( // Instantiate cond function. AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_temp_value; std::unique_ptr cond_fbody; const FunctionDef* cond_host_func = fld->Find(cond_host_func_name); @@ -1634,10 +1647,13 @@ absl::Status RewriteHostWhileLoopCond( // Rewrites while loop body function for host. absl::Status RewriteHostWhileLoopBody( - const string& body_host_func_name, const string& while_node_name, - const string& host_transfer_key, const string& xla_cluster_attr_name, - const string& xla_cluster_name, const string& outside_compilation_attr_name, - const string& outside_compilation_name, FunctionLibraryDefinition* fld) { + const std::string& body_host_func_name, const std::string& while_node_name, + const std::string& host_transfer_key, + const std::string& xla_cluster_attr_name, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::string& outside_compilation_name, + FunctionLibraryDefinition* fld) { // Replace key placeholder node with _Arg node. TF_RETURN_IF_ERROR(ReplaceKeyPlaceholderWithArgNode( xla_cluster_name, body_host_func_name, fld)); @@ -1645,7 +1661,7 @@ absl::Status RewriteHostWhileLoopBody( // Instantiate body function. AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map attrs; + protobuf::Map attrs; attrs["_device_ordinal"] = device_ordinal_temp_value; std::unique_ptr body_fbody; const FunctionDef* body_host_func = fld->Find(body_host_func_name); @@ -1692,13 +1708,16 @@ absl::Status RewriteHostWhileLoopBody( // Builds host side graph for while node. TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForWhileNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const string& while_node_name, const string& host_transfer_key, - const string& host_graph_func_name, FunctionLibraryDefinition* fld, - const string& cond_host_func_name, const string& body_host_func_name) { + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const std::string& while_node_name, + const std::string& host_transfer_key, + const std::string& host_graph_func_name, FunctionLibraryDefinition* fld, + const std::string& cond_host_func_name, + const std::string& body_host_func_name) { Graph host_graph(fld); - string outside_compilation_name = absl::StrCat("oc_while_", while_node_name); + std::string outside_compilation_name = + absl::StrCat("oc_while_", while_node_name); // Step 1: add key placeholder node. TF_ASSIGN_OR_RETURN( @@ -1759,10 +1778,12 @@ TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForWhileNode( // Builds host graph for func call nodes. absl::Status BuildHostGraphForFuncCallNode( - const string& xla_cluster_attr_name, const string& xla_cluster_name, - const string& outside_compilation_attr_name, - const string& func_call_node_name, const string& func_call_host_func_name, - const string& host_graph_func_name, FunctionLibraryDefinition* fld) { + const std::string& xla_cluster_attr_name, + const std::string& xla_cluster_name, + const std::string& outside_compilation_attr_name, + const std::string& func_call_node_name, + const std::string& func_call_host_func_name, + const std::string& host_graph_func_name, FunctionLibraryDefinition* fld) { Graph host_graph(fld); AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("_device_ordinal"); @@ -1807,18 +1828,19 @@ absl::Status BuildHostGraphForFuncCallNode( } TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, Graph* g, Node* n, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, - std::vector* host_graphs, - std::vector* shape_inference_graphs, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { bool func_has_outside_compilation = false; NameAttrList func; if (fld->Contains(n->type_string())) { func.set_name(n->type_string()); - typedef protobuf::Map AttrMap; + typedef protobuf::Map AttrMap; *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); } else if (n->IsPartitionedCall()) { TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func)); @@ -1827,7 +1849,7 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( func.set_name(FunctionLibraryDefinition::kGradientOp); *func.mutable_attr() = n->def().attr(); } - string canonical_func_name; + std::string canonical_func_name; if (func.name() == FunctionLibraryDefinition::kGradientOp) { NameAttrList forward_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &forward_func)); @@ -1835,8 +1857,8 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( } else { canonical_func_name = func.name(); } - string new_func_name = absl::StrCat(canonical_func_name, "_oc"); - string host_func_name = + std::string new_func_name = absl::StrCat(canonical_func_name, "_oc"); + std::string host_func_name = absl::StrCat("oc_func_call_host_", canonical_func_name); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, @@ -1876,11 +1898,11 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get())); TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def)); replace->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); replace->AddAttr(kXlaOriginalOutsideCompilationNodeName, replace->name()); // Build host side graph for the function call. - string oc_host_graph_name = + std::string oc_host_graph_name = absl::StrCat("oc_func_host_graph_", replace->name()); TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode( xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, @@ -1893,12 +1915,13 @@ TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( } absl::Status ExtractOutsideCompilationForIfNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, Graph* g, Node* n, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, - std::vector* host_graphs, - std::vector* shape_inference_graphs, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { // Instantiate "then_branch" and "else_branch". NameAttrList then_branch, else_branch; @@ -1908,12 +1931,14 @@ absl::Status ExtractOutsideCompilationForIfNode( // Extract outside compilation for then_branch and else_branch. bool then_branch_has_outside_compilation = false; bool else_branch_has_outside_compilation = false; - string then_branch_host_func_name = - absl::StrCat("oc_then_branch_host_if_", then_branch.name()), - else_branch_host_func_name = - absl::StrCat("oc_else_branch_host_if_", else_branch.name()); - string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), - else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + std::string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", then_branch.name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", else_branch.name()); + std::string then_branch_xla_func_name = + absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = + absl::StrCat(else_branch.name(), "_oc"); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, then_branch, then_branch_xla_func_name, then_branch_host_func_name, @@ -1946,7 +1971,7 @@ absl::Status ExtractOutsideCompilationForIfNode( } n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name()); - string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + std::string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); // XLA computation: add a SendToHost node to send cond predicate. Node* pred_node; @@ -1956,7 +1981,7 @@ absl::Status ExtractOutsideCompilationForIfNode( BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), host_transfer_key, pred_node, g)); n->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{send_pred_node->name()}); + std::vector{send_pred_node->name()}); // Add a control edge from `send_pred_node` to If node, so XlaCompiler will // visit If node after `send_pred_node`, thus the token output for @@ -1969,7 +1994,7 @@ absl::Status ExtractOutsideCompilationForIfNode( // we need to create a no-op host graph. if (!then_branch_has_outside_compilation) { std::unique_ptr then_branch_host_graph(new Graph(fld)); - std::vector then_branch_host_graphs; + std::vector then_branch_host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph( xla_cluster_name, outside_compilation_attr_name, then_branch_host_graphs, fld, &then_branch_host_graph)); @@ -1986,7 +2011,7 @@ absl::Status ExtractOutsideCompilationForIfNode( } if (!else_branch_has_outside_compilation) { std::unique_ptr else_branch_host_graph(new Graph(fld)); - std::vector else_branch_host_graphs; + std::vector else_branch_host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph( xla_cluster_name, outside_compilation_attr_name, else_branch_host_graphs, fld, &else_branch_host_graph)); @@ -2001,7 +2026,7 @@ absl::Status ExtractOutsideCompilationForIfNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(else_branch_host_fdef)); } } - string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + std::string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, n->name(), host_transfer_key, oc_host_graph_name, fld, @@ -2012,12 +2037,13 @@ absl::Status ExtractOutsideCompilationForIfNode( } absl::Status ExtractOutsideCompilationForWhileNode( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, Graph* g, Node* n, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, - std::vector* host_graphs, - std::vector* shape_inference_graphs, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { // Instantiate "cond" and "body". NameAttrList cond, body; @@ -2027,10 +2053,12 @@ absl::Status ExtractOutsideCompilationForWhileNode( // Extract outside compilation for cond and body. bool cond_has_outside_compilation = false; bool body_has_outside_compilation = false; - string cond_host_func_name = absl::StrCat("oc_cond_host_while_", cond.name()), - body_host_func_name = absl::StrCat("oc_body_host_while_", body.name()); - string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), - body_xla_func_name = absl::StrCat(body.name(), "_oc"); + std::string cond_host_func_name = + absl::StrCat("oc_cond_host_while_", cond.name()), + body_host_func_name = + absl::StrCat("oc_body_host_while_", body.name()); + std::string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, @@ -2060,19 +2088,19 @@ absl::Status ExtractOutsideCompilationForWhileNode( } n->AddAttr(kXlaOriginalOutsideCompilationNodeName, n->name()); - string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + std::string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); // XLA computation: rewrite cond function to add a SendToHost node to send // loop predicate. TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond( cond_xla_func_name, host_transfer_key, &cond, fld, n)); n->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); + std::vector{kXlaTokenArgNodeName}); // Build host side graph for the "While" node. if (!cond_has_outside_compilation) { std::unique_ptr cond_host_graph(new Graph(fld)); - std::vector host_graphs; + std::vector host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, host_graphs, fld, &cond_host_graph)); @@ -2088,7 +2116,7 @@ absl::Status ExtractOutsideCompilationForWhileNode( } if (!body_has_outside_compilation) { std::unique_ptr body_host_graph(new Graph(fld)); - std::vector host_graphs; + std::vector host_graphs; TF_RETURN_IF_ERROR(ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, host_graphs, fld, &body_host_graph)); @@ -2102,7 +2130,8 @@ absl::Status ExtractOutsideCompilationForWhileNode( TF_RETURN_IF_ERROR(fld->AddFunctionDef(body_host_fdef)); } } - string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + std::string oc_host_graph_name = + absl::StrCat("oc_while_host_graph_", n->name()); TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, n->name(), host_transfer_key, oc_host_graph_name, fld, @@ -2113,11 +2142,13 @@ absl::Status ExtractOutsideCompilationForWhileNode( } absl::Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( - Graph* g, const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const std::map& host_compute_core, FunctionLibraryRuntime* flr, - FunctionLibraryDefinition* fld, std::vector* host_graphs, - std::vector* shape_inference_graphs, + Graph* g, const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, + const std::map& host_compute_core, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* host_graphs, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { std::vector if_nodes, while_nodes, func_call_nodes; for (Node* n : g->nodes()) { @@ -2155,7 +2186,7 @@ absl::Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( } absl::Status CopyOutsideCompilationConstNodes( - Graph* g, const string& outside_compilation_attr_name) { + Graph* g, const std::string& outside_compilation_attr_name) { for (Node* n : g->op_nodes()) { if (!n->IsConstant() || !HasNodeAttr(n->def(), outside_compilation_attr_name)) { @@ -2205,8 +2236,8 @@ absl::Status RewriteOutsideCompilationSubgraphFn::operator()( const std::vector& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def) { - string old_name = node_def->op(); - string new_name = + std::string old_name = node_def->op(); + std::string new_name = absl::StrCat(xla_cluster_name_, "_", new_function_name_, "_", old_name); node_def->set_op(new_name); node_def->set_name(new_name); @@ -2290,14 +2321,14 @@ absl::Status RewriteOutsideCompilationSubgraphFn::operator()( AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", *shapes, node_def); } else { - string shape_inference_func_name = + std::string shape_inference_func_name = absl::StrCat("_outside_compilation_shape_inference_", new_name); NameAttrList shape_inference_graph; shape_inference_graph.set_name(shape_inference_func_name); AddNodeAttr("shape_inference_graph", shape_inference_graph, node_def); AddNodeAttr("shapes", std::vector{}, node_def); } - AddNodeAttr("ancestors", std::vector{}, node_def); + AddNodeAttr("ancestors", std::vector{}, node_def); AddNodeAttr("Tinputs", recv_at_host_dtypes, node_def); AddNodeAttr("Toutputs", send_from_host_dtypes, node_def); AddNodeAttr("key", absl::StrCat("host_compute_channel_", new_name), node_def); @@ -2306,15 +2337,16 @@ absl::Status RewriteOutsideCompilationSubgraphFn::operator()( } absl::Status ExtractOutsideCompilationForFunction( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const NameAttrList& func_name_attrs, const string& new_func_name, - const string& host_graph_func_name, - const std::map& host_compute_core, FunctionLibraryRuntime* flr, - FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const NameAttrList& func_name_attrs, + const std::string& new_func_name, const std::string& host_graph_func_name, + const std::map& host_compute_core, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* shape_inference_graphs, bool* has_outside_compilation) { // Convert the function to graph. - const string& func_name = func_name_attrs.name(); + const std::string& func_name = func_name_attrs.name(); FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR( flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); @@ -2345,8 +2377,8 @@ absl::Status ExtractOutsideCompilationForFunction( } std::unique_ptr graph_out; - std::vector outside_compilation_host_graphs; - std::vector shape_inference_graphs_to_rewrite; + std::vector outside_compilation_host_graphs; + std::vector shape_inference_graphs_to_rewrite; if (*has_outside_compilation) { // Copy outside compilation Const nodes with non outside compilation users. TF_RETURN_IF_ERROR(CopyOutsideCompilationConstNodes( @@ -2404,7 +2436,7 @@ absl::Status ExtractOutsideCompilationForFunction( } } } - std::map host_compute_nodes; + std::map host_compute_nodes; for (Node* n : outside_compilation_nodes) { auto host_compute_node_or = ReplaceOutsideCompilationCallNode( graph_out.get(), n, host_compute_core, *cluster_deps); @@ -2416,11 +2448,11 @@ absl::Status ExtractOutsideCompilationForFunction( // them so XlaCompiler can handle them in correct order. for (const auto& iter : host_compute_nodes) { Node* host_compute_node = iter.second; - std::vector token_input_node_names; + std::vector token_input_node_names; TF_RETURN_IF_ERROR(GetNodeAttr(host_compute_node->def(), kXlaTokenInputNodesAttrName, &token_input_node_names)); - for (const string& node_name : token_input_node_names) { + for (const std::string& node_name : token_input_node_names) { if (node_name == kXlaTokenArgNodeName) { continue; } @@ -2459,7 +2491,7 @@ absl::Status ExtractOutsideCompilationForFunction( // Shape inference graphs might contain Placeholder nodes for outside // compilation to outside compilation edges. Rewrite shape inference graphs // to remove such nodes. - for (const string& shape_inference_graph : + for (const std::string& shape_inference_graph : shape_inference_graphs_to_rewrite) { TF_RETURN_IF_ERROR( RewriteShapeInferenceGraph(shape_inference_graph, host_graph.get(), @@ -2467,7 +2499,7 @@ absl::Status ExtractOutsideCompilationForFunction( } // Remove the outside compilation graphs from function library. - for (const string& func : outside_compilation_host_graphs) { + for (const std::string& func : outside_compilation_host_graphs) { TF_RETURN_IF_ERROR(fld->RemoveFunction(func)); } @@ -2499,9 +2531,9 @@ absl::Status ExtractOutsideCompilationForFunction( } absl::Status ExtractOutsideCompilation( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters, Graph* g, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, bool* modified) { if (VLOG_IS_ON(4)) { @@ -2511,14 +2543,14 @@ absl::Status ExtractOutsideCompilation( *modified = false; auto node_name_index = g->BuildNodeNameIndex(); for (auto& iter : clusters) { - string xla_cluster_name = iter.first; + std::string xla_cluster_name = iter.first; Node* n = iter.second.node; auto const& func_name_attrs = iter.second.func_name_attrs; auto const& host_compute_core = iter.second.host_compute_core; - std::vector shape_inference_graphs; + std::vector shape_inference_graphs; bool has_outside_compilation; - string host_graph_func_name = + std::string host_graph_func_name = absl::StrCat("oc_host_graph_", xla_cluster_name); TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, @@ -2528,7 +2560,7 @@ absl::Status ExtractOutsideCompilation( *modified |= has_outside_compilation; if (has_outside_compilation) { - string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); + std::string pivot_name = absl::StrCat(xla_cluster_name, "/pivot"); Node* pivot_node = node_name_index[pivot_name]; TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph( g, fld, host_graph_func_name, n, pivot_node)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index 7631ccd0bc6ab0..c1697fcb4cde0d 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -44,9 +44,9 @@ namespace tensorflow { class RewriteOutsideCompilationSubgraphFn { public: RewriteOutsideCompilationSubgraphFn( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const string& xla_cluster_name, const string& new_function_name) + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const std::string& new_function_name) : xla_cluster_attr_name_(xla_cluster_attr_name), outside_compilation_attr_name_(outside_compilation_attr_name), xla_cluster_name_(xla_cluster_name), @@ -59,10 +59,10 @@ class RewriteOutsideCompilationSubgraphFn { NodeDef* node_def); private: - string xla_cluster_attr_name_; - string outside_compilation_attr_name_; - string xla_cluster_name_; - string new_function_name_; + std::string xla_cluster_attr_name_; + std::string outside_compilation_attr_name_; + std::string xla_cluster_name_; + std::string new_function_name_; }; // For an XLA computation function, replace all outside compilations with @@ -88,12 +88,13 @@ class RewriteOutsideCompilationSubgraphFn { // has_outside_compilation: a bool indicating whether this function has any // outside compilation nodes. absl::Status ExtractOutsideCompilationForFunction( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, const string& xla_cluster_name, - const NameAttrList& func_name_attrs, const string& new_func_name, - const string& host_graph_func_name, - const std::map& host_compute_core, FunctionLibraryRuntime* flr, - FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const NameAttrList& func_name_attrs, + const std::string& new_func_name, const std::string& host_graph_func_name, + const std::map& host_compute_core, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* shape_inference_graphs, bool* has_outside_compilation); // Rewrites XLA computation in `clusters` to replace outside compilation nodes @@ -101,9 +102,9 @@ absl::Status ExtractOutsideCompilationForFunction( // of outside compilation outputs cannot be determined now, we will store shape // inference graph into `fld`. absl::Status ExtractOutsideCompilation( - const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const std::unordered_map& clusters, Graph* g, + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, bool* modified); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index 4d007d07504939..1a6441a80726a0 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -236,14 +236,14 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { } absl::Status ExtractOutsideCompilationTest( - const string &xla_cluster_attr_name, - const string &outside_compilation_attr_name, - const string &xla_cluster_name, const NameAttrList &func_name_attrs, - const string &new_func_name, const string &host_graph_func_name, - const std::map &host_compute_core, - FunctionLibraryDefinition *fld, - std::vector *shape_inference_graphs, - bool *has_outside_compilation) { + const std::string& xla_cluster_attr_name, + const std::string& outside_compilation_attr_name, + const std::string& xla_cluster_name, const NameAttrList& func_name_attrs, + const std::string& new_func_name, const std::string& host_graph_func_name, + const std::map& host_compute_core, + FunctionLibraryDefinition* fld, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { OptimizerOptions opts; pflr_ = std::make_unique( device_mgr_.get(), Env::Default(), /*config=*/nullptr, @@ -288,9 +288,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -342,7 +342,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, Basic) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper( *fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, &host_fbody)); @@ -406,9 +406,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -481,9 +481,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -498,7 +498,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, @@ -568,7 +568,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // _xla_token_input_nodes. Node *if_node = node_name_index["if"]; EXPECT_NE(if_node, nullptr); - std::vector token_inputs; + std::vector token_inputs; TF_CHECK_OK( GetNodeAttr(if_node->def(), "_xla_token_input_nodes", &token_inputs)); EXPECT_THAT(token_inputs, ::testing::ElementsAre("send_oc_if_pred_if")); @@ -631,9 +631,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -648,7 +648,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInWhile) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, @@ -767,9 +767,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { TF_CHECK_OK(fld.AddFunctionDef(*xla_fdef)); } - protobuf::Map attrs; - std::map host_compute_core; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -784,7 +784,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { std::unique_ptr host_fbody; AttrValue device_ordinal_temp_value; device_ordinal_temp_value.set_i(0); - protobuf::Map host_func_attrs; + protobuf::Map host_func_attrs; host_func_attrs["_device_ordinal"] = device_ordinal_temp_value; TF_CHECK_OK(FunctionDefToBodyHelper(*fld.Find("host_graph"), AttrSlice(&host_func_attrs), &fld, @@ -873,9 +873,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -898,14 +898,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, EXPECT_NE(host_compute_1, nullptr); // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr. - std::vector token_input_nodes; + std::vector token_input_nodes; TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()), "_xla_token_input_nodes", &token_input_nodes)); - std::vector expected_token_input_nodes_0({"_xla_token_arg_node"}); + std::vector expected_token_input_nodes_0( + {"_xla_token_arg_node"}); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0); token_input_nodes.clear(); - std::vector expected_token_input_nodes_1( + std::vector expected_token_input_nodes_1( {"_xla_token_arg_node", "outside_compilation_0_host_compute"}); TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), "_xla_token_input_nodes", &token_input_nodes)); @@ -955,9 +956,9 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, } FunctionLibraryDefinition fld(OpRegistry::Global(), fdl); - protobuf::Map attrs; - std::map host_compute_core = {{"0", 1}, {"1", 0}}; - std::vector shape_inference_graphs; + protobuf::Map attrs; + std::map host_compute_core = {{"0", 1}, {"1", 0}}; + std::vector shape_inference_graphs; bool has_outside_compilation; NameAttrList name_attrs; name_attrs.set_name("cluster"); @@ -980,14 +981,15 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, EXPECT_NE(host_compute_1, nullptr); // Check XlaHostCompute nodes' "_xla_token_input_nodes" attr. - std::vector token_input_nodes; + std::vector token_input_nodes; TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_0->attrs()), "_xla_token_input_nodes", &token_input_nodes)); - std::vector expected_token_input_nodes_0({"_xla_token_arg_node"}); + std::vector expected_token_input_nodes_0( + {"_xla_token_arg_node"}); EXPECT_EQ(token_input_nodes, expected_token_input_nodes_0); token_input_nodes.clear(); - std::vector expected_token_input_nodes_1( + std::vector expected_token_input_nodes_1( {"_xla_token_arg_node", "outside_compilation_0_host_compute"}); TF_CHECK_OK(GetNodeAttr(AttrSlice(host_compute_1->attrs()), "_xla_token_input_nodes", &token_input_nodes)); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 446df9cac70e2d..a0a0d45736f1e8 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -46,7 +46,7 @@ std::vector* jitrt_flag_list; std::vector* flag_list; absl::once_flag flags_init; -bool SetterForXlaAutoJitFlag(const string& value) { +bool SetterForXlaAutoJitFlag(const std::string& value) { int32_t opt_level; // We need to use the mark_for_compilation_flags directly here instead of // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The @@ -81,7 +81,7 @@ bool SetterForXlaAutoJitFlag(const string& value) { return true; } -bool SetterForXlaCallModuleDisabledChecks(const string& value) { +bool SetterForXlaCallModuleDisabledChecks(const std::string& value) { auto directives = absl::StrSplit(value, ',', absl::SkipEmpty()); call_module_flags->disabled_checks.insert(directives.begin(), directives.end()); @@ -231,7 +231,7 @@ void AllocateAndParseFlags() { mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0; mark_for_compilation_flags->tf_xla_min_cluster_size = 4; mark_for_compilation_flags->tf_xla_max_cluster_size = - std::numeric_limits::max(); + std::numeric_limits::max(); mark_for_compilation_flags->tf_xla_clustering_debug = false; mark_for_compilation_flags->tf_xla_cpu_global_jit = false; mark_for_compilation_flags->tf_xla_clustering_fuel = @@ -463,7 +463,7 @@ void ResetFlags() { } // namespace -bool SetXlaAutoJitFlagFromFlagString(const string& value) { +bool SetXlaAutoJitFlagFromFlagString(const std::string& value) { absl::call_once(flags_init, &AllocateAndParseFlags); return SetterForXlaAutoJitFlag(value); } diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 3561551f363ac6..96154b892ae5b0 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -41,15 +41,15 @@ struct XlaAutoJitFlag { // `optimization_level_general` applies. // // Experimental. - int32 optimization_level_single_gpu; - int32 optimization_level_general; + int32_t optimization_level_single_gpu; + int32_t optimization_level_general; }; // Sets the xla_auto_jit_flag based on the given flag string. Supported syntax // is: // : sets general and single_gpu setting to the provided number. // single-gpu(): sets the single_gpu setting to the provided number. -bool SetXlaAutoJitFlagFromFlagString(const string& value); +bool SetXlaAutoJitFlagFromFlagString(const std::string& value); // Flags associated with the XLA bridge's mark_for_compilation_pass module. struct MarkForCompilationPassFlags { @@ -57,16 +57,16 @@ struct MarkForCompilationPassFlags { // Minimum number of operators in an XLA compilation. Ignored for operators // placed on an XLA device or operators explicitly marked for compilation. - int32 tf_xla_min_cluster_size; + int32_t tf_xla_min_cluster_size; // Maximum number of operators in an XLA compilation. - int32 tf_xla_max_cluster_size; + int32_t tf_xla_max_cluster_size; // If non-empty, limit XLA clustering to the following TF operations. - string tf_xla_ops_to_cluster; + std::string tf_xla_ops_to_cluster; // If non-empty, remove following operations from XLA clustering excludelist. - string tf_xla_cluster_exclude_ops; + std::string tf_xla_cluster_exclude_ops; // Dump graphs during XLA compilation. bool tf_xla_clustering_debug; @@ -110,7 +110,7 @@ struct MarkForCompilationPassFlags { bool tf_xla_disable_strict_signature_checks; // Specifies the persistance cache prefix. Default is "xla_compile_cache" - string tf_xla_persistent_cache_prefix; + std::string tf_xla_persistent_cache_prefix; }; // Flags associated with XLA Sparse Core. diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc index 75bd1d7310a295..1b0239c3550970 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc @@ -95,7 +95,7 @@ TEST(ForceXlaConstantsOnHostPassTest, Simple) { if (CanCreateXlaKernel(node->def())) { EXPECT_FALSE(found); found = true; - std::vector hostmem_attr; + std::vector hostmem_attr; EXPECT_TRUE(TryGetNodeAttr(node->def(), "_input_hostmem", &hostmem_attr)); EXPECT_EQ(hostmem_attr.size(), 1); EXPECT_EQ(hostmem_attr[0], 1); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 8317d222928200..03a7d1081b8b53 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -93,7 +93,7 @@ std::vector IntTensorAsVector(const Tensor& t) { result.reserve(t.NumElements()); for (int i = 0; i < t.NumElements(); i++) { int64_t element = t.dtype() == DT_INT32 - ? static_cast(t.flat()(i)) + ? static_cast(t.flat()(i)) : t.flat()(i); result.push_back(element); } @@ -251,14 +251,14 @@ absl::Status ComputeSliceSize(const Scope& host_scope, absl::Status ConvertTensorFlowSliceToStaticShapedSlice( Graph* g, Node* slice, const SliceInputs& slice_inputs, absl::string_view cluster_name, Node** result) { - string host_name; + std::string host_name; TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( slice->assigned_device_name(), &host_name)); absl::Status status; Scope main_scope = NewInternalScope(g, &status, /*refiner=*/nullptr) - .WithXlaCluster(string(cluster_name)) + .WithXlaCluster(std::string(cluster_name)) .NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice")); Scope host_scope = main_scope.WithAssignedDevice(host_name); @@ -286,7 +286,7 @@ absl::Status ConvertTensorFlowSliceToStaticShapedSlice( TF_RETURN_IF_ERROR(main_scope.status()); - std::vector compile_time_const_inputs; + std::vector compile_time_const_inputs; compile_time_const_inputs.push_back("size"); (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, compile_time_const_inputs); diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index 411f761995483a..6a8523a7d4c893 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -66,7 +66,8 @@ class FakeDevice : public Device { Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } - static std::unique_ptr Make(const string& name, const string& type) { + static std::unique_ptr Make(const std::string& name, + const std::string& type) { DeviceAttributes device_attributes; device_attributes.set_name(name); device_attributes.set_device_type(DeviceType(type).type()); @@ -100,7 +101,7 @@ absl::Status IncreaseDynamismForAutoJit(const Scope& s, // Scope::ToGraph seems to drop assigned devices, probably because it goes // through a GraphDef. So explicitly maintain the device assignment. - std::unordered_map assigned_device_names; + std::unordered_map assigned_device_names; for (Node* n : s.graph()->nodes()) { assigned_device_names[n->name()] = n->assigned_device_name(); } @@ -149,7 +150,7 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) { Inputs(m_slice_size_0, Const(static_cast(500)), Const(zero_32)))); - std::vector compile_time_constant_inputs; + std::vector compile_time_constant_inputs; compile_time_constant_inputs.push_back("size"); auto m_dynamic_slice = NodeWith( Op("Slice"), AssignedDevice(kDeviceName), diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index c3a24f3e0f7163..340cdbe8032c63 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -151,7 +151,7 @@ class MarkForCompilationPassImpl { std::optional resource_op_device, std::optional resource_var_operation_node_id, std::optional deadness_predicate, - bool is_xla_compile_attr_true, std::optional xla_scope) + bool is_xla_compile_attr_true, std::optional xla_scope) : cycles_graph_node_id_(tf_graph_node_id), effective_cluster_size_(effective_cluster_size), has_functional_control_flow_(has_functional_control_flow), @@ -220,7 +220,7 @@ class MarkForCompilationPassImpl { // If not nullopt then the all nodes in the cluster either do not have the // XlaScope attribute set or have it set to the value returned. - const std::optional& xla_scope() const { return xla_scope_; } + const std::optional& xla_scope() const { return xla_scope_; } // Returns the TF graph node IDs for the resource variable operations in // this cluster. @@ -228,7 +228,7 @@ class MarkForCompilationPassImpl { return resource_var_operation_node_ids_; } - string DebugString(const Graph& graph) const { + std::string DebugString(const Graph& graph) const { Node* node = graph.FindNodeId(cycles_graph_node_id()); if (!node) { // This should never happen but we try to be resilient because this is a @@ -254,7 +254,7 @@ class MarkForCompilationPassImpl { std::optional resource_op_device_; std::optional deadness_predicate_; bool is_xla_compile_attr_true_; - std::optional xla_scope_; + std::optional xla_scope_; std::vector resource_var_operation_node_ids_; Cluster(const Cluster&) = delete; @@ -365,7 +365,7 @@ class MarkForCompilationPassImpl { std::optional resource_var_operation_node_id, std::optional deadness_predicate, bool is_xla_compile_attr_true, - std::optional xla_scope) { + std::optional xla_scope) { cluster_storage_.push_back(std::make_unique( cycles_graph_node_id, effective_cluster_size, has_functional_control_flow, device_set, resource_op_device, @@ -374,7 +374,7 @@ class MarkForCompilationPassImpl { return cluster_storage_.back().get(); } - std::optional GetXlaScope(Node* n); + std::optional GetXlaScope(Node* n); // Returns the cluster for node `n`. If two nodes, N1 and N2, are placed in // the same cluster by the clustering algorithm then this function will return @@ -417,7 +417,8 @@ class MarkForCompilationPassImpl { // Returns a string representing `cycles_graph_node_id`. If the node is // unclusterable (either it is a phatom "frame" node or is not a compilation // candidate) then set `*found_unclustered` to true. - string DebugStringForCyclesGraphNode(int node_id, bool* found_unclustered); + std::string DebugStringForCyclesGraphNode(int node_id, + bool* found_unclustered); // We could not contract the edge from `from` to `to`. Return a string // describing an alternate path from `from` to `to` (besides the direct edge @@ -429,7 +430,7 @@ class MarkForCompilationPassImpl { // contracted because of the path [P,Q,R]" where P, Q and R are all clusters // since in that case a natural question is why we could not form a {A, P, Q, // R, B} cluster. - string DescribePotentialCycle(int from, int to); + std::string DescribePotentialCycle(int from, int to); // Merge the clusters `cluster_from` and `cluster_to`. After this step the // larger combined cluster is represented by `cluster_from`, but can have @@ -459,8 +460,8 @@ class MarkForCompilationPassImpl { return true; } - string EdgeContractionFailureMsg(Cluster* from, Cluster* to, - absl::string_view reason) { + std::string EdgeContractionFailureMsg(Cluster* from, Cluster* to, + absl::string_view reason) { return absl::StrCat("Could not contract ", from->DebugString(*graph_), " -> ", to->DebugString(*graph_), " because ", reason, "."); @@ -468,7 +469,7 @@ class MarkForCompilationPassImpl { DebugOptions debug_options_; Graph* graph_; - uint64 graph_fingerprint_; + uint64_t graph_fingerprint_; FunctionLibraryDefinition* flib_def_; Env* env_; OptimizerOptions::GlobalJitLevel global_jit_level_; @@ -547,7 +548,7 @@ std::vector MarkForCompilationPassImpl::FindAlternatePathForDebugging( return path; } -string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( +std::string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( int cycles_graph_node_id, bool* found_unclustered) { Cluster* cluster = GetClusterForCyclesGraphNode(cycles_graph_node_id); if (cluster) { @@ -567,8 +568,9 @@ string MarkForCompilationPassImpl::DebugStringForCyclesGraphNode( return node->name(); } -string MarkForCompilationPassImpl::DescribePotentialCycle(int from, int to) { - std::vector path_str; +std::string MarkForCompilationPassImpl::DescribePotentialCycle(int from, + int to) { + std::vector path_str; bool found_unclustered = false; absl::c_transform(FindAlternatePathForDebugging(from, to), std::back_inserter(path_str), [&](int node_id) { @@ -701,7 +703,7 @@ absl::StatusOr MarkForCompilationPassImpl::ForEachEdgeInPostOrder( // Make a copy of the set of successors because we may modify the graph in // TryToContractEdge. - std::vector successors_copy = + std::vector successors_copy = cycles_graph_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); for (int to : successors_copy) { @@ -974,7 +976,7 @@ class ClusterSequenceNumberGenerator { sequence_numbers_.clear(); } - int64 GetNext(uint64 key) { + int64_t GetNext(uint64_t key) { mutex_lock lock(mu_); return sequence_numbers_[key]++; } @@ -987,13 +989,13 @@ class ClusterSequenceNumberGenerator { private: mutex mu_; - absl::flat_hash_map sequence_numbers_; + absl::flat_hash_map sequence_numbers_; }; // Get a monotonic sequence numbers for a graph identified by its `fingerprint`. // The sequence number is necessary to disambiguate clusters extracted from the // same graph and when duplicate graphs exist within the same process. -int64_t GetNextClusterSequenceNumber(uint64 fingerprint) { +int64_t GetNextClusterSequenceNumber(uint64_t fingerprint) { return ClusterSequenceNumberGenerator::Global().GetNext(fingerprint); } @@ -1002,7 +1004,7 @@ absl::Status MarkForCompilationPassImpl::CreateClusters() { clusters_created_ = true; // Names for each cluster. - std::unordered_map cluster_names; + std::unordered_map cluster_names; if (debug_options_.dump_graphs) { DumpGraphToFile("before_mark_for_compilation", *graph_, flib_def_); @@ -1030,7 +1032,7 @@ absl::Status MarkForCompilationPassImpl::CreateClusters() { if (cluster->effective_cluster_size() >= debug_options_.min_cluster_size || cluster->has_functional_control_flow() || cluster->is_xla_compile_attr_true()) { - string& name = cluster_names[cluster->cycles_graph_node_id()]; + std::string& name = cluster_names[cluster->cycles_graph_node_id()]; if (name.empty()) { if (!cluster_name_prefix_.empty()) { @@ -1099,7 +1101,7 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( return false; } -std::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { +std::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { // Look for either _XlaScope or _XlaInternalScope on both nodes to guide // clustering. If both nodes have a scope and the scopes do not match, do // not cluster along this edge. If even one of the nodes lacks a scope @@ -1118,14 +1120,14 @@ std::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { if (global_jit_level_ != OptimizerOptions::OFF) { // If global_jit_level_ is ON, respect only _XlaInternalScope. - const string& scope = + const std::string& scope = GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr); if (!scope.empty()) { return scope; } } else { // If global_jit_level_ is OFF, respect only _XlaScope. - const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); + const std::string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); if (!scope.empty()) { return scope; } @@ -1186,9 +1188,9 @@ absl::Status MarkForCompilationPassImpl::BuildInitialClusterSet() { deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot)); } - const string& device_name_str = !node->assigned_device_name().empty() - ? node->assigned_device_name() - : node->requested_device(); + const std::string& device_name_str = !node->assigned_device_name().empty() + ? node->assigned_device_name() + : node->requested_device(); TF_ASSIGN_OR_RETURN(DeviceId device, device_info_cache_.GetIdFor(device_name_str)); @@ -1258,16 +1260,17 @@ absl::StatusOr IsIdentityDrivingConstsInLoop(Node* node) { return true; } -absl::flat_hash_set CreateClusterExcludeList() { +absl::flat_hash_set CreateClusterExcludeList() { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - absl::flat_hash_set excludelist; + absl::flat_hash_set excludelist; for (auto s : absl::StrSplit(flags->tf_xla_cluster_exclude_ops, ',')) { if (!s.empty()) { - excludelist.insert(string(s)); + excludelist.insert(std::string(s)); } } if (VLOG_IS_ON(2) && !excludelist.empty()) { - std::vector vexcludelist(excludelist.begin(), excludelist.end()); + std::vector vexcludelist(excludelist.begin(), + excludelist.end()); absl::c_sort(vexcludelist); VLOG(2) << "XLA clustering will exclude following TF operations from auto " "clustering: " @@ -1276,11 +1279,11 @@ absl::flat_hash_set CreateClusterExcludeList() { return excludelist; } -absl::flat_hash_set GetOrCreateAllowlist() { - absl::flat_hash_map>* allowlist_table = +absl::flat_hash_set GetOrCreateAllowlist() { + absl::flat_hash_map>* allowlist_table = tensorflow::GetAllowlistTable(); MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - absl::flat_hash_set allowlist; + absl::flat_hash_set allowlist; for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) { if (s == "FUSIBLE") { @@ -1292,12 +1295,12 @@ absl::flat_hash_set GetOrCreateAllowlist() { allowlist.insert(v.begin(), v.end()); } else if (!s.empty()) { // Should be a user provided TF operation. - allowlist.insert(string(s)); + allowlist.insert(std::string(s)); } } if (VLOG_IS_ON(2) && !allowlist.empty()) { - std::vector vallowlist(allowlist.begin(), allowlist.end()); + std::vector vallowlist(allowlist.begin(), allowlist.end()); absl::c_sort(vallowlist); VLOG(2) << "XLA clustering will only consider the following TF operations: " << absl::StrJoin(vallowlist, " "); @@ -1338,8 +1341,8 @@ absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() { auto allowlist = GetOrCreateAllowlist(); - std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); - absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); + std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); + absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that user's provided TF operation really exists. for (const auto& s : allowlist) { if (!all_ops.contains(s)) { @@ -1674,7 +1677,7 @@ void MarkForCompilationPassImpl::DumpPostClusteringGraphs() { DumpGraphToFile("mark_for_compilation_annotated", new_graph, flib_def_); } -string RatioToString(int numerator, int denominator) { +std::string RatioToString(int numerator, int denominator) { return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } @@ -1985,10 +1988,11 @@ absl::Status MarkForCompilationPass::RunForTest( return MarkForCompilation(options, debug_options); } -absl::flat_hash_map>* GetAllowlistTable() { +absl::flat_hash_map>* +GetAllowlistTable() { // Table format: category name: {list of TF operations in that category} - static absl::flat_hash_map>* result = - new absl::flat_hash_map>{ + static absl::flat_hash_map>* result = + new absl::flat_hash_map>{ // Unary {"PW", {"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", @@ -2056,8 +2060,8 @@ void ResetClusterSequenceNumber() { ClusterSequenceNumberGenerator::Global().Reset(); } -absl::flat_hash_set GetKnownXLAAllowlistOp() { - absl::flat_hash_set result{ +absl::flat_hash_set GetKnownXLAAllowlistOp() { + absl::flat_hash_set result{ "AdjustContrastv2", "AdjustHue", "AdjustSaturation", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 558912f2eee2e0..d6a2814ed33982 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -47,7 +47,7 @@ class MarkForCompilationPass : public GraphOptimizationPass { friend class MarkForCompilationPassTestHelper; }; -absl::flat_hash_map>* GetAllowlistTable(); +absl::flat_hash_map>* GetAllowlistTable(); namespace testing { // DO NOT USE IN PRODUCTION. @@ -56,7 +56,7 @@ namespace testing { void ResetClusterSequenceNumber(); // Return a list of operation that we choose not to put into the allowlist. -absl::flat_hash_set GetKnownXLAAllowlistOp(); +absl::flat_hash_set GetKnownXLAAllowlistOp(); } // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 1a120791206369..1d4031a4ffc926 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -67,10 +67,10 @@ static bool Initialized = [] { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -std::unordered_map GetClusters(const Graph& graph) { - std::unordered_map ids; +std::unordered_map GetClusters(const Graph& graph) { + std::unordered_map ids; for (Node* node : graph.nodes()) { - string cluster; + std::string cluster; if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) { CHECK(!cluster.empty()); ids[node->name()] = cluster; @@ -86,10 +86,10 @@ std::unordered_map GetClusters(const Graph& graph) { return ids; } -std::set GetClusterNames(const Graph& graph) { - std::set names; +std::set GetClusterNames(const Graph& graph) { + std::set names; for (Node* node : graph.nodes()) { - string cluster; + std::string cluster; if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) { CHECK(!cluster.empty()); names.insert(cluster); @@ -98,10 +98,10 @@ std::set GetClusterNames(const Graph& graph) { return names; } -absl::flat_hash_map> GetClusterSets( - const Graph& g, std::vector* cluster_names = nullptr) { +absl::flat_hash_map> GetClusterSets( + const Graph& g, std::vector* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - absl::flat_hash_map> cluster_sets; + absl::flat_hash_map> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -357,7 +357,7 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + std::string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; testing::FindNodeByName(graph.get(), "A") ->set_assigned_device_name(xla_cpu_device); testing::FindNodeByName(graph.get(), "tanh0") @@ -694,7 +694,7 @@ TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) { } namespace { -Node* MakeRead(const Scope& scope, const string& id, +Node* MakeRead(const Scope& scope, const std::string& id, Node** var_handle_op = nullptr) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); @@ -706,7 +706,7 @@ Node* MakeRead(const Scope& scope, const string& id, return read.node(); } -Node* MakeWrite(const Scope& scope, const string& id) { +Node* MakeWrite(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output value_to_write = @@ -716,7 +716,7 @@ Node* MakeWrite(const Scope& scope, const string& id) { return assign_op.operation.node(); } -Node* MakeNeutral(const Scope& scope, const string& id) { +Node* MakeNeutral(const Scope& scope, const std::string& id) { return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); } } // namespace @@ -733,11 +733,11 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - absl::flat_hash_map> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", - "ValueToAssignW"}; + std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", + "ValueToAssignW"}; ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); } @@ -753,7 +753,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - absl::flat_hash_map> cluster_sets = + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 0); } @@ -779,13 +779,13 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::vector cluster_names; - absl::flat_hash_map> cluster_sets = + std::vector cluster_names; + absl::flat_hash_map> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 1); - std::vector expected_clustered_nodes_a = { + std::vector expected_clustered_nodes_a = { "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"}; ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); } @@ -881,7 +881,7 @@ TEST(XlaCompilationTest, ConstOp) { { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); - auto c = ops::Const(root.WithOpName("const"), string("string")); + auto c = ops::Const(root.WithOpName("const"), std::string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); @@ -901,12 +901,12 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; + std::string cluster_name = clusters.begin()->second; - std::unordered_map expected_clusters( + std::unordered_map expected_clusters( {{"negate", cluster_name}, {"add", cluster_name}}); EXPECT_EQ(clusters, expected_clusters); } @@ -924,12 +924,12 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; + std::string cluster_name = clusters.begin()->second; - std::unordered_map expected_clusters( + std::unordered_map expected_clusters( {{"negate", cluster_name}, {"identity", cluster_name}, {"add", cluster_name}}); @@ -956,7 +956,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't @@ -982,7 +982,7 @@ TEST(XlaCompilationTest, RandomShape) { TF_ASSERT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["shape"], ""); } @@ -1028,7 +1028,7 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) { TF_ASSERT_OK( MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["fn_call"], ""); } @@ -1054,12 +1054,12 @@ TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_gpu_device)); + n->set_assigned_device_name(std::string(xla_gpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/shape_rng"], ""); EXPECT_EQ(clusters["test/reshape"], ""); } @@ -1087,12 +1087,12 @@ TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_gpu_device)); + n->set_assigned_device_name(std::string(xla_gpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/read"], ""); EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); } @@ -1133,15 +1133,15 @@ TEST(XlaCompilationTest, DontClusterMergingNodes) { for (Node* n : graph->nodes()) { if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { - n->set_assigned_device_name(string(xla_gpu_dev0)); + n->set_assigned_device_name(std::string(xla_gpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { - n->set_assigned_device_name(string(xla_gpu_dev1)); + n->set_assigned_device_name(std::string(xla_gpu_dev1)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); // Each of the MatMuls should be in a separate cluster. - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]); EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]); @@ -1170,17 +1170,17 @@ TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) { for (Node* n : graph->nodes()) { if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) { - n->set_assigned_device_name(string(xla_cpu_dev0)); + n->set_assigned_device_name(std::string(xla_cpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { - n->set_assigned_device_name(string(xla_gpu_dev0)); + n->set_assigned_device_name(std::string(xla_gpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { - n->set_assigned_device_name(string(xla_gpu_dev1)); + n->set_assigned_device_name(std::string(xla_gpu_dev1)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); // Each of the MatMuls should be in a separate cluster. - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]); EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]); @@ -1223,14 +1223,14 @@ TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) { TF_ASSERT_OK(root.ToGraph(graph.get())); for (Node* n : graph->nodes()) { if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) { - n->set_assigned_device_name(string(xla_gpu_dev0)); + n->set_assigned_device_name(std::string(xla_gpu_dev0)); } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) { - n->set_assigned_device_name(string(xla_gpu_dev1)); + n->set_assigned_device_name(std::string(xla_gpu_dev1)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]); EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]); EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]); @@ -1254,12 +1254,12 @@ TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_cpu_device)); + n->set_assigned_device_name(std::string(xla_cpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/a"], ""); EXPECT_NE(clusters["test/b"], ""); EXPECT_NE(clusters["test/c"], ""); @@ -1277,7 +1277,7 @@ TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/a"], ""); EXPECT_EQ(clusters["test/b"], ""); } @@ -1299,12 +1299,12 @@ TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) { for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { - n->set_assigned_device_name(string(xla_cpu_device)); + n->set_assigned_device_name(std::string(xla_cpu_device)); } } TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/check"], ""); EXPECT_NE(clusters["test/greaterequal"], ""); EXPECT_NE(clusters["test/assert"], ""); @@ -1324,7 +1324,7 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/assert"], ""); EXPECT_EQ(clusters["test/check"], ""); } @@ -1345,7 +1345,7 @@ TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/tensor_list_reserve"], ""); } @@ -1373,7 +1373,7 @@ TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/tensor_list_element_shape"], ""); } @@ -1391,7 +1391,7 @@ TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; + std::string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0"; for (Node* n : graph->nodes()) { if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { n->set_assigned_device_name(xla_cpu_device); @@ -1400,7 +1400,7 @@ TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/tensor_list_reserve"], ""); } @@ -1427,7 +1427,7 @@ TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/x"], ""); @@ -1451,7 +1451,7 @@ TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/x"], ""); EXPECT_EQ(clusters["test/y"], ""); @@ -1473,7 +1473,7 @@ TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/x"], ""); EXPECT_EQ(clusters["test/y"], ""); @@ -1486,8 +1486,8 @@ TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) { Node* resource_read = MakeRead(root, "read", &var_handle); Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); - string resource_read_name = resource_read->name(); - string var_handle_name = var_handle->name(); + std::string resource_read_name = resource_read->name(); + std::string var_handle_name = var_handle->name(); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); @@ -1499,7 +1499,7 @@ TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/b"], ""); EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]); @@ -1512,8 +1512,8 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { Node* resource_read = MakeRead(root, "read", &var_handle); Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a); - string resource_read_name = resource_read->name(); - string var_handle_name = var_handle->name(); + std::string resource_read_name = resource_read->name(); + std::string var_handle_name = var_handle->name(); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); @@ -1525,7 +1525,7 @@ TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/b"], ""); EXPECT_EQ(clusters[resource_read_name], ""); @@ -1555,7 +1555,7 @@ TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/z"], ""); } @@ -1580,7 +1580,7 @@ TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["test/z"], ""); } @@ -1610,7 +1610,7 @@ TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/y"], ""); EXPECT_EQ(clusters["test/x"], clusters["test/y"]); @@ -1632,7 +1632,7 @@ TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["test/y"], ""); EXPECT_EQ(clusters["test/y"], clusters["test/x"]); @@ -1705,7 +1705,7 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["some_ctrl_input"], ""); EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_0_update"]); @@ -1875,19 +1875,19 @@ TEST(XlaCompilationTest, ClusterSessionName) { TF_ASSERT_OK( MarkForCompilationPassTestHelper::MarkForCompilation(&graph, options)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); ASSERT_FALSE(clusters.empty()); - string cluster_name = clusters.begin()->second; + std::string cluster_name = clusters.begin()->second; - std::unordered_map expected_clusters( + std::unordered_map expected_clusters( {{"negate", cluster_name}, {"add", cluster_name}}); EXPECT_EQ(clusters, expected_clusters); EXPECT_THAT(cluster_name, ::testing::StartsWith("test_session_name")); } namespace { -Node* MakeStageNode(GraphDefBuilder& builder, string name, +Node* MakeStageNode(GraphDefBuilder& builder, std::string name, std::initializer_list dtypes, absl::Span values) { auto opts = builder.opts() @@ -1949,7 +1949,7 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { &graph, MarkForCompilationPassTestHelper::Options().WithNoClusterScoping())); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_EQ(clusters["add0"], clusters["add1"]); EXPECT_EQ(clusters["add0"], clusters["relu1"]); EXPECT_EQ(clusters["relu0"], clusters["add1"]); @@ -1964,7 +1964,7 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - std::unordered_map clusters = GetClusters(*graph); + std::unordered_map clusters = GetClusters(*graph); EXPECT_NE(clusters["add0"], clusters["add1"]); EXPECT_NE(clusters["add0"], clusters["relu1"]); EXPECT_NE(clusters["relu0"], clusters["add1"]); @@ -1973,9 +1973,9 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { } TEST(XlaCompilationTest, XLALiteAllowlist) { auto* allowlist_table = tensorflow::GetAllowlistTable(); - absl::flat_hash_set hallowlist; - std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); - absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); + absl::flat_hash_set hallowlist; + std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); + absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that all the operations in the table are existing TF operations for (auto pair : *allowlist_table) { @@ -1988,10 +1988,10 @@ TEST(XlaCompilationTest, XLALiteAllowlist) { // Check that all registered XLA operation are in the allowlist // table or are known to not be in it. - absl::flat_hash_set known_not_in_list = + absl::flat_hash_set known_not_in_list = tensorflow::testing::GetKnownXLAAllowlistOp(); - std::vector unknow_op; - for (string op : vall_ops) { + std::vector unknow_op; + for (std::string op : vall_ops) { if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) { unknow_op.push_back(op); } diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index ce1f2cd5bcd671..db158fc84a0173 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -35,7 +35,7 @@ namespace { using impl::NodeMatcherProperties; using impl::OutEdge; -string IndentAllButFirstLine(absl::string_view text) { +std::string IndentAllButFirstLine(absl::string_view text) { std::vector lines = absl::StrSplit(text, '\n'); for (int i = 1; i < lines.size(); i++) { lines[i].insert(0, " "); @@ -86,21 +86,21 @@ bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, case DT_DOUBLE: return CompareTensor(tensor, expected_tensor, listener); case DT_INT8: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_INT16: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_INT32: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_INT64: return CompareTensor(tensor, expected_tensor, listener); case DT_UINT8: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_UINT16: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_UINT32: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); case DT_UINT64: - return CompareTensor(tensor, expected_tensor, listener); + return CompareTensor(tensor, expected_tensor, listener); default: LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. << DataType_Name(tensor.dtype()); @@ -188,7 +188,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (control_dep_set && !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { if (listener->IsInterested()) { - string explanation = inner_listener.str(); + std::string explanation = inner_listener.str(); if (!explanation.empty()) { explanation = absl::StrCat(", ", explanation, ","); } @@ -225,7 +225,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { } void DescribeTo(::std::ostream* os) const override { - std::vector predicates; + std::vector predicates; if (name) { predicates.push_back(absl::StrCat("name: ", *name)); @@ -282,10 +282,11 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (!attrs.empty()) { printed_something = true; - std::vector attrs_str; + std::vector attrs_str; absl::c_transform( attrs, std::back_inserter(attrs_str), - [](const std::pair>& attr_kv_pair) { + [](const std::pair>& + attr_kv_pair) { return absl::StrCat(attr_kv_pair.first, "->", attr_kv_pair.second ? SummarizeAttrValue(*attr_kv_pair.second) @@ -319,7 +320,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (listener->IsInterested()) { *listener << "\ninput " << input_idx << " does not match expected:\n"; (*input_matchers)[input_idx].DescribeTo(listener->stream()); - string explanation = inner_listener.str(); + std::string explanation = inner_listener.str(); if (!explanation.empty()) { *listener << ", " << explanation; } @@ -327,14 +328,14 @@ struct NodeMatcher : public ::testing::MatcherInterface { return false; } - std::optional op; - std::optional name; - std::optional assigned_device; + std::optional op; + std::optional name; + std::optional assigned_device; std::optional constant_value; std::optional>> input_matchers; std::optional<::testing::Matcher>> control_dep_set; - std::map> attrs; + std::map> attrs; }; // Matches a dst and dst_output on an input edge. Today we only use this with @@ -352,7 +353,7 @@ class OutEdgeMatcher : public ::testing::MatcherInterface { if (listener->IsInterested()) { *listener << "\nsource does not match expected "; src_matcher_.DescribeTo(listener->stream()); - string explanation = inner_listener.str(); + std::string explanation = inner_listener.str(); if (!explanation.empty()) { *listener << "\n\t" << explanation; } @@ -432,21 +433,21 @@ ::testing::Matcher impl::NodeWith( return ::testing::MakeMatcher(matcher); } -impl::NodeMatcherProperties Name(string name) { +impl::NodeMatcherProperties Name(std::string name) { impl::NodeMatcherProperties props; props.set_name(std::move(name)); return props; } // Matches a node with op `op`. -impl::NodeMatcherProperties Op(string op) { +impl::NodeMatcherProperties Op(std::string op) { impl::NodeMatcherProperties props; props.set_op(std::move(op)); return props; } // Matches a node with assigned device `assigned_device`. -impl::NodeMatcherProperties AssignedDevice(string assigned_device) { +impl::NodeMatcherProperties AssignedDevice(std::string assigned_device) { impl::NodeMatcherProperties props; props.set_assigned_device(std::move(assigned_device)); return props; @@ -472,15 +473,15 @@ impl::NodeMatcherProperties impl::CtrlDeps( return props; } -std::pair impl::AttrLiteralHelper( - const std::pair& bool_attr) { +std::pair impl::AttrLiteralHelper( + const std::pair& bool_attr) { AttrValue attr_value; attr_value.set_b(bool_attr.second); return {bool_attr.first, attr_value}; } -std::pair impl::AttrLiteralHelper( - const std::pair>& int_list_attr) { +std::pair impl::AttrLiteralHelper( + const std::pair>& int_list_attr) { AttrValue attr_value; AttrValue::ListValue* list = attr_value.mutable_list(); for (int i : int_list_attr.second) { @@ -489,23 +490,24 @@ std::pair impl::AttrLiteralHelper( return {int_list_attr.first, attr_value}; } -std::pair impl::AttrLiteralHelper( - const std::pair>& string_list_attr) { +std::pair impl::AttrLiteralHelper( + const std::pair>& + string_list_attr) { AttrValue attr_value; AttrValue::ListValue* list = attr_value.mutable_list(); - for (const string& s : string_list_attr.second) { + for (const std::string& s : string_list_attr.second) { list->add_s(s); } return {string_list_attr.first, attr_value}; } -impl::NodeMatcherProperties impl::Attr(std::pair attr) { +impl::NodeMatcherProperties impl::Attr(std::pair attr) { impl::NodeMatcherProperties props; props.set_attr(std::move(attr)); return props; } -impl::NodeMatcherProperties impl::Attr(string name) { +impl::NodeMatcherProperties impl::Attr(std::string name) { impl::NodeMatcherProperties props; props.set_attr({std::move(name), std::nullopt}); return props; diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h index bb2c1875306185..1391df3743bd4c 100644 --- a/tensorflow/compiler/jit/node_matchers.h +++ b/tensorflow/compiler/jit/node_matchers.h @@ -84,11 +84,11 @@ class NodeMatcherProperties { public: using NodeSeqMatcher = std::vector<::testing::Matcher>; using InputSeqMatcher = std::vector<::testing::Matcher>; - using AttrKeyValuePair = std::pair>; + using AttrKeyValuePair = std::pair>; - const std::optional& name() const { return name_; } - const std::optional& op() const { return op_; } - const std::optional& assigned_device() const { + const std::optional& name() const { return name_; } + const std::optional& op() const { return op_; } + const std::optional& assigned_device() const { return assigned_device_; } const std::optional& constant_value() const { @@ -102,17 +102,17 @@ class NodeMatcherProperties { } const std::optional& attr() const { return attr_; } - void set_name(string name) { + void set_name(std::string name) { DCHECK(IsEmpty()); name_ = std::move(name); } - void set_op(string op) { + void set_op(std::string op) { DCHECK(IsEmpty()); op_ = std::move(op); } - void set_assigned_device(string assigned_device) { + void set_assigned_device(std::string assigned_device) { DCHECK(IsEmpty()); assigned_device_ = std::move(assigned_device); } @@ -144,9 +144,9 @@ class NodeMatcherProperties { } private: - std::optional name_; - std::optional op_; - std::optional assigned_device_; + std::optional name_; + std::optional op_; + std::optional assigned_device_; std::optional constant_value_; std::optional input_matchers_; std::optional control_deps_; @@ -162,39 +162,40 @@ impl::NodeMatcherProperties Inputs( impl::NodeMatcherProperties CtrlDeps( absl::Span> control_deps); -impl::NodeMatcherProperties Attr(std::pair attrs); -impl::NodeMatcherProperties Attr(string name); +impl::NodeMatcherProperties Attr(std::pair attrs); +impl::NodeMatcherProperties Attr(std::string name); -std::pair AttrLiteralHelper( - const std::pair& bool_attr); +std::pair AttrLiteralHelper( + const std::pair& bool_attr); -std::pair AttrLiteralHelper( - const std::pair>& int_list_attr); +std::pair AttrLiteralHelper( + const std::pair>& int_list_attr); -std::pair AttrLiteralHelper( - const std::pair>& string_list_attr); +std::pair AttrLiteralHelper( + const std::pair>& + string_list_attr); } // namespace impl // ----------------------------------------------------------------------------- // Public interface. // Matches a node with name `name`. -impl::NodeMatcherProperties Name(string name); +impl::NodeMatcherProperties Name(std::string name); // Matches a node with op `op`. -impl::NodeMatcherProperties Op(string op); +impl::NodeMatcherProperties Op(std::string op); // Matches a node with assigned device `assigned_device`. -impl::NodeMatcherProperties AssignedDevice(string assigned_device); +impl::NodeMatcherProperties AssignedDevice(std::string assigned_device); // Matches a node with a boolean typed attribute named `name` and with value // `value`. template -impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { +impl::NodeMatcherProperties Attr(const std::string& name, ValueTy value) { return impl::Attr({impl::AttrLiteralHelper({name, value})}); } -inline impl::NodeMatcherProperties Attr(const string& name) { +inline impl::NodeMatcherProperties Attr(const std::string& name) { return impl::Attr(name); } diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc index 6f37d5617b6ce6..ac1d9ce3468df1 100644 --- a/tensorflow/compiler/jit/node_matchers_test.cc +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -41,7 +41,7 @@ using testing::matchers::Op; using testing::matchers::Out; template -string Explain(const T& t, const M& m) { +std::string Explain(const T& t, const M& m) { ::testing::StringMatchResultListener listener; EXPECT_THAT(t, ::testing::Not(m)); // For the error message. EXPECT_FALSE(m.MatchAndExplain(t, &listener)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index c8bbcee20e3829..9539a14d060f42 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -100,7 +100,7 @@ absl::Status PartiallyDecluster(std::unique_ptr* graph) { return pass.Run(opt_options); } -Node* FindNodeByName(const Graph& graph, const string& name) { +Node* FindNodeByName(const Graph& graph, const std::string& name) { for (Node* node : graph.nodes()) { if (node->name() == name) { return node; @@ -109,7 +109,7 @@ Node* FindNodeByName(const Graph& graph, const string& name) { return nullptr; } -bool GetInputsForNode(const Graph& graph, const string& node_name, +bool GetInputsForNode(const Graph& graph, const std::string& node_name, std::vector* inputs) { const Node* node = FindNodeByName(graph, node_name); if (node == nullptr) { @@ -292,7 +292,7 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { void AddToCluster(absl::Span nodes, absl::string_view cluster_name) { for (Node* n : nodes) { - n->AddAttr(kXlaClusterAttr, string(cluster_name)); + n->AddAttr(kXlaClusterAttr, std::string(cluster_name)); } } diff --git a/tensorflow/compiler/jit/pjrt_base_device.cc b/tensorflow/compiler/jit/pjrt_base_device.cc index ce7ed954575040..d25d77d6cff22b 100644 --- a/tensorflow/compiler/jit/pjrt_base_device.cc +++ b/tensorflow/compiler/jit/pjrt_base_device.cc @@ -17,8 +17,8 @@ limitations under the License. namespace tensorflow { namespace { -DeviceAttributes BuildPjRtBaseDeviceAttributes(const string& name_prefix, - const string& device_name, +DeviceAttributes BuildPjRtBaseDeviceAttributes(const std::string& name_prefix, + const std::string& device_name, int device_ordinal) { return Device::BuildDeviceAttributes( absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 2fee2b0b898890..33f09704d7c72b 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -143,7 +143,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { using ResourceOp = std::pair; -string ResourceOpToString(const ResourceOp& resource_op) { +std::string ResourceOpToString(const ResourceOp& resource_op) { return absl::StrCat( resource_op.first, ": ", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); @@ -233,14 +233,14 @@ class ResourceOpSet { void operator=(const ResourceOpSet&) = delete; }; -string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { - std::vector elements_debug_string; +std::string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { + std::vector elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } -string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { +std::string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { return absl::StrCat( "[", n.name(), ": ", n.type_string(), "(", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index 8a80b8ae9b3497..6b038c992f1715 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -38,7 +38,7 @@ limitations under the License. namespace tensorflow { namespace { -Node* MakeRead(const Scope& scope, const string& id) { +Node* MakeRead(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output read = @@ -46,7 +46,7 @@ Node* MakeRead(const Scope& scope, const string& id) { return read.node(); } -Node* MakeWrite(const Scope& scope, const string& id) { +Node* MakeWrite(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output value_to_write = @@ -56,7 +56,7 @@ Node* MakeWrite(const Scope& scope, const string& id) { return assign_op.operation.node(); } -Node* MakeModify(const Scope& scope, const string& id) { +Node* MakeModify(const Scope& scope, const std::string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f); @@ -65,7 +65,7 @@ Node* MakeModify(const Scope& scope, const string& id) { return assign_add_op.operation.node(); } -Node* MakeNeutral(const Scope& scope, const string& id) { +Node* MakeNeutral(const Scope& scope, const std::string& id) { return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); } @@ -238,7 +238,8 @@ TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { EXPECT_EQ(incompatible_pairs[1], write_modify_pair); } -FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { +FunctionDefLibrary CreateFunctionDefLibWithConstFunction( + const std::string& name) { FunctionDefLibrary flib_def; FunctionDef func = FunctionDefHelper::Create( /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, @@ -249,8 +250,8 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { return flib_def; } -Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, - absl::Status* status) { +Node* MakeCall(Graph* graph, const std::string& callee_name, + const std::string& node_name, absl::Status* status) { NodeDef call_node; call_node.set_name(node_name); call_node.set_op(callee_name); diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h index 467ecb83a74aae..b1469d2d699bf1 100644 --- a/tensorflow/compiler/jit/shape_inference.h +++ b/tensorflow/compiler/jit/shape_inference.h @@ -35,7 +35,8 @@ struct InferredShape { DataType handle_type = DT_INVALID; PartialTensorShape handle_shape; }; -typedef std::unordered_map> GraphShapeInfo; +typedef std::unordered_map> + GraphShapeInfo; // Infer shapes for all Tensors in a graph, and save them in a map. The vector // for a Node contains the information about each of its outputs. diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc index eaabf18c79603c..599d442de4b092 100644 --- a/tensorflow/compiler/jit/shape_inference_test.cc +++ b/tensorflow/compiler/jit/shape_inference_test.cc @@ -61,7 +61,7 @@ TEST(ShapeInferenceTest, Basics) { TF_ASSERT_OK(InferShapes(graph.get(), /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({3})}}, {"C", {PartialTensorShape()}}, {"D", {PartialTensorShape({2, 3})}}, {"E", {PartialTensorShape()}}, {"F", {PartialTensorShape()}}, @@ -94,7 +94,7 @@ TEST(ShapeInferenceTest, UseArgShapesForVariableBatchSize) { TF_ASSERT_OK(InferShapes(graph.get(), arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({2, 3})}}, {"C", {PartialTensorShape({2, 3})}}, @@ -127,7 +127,7 @@ TEST(ShapeInferenceTest, UseArgShapesForVariableBatchSizeIncompleteUserArgs) { TF_ASSERT_OK(InferShapes(graph.get(), arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({2, 3})}}, {"C", {PartialTensorShape({2, 3})}}, @@ -156,7 +156,7 @@ TEST(ShapeInferenceTest, WhileLoop) { ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -168,11 +168,11 @@ TEST(ShapeInferenceTest, WhileLoop) { auto identity = ops::Identity(scope.WithOpName("while/Identity"), switch_node.output_true); auto identity_shape = - ops::Const(scope.WithOpName("while/Identity/shape"), {}); + ops::Const(scope.WithOpName("while/Identity/shape"), {}); auto identity_reshaped = ops::Reshape( scope.WithOpName("while/Identity/reshaped"), identity, identity_shape); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity_reshaped, one); auto next_iteration = @@ -190,7 +190,7 @@ TEST(ShapeInferenceTest, WhileLoop) { GraphShapeInfo shape_info; TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, &shape_info)); - std::map> expected = { + std::map> expected = { {"while/Identity", {PartialTensorShape()}}, {"while/add", {PartialTensorShape({})}}, }; diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index 81ab1d8d05f96e..30a9ab51faf105 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -29,7 +29,7 @@ namespace tensorflow { absl::Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, - std::map> expected_shapes) { + std::map> expected_shapes) { for (Node* node : graph.op_nodes()) { auto sit = shape_info.find(node->name()); TF_RET_CHECK(sit != shape_info.end()) @@ -50,7 +50,7 @@ absl::Status ShapeAnnotationsMatch( } } if (!expected_shapes.empty()) { - std::vector missing; + std::vector missing; missing.reserve(expected_shapes.size()); for (const auto& entry : expected_shapes) { missing.push_back(entry.first); @@ -88,12 +88,12 @@ void DeviceSetup::AddDevicesAndSetUp( flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); } -Device* DeviceSetup::GetDevice(const string& device_name) { +Device* DeviceSetup::GetDevice(const std::string& device_name) { if (device_mgr_ == nullptr) { return nullptr; } - string full_device_name = absl::StrCat( + std::string full_device_name = absl::StrCat( "/job:localhost/replica:0/task:0/device:", device_name, ":0"); Device* device; TF_CHECK_OK(device_mgr_->LookupDevice(full_device_name, &device)); diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h index ec694662297399..ba7d2533ef7c74 100644 --- a/tensorflow/compiler/jit/test_util.h +++ b/tensorflow/compiler/jit/test_util.h @@ -44,7 +44,7 @@ namespace tensorflow { // `expected_shapes` entries. absl::Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, - std::map> expected_shapes); + std::map> expected_shapes); // A helper object to create GraphOptimizationPassOptions. struct GraphOptimizationPassWrapper { @@ -74,7 +74,7 @@ class DeviceSetup { void AddDevicesAndSetUp( const std::vector& device_names, const std::optional& fdef = std::nullopt); - Device* GetDevice(const string& device_name); + Device* GetDevice(const std::string& device_name); FunctionLibraryRuntime* flr() { return flr_; } private: diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test.cc b/tensorflow/compiler/jit/tests/auto_clustering_test.cc index 90e73c23d210d7..d108bc51b5ee33 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test.cc @@ -23,7 +23,7 @@ class AutoClusteringTestImpl : public AutoClusteringTest { protected: // Test auto-clustering with a proto text file ${key}.pbtxt. absl::Status RunAutoClusteringTestWithPbtxt(absl::string_view key) { - string file_name_without_extension = + std::string file_name_without_extension = absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key); return AutoClusteringTest::RunAutoClusteringTestWithPbtxt( @@ -33,7 +33,7 @@ class AutoClusteringTestImpl : public AutoClusteringTest { // Test auto-clustering with a gzipped proto text file ${key}.pbtxt.gz. absl::Status RunAutoClusteringTestWithGzippedPbtxt(absl::string_view key) { - string file_name_without_extension = + std::string file_name_without_extension = absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key); return AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index dee77ac750ee54..258449e91120e1 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -33,7 +33,7 @@ limitations under the License. namespace tensorflow { namespace { -absl::StatusOr SummarizeClustering( +absl::StatusOr SummarizeClustering( const GraphDef& auto_clustered_graph_def) { testing::ResetClusterSequenceNumber(); Graph graph(OpRegistry::Global()); @@ -45,7 +45,7 @@ absl::StatusOr SummarizeClustering( // cluster_id -> (operation name -> # of operations) const int kNoCluster = -1; - std::map> clusters; + std::map> clusters; std::map cluster_size; int clustered_nodes = 0; for (Node* n : graph.op_nodes()) { @@ -60,7 +60,7 @@ absl::StatusOr SummarizeClustering( cluster_size[cluster]++; } - string result = + std::string result = absl::StrCat("Clustered nodes: ", clustered_nodes, "\nUnclustered nodes: ", cluster_size[kNoCluster], "\nNumber of clusters: ", clusters.size() - 1, "\n\n"); @@ -99,7 +99,7 @@ absl::Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { return absl::OkStatus(); } -absl::Status ReadTextProtoFromString(Env* env, const string& data, +absl::Status ReadTextProtoFromString(Env* env, const std::string& data, ::tensorflow::protobuf::Message* proto) { if (!::tensorflow::protobuf::TextFormat::ParseFromString(data, proto)) { return errors::DataLoss("Can't parse input data as text proto"); @@ -141,7 +141,8 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestImpl( graphdef = std::move(next); } - TF_ASSIGN_OR_RETURN(string clustering_summary, SummarizeClustering(graphdef)); + TF_ASSIGN_OR_RETURN(std::string clustering_summary, + SummarizeClustering(graphdef)); // To update golden files flip this to true and run // @@ -149,13 +150,15 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestImpl( // tensorflow/compiler/jit/tests:auto_clustering_test bool update_golden = false; if (update_golden) { - TF_RETURN_IF_ERROR(WriteStringToFile( - Env::Default(), string(golden_summary_file_path), clustering_summary)); + TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), + std::string(golden_summary_file_path), + clustering_summary)); } - string golden_file_contents; - TF_RETURN_IF_ERROR(ReadFileToString( - Env::Default(), string(golden_summary_file_path), &golden_file_contents)); + std::string golden_file_contents; + TF_RETURN_IF_ERROR(ReadFileToString(Env::Default(), + std::string(golden_summary_file_path), + &golden_file_contents)); EXPECT_EQ(golden_file_contents, clustering_summary); @@ -167,7 +170,7 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( absl::string_view golden_summary_file_path) { GraphDef graphdef; TF_RETURN_IF_ERROR( - ReadTextProto(Env::Default(), string(pbtxt_file_path), &graphdef)); + ReadTextProto(Env::Default(), std::string(pbtxt_file_path), &graphdef)); return RunAutoClusteringTestImpl(std::move(graphdef), golden_summary_file_path); } @@ -177,8 +180,8 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( absl::string_view golden_summary_file_path) { Env* env = Env::Default(); std::unique_ptr file_reader; - TF_RETURN_IF_ERROR( - env->NewRandomAccessFile(string(gzipped_pbtxt_file_path), &file_reader)); + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + std::string(gzipped_pbtxt_file_path), &file_reader)); std::unique_ptr input_stream( new io::RandomAccessInputStream(file_reader.get())); constexpr int k_buffer_size = 256 << 10; // 256kb @@ -206,7 +209,7 @@ absl::Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, benchmark::State& state) { GraphDef graph_def; TF_RETURN_IF_ERROR( - ReadTextProto(Env::Default(), string(graph_def_path), &graph_def)); + ReadTextProto(Env::Default(), std::string(graph_def_path), &graph_def)); OptimizationPassRunner runner; TF_RETURN_IF_ERROR(runner.SetJitLevel(tensorflow::OptimizerOptions::ON_2)); diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc index e4be1a1f641656..33e2daf941eafb 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc @@ -29,7 +29,7 @@ namespace { // Creates a float tensor of linearly increasing values, starting from offset. Tensor CreateInputTensor(const TensorShape& shape, float offset) { Tensor tensor(DT_FLOAT, shape); - for (int64 i = 0; i < tensor.flat().size(); ++i) { + for (int64_t i = 0; i < tensor.flat().size(); ++i) { tensor.flat()(i) = offset + i; } return tensor; @@ -127,7 +127,7 @@ absl::Status DeviceCompilerSerializeTest::ExecuteWithBatch( } Tensor f32_input(DT_FLOAT, shape); - for (int64 i = 0; i < f32_input.NumElements(); ++i) { + for (int64_t i = 0; i < f32_input.NumElements(); ++i) { EXPECT_NEAR(golden_output_tensors[0].flat()(i), output_tensors[0].flat()(i), 1e-3); } @@ -139,7 +139,7 @@ DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( absl::string_view persistent_cache_dir_path, absl::string_view file_prefix) { Env* env = Env::Default(); - std::vector file_names; + std::vector file_names; TF_RETURN_IF_ERROR( env->GetChildren(tensorflow::testing::TmpDir(), &file_names)); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index ce0285c2e797d2..8ccb236897ce39 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -45,11 +45,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/local_client.h" +#include "xla/future.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/pjrt_future.h" #include "xla/service/executable.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" @@ -809,8 +809,6 @@ xla::ExecuteOptions GetPjRtExecuteOptions( const DeviceType& device_type, absl::flat_hash_set non_donatable_input_indices) { xla::ExecuteOptions options; - options.arguments_are_tupled = false; - options.untuple_result = true; // Hardcode run id to always be one: TF distributed strategy // differentiates between subsequent runs using dependency edges. This // is safe, as only TF dist-strat can produce distributed ops, and we diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index 9e71286dc95df8..d8ed5feac79f12 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -207,8 +207,6 @@ class PjRtExecutionUtilTest : public OpsTestBase { &executable_args, /*owned_args=*/{}, &non_donatable_input_indices)); xla::ExecuteOptions exe_options; - exe_options.arguments_are_tupled = false; - exe_options.untuple_result = true; // TODO(b/257548614): currently PJRT is compiled as portable (num_replica = // 1 and num_partition = 1). Support multiple partitions case. @@ -520,8 +518,6 @@ TEST_F(PjRtExecutionUtilTest, PopulateCtxOutputsResourceUpdates) { TEST(XlaLaunchUtilTest, GetPjRtExecuteOptions) { xla::ExecuteOptions options = GetPjRtExecuteOptions(DeviceType(DEVICE_GPU), {}); - EXPECT_FALSE(options.arguments_are_tupled); - EXPECT_TRUE(options.untuple_result); EXPECT_FALSE(options.strict_shape_checking); EXPECT_TRUE(options.use_major_to_minor_data_layout_for_callbacks); } diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7f200aa186a466..ab6c5abeca86f0 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1990,7 +1990,6 @@ cc_library( ":tf_tfl_passes", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/lite/core:macros", "//tensorflow/compiler/mlir/lite/debug", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:converter_error_data_proto_cc", @@ -2212,10 +2211,8 @@ tf_proto_library( srcs = ["converter_flags.proto"], make_default_target_header_only = True, protodeps = [ - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_options_proto", - "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto", - "//tensorflow/compiler/mlir/lite/debug:debug_options_proto", ":types_proto", + "//tensorflow/compiler/mlir/lite/debug:debug_options_proto", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/compiler/mlir/lite/converter_flags.proto b/tensorflow/compiler/mlir/lite/converter_flags.proto index 1c1a1ad00aea74..49795ad8337d9a 100644 --- a/tensorflow/compiler/mlir/lite/converter_flags.proto +++ b/tensorflow/compiler/mlir/lite/converter_flags.proto @@ -17,8 +17,6 @@ package tflite; import "tensorflow/compiler/mlir/lite/debug/debug_options.proto"; import "tensorflow/compiler/mlir/lite/types.proto"; -import "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto"; -import "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto"; // Supported I/O file formats. Some formats may be input-only or output-only. enum FileFormat { @@ -43,6 +41,8 @@ enum FileFormat { // // Next ID to use: 69. message ConverterFlags { + reserved 54, 61; + // Input file format optional FileFormat input_format = 1; @@ -312,12 +312,6 @@ message ConverterFlags { // If true, disable folding mul->fc as in layer norm during optimize pass. optional bool disable_fuse_mul_and_fc = 53 [default = false]; - // Indicates the quantization specs. Quantization spec can be set to either - // a preset method or a custom method. - // Note: This is deprecated; use `quantization_config` instead. - optional stablehlo.quantization.QuantizationOptions quantization_options = 54 - [deprecated = true]; - // Flag to enable hlo to tf conversion. // This is useful to exercise StableHLO -> HLO -> TF -> TFLite path. optional bool enable_hlo_to_tf_conversion = 55 @@ -346,11 +340,6 @@ message ConverterFlags { // WARNING: Experimental interface, subject to change. optional string qdq_conversion_mode = 60 [default = "NONE"]; - // Configures quantization behavior. This config is fed to the StableHLO - // Quantizer integrated in the converter. - // WARNING: Experimental interface, subject to change. - optional stablehlo.quantization.QuantizationConfig quantization_config = 61; - // Disables per channel weights quantization for Dense layers and enables // legacy per tensor quantization. The legacy quantization for Dense layers is // inconsistent with Conv 1x1 which always performs per channel quantization. diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc index 6c26865757950a..b82d5725182745 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc @@ -120,7 +120,7 @@ class InitPassManagerTest : public testing::Test { } absl::Status GetDumpDir(std::string* dump_dir) { - std::vector files; + std::vector files; if (auto status = tsl::Env::Default()->GetChildren(path_, &files); !status.ok()) { return status; diff --git a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc index 0d83e1971072c3..80975abd3e9a7a 100644 --- a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc +++ b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "mlir-c/IR.h" // from @llvm-project #include "mlir/Bindings/Python/NanobindAdaptors.h" // from @llvm-project // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 05abe12b6ebf58..d7027e91f480ef 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -4962,7 +4962,7 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl& regions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -5233,6 +5233,22 @@ int64_t SoftmaxOp::GetArithmeticCount(Operation* op) { // TanhOp //===----------------------------------------------------------------------===// +OpFoldResult TanhOp::fold(FoldAdaptor adaptor) { + if (!ShouldFoldOperation(this->getOperation())) return {}; + + auto operands = adaptor.getOperands(); + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::tanh(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + int64_t TanhOp::GetArithmeticCount(Operation* op) { int64_t count; // As a very rough ballpark, the cost of evaluating a math function diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 64fc866b2be055..c90859cd6accfe 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1100,7 +1100,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, - TFL_TensorOf<[F32, QI4, QI8, QUI8, QI16]>:$filter, + TFL_TensorOf<[F32, QI2, QI4, QI8, QUI8, QI16]>:$filter, TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_AFAttr:$fused_activation_function, @@ -2477,13 +2477,13 @@ equivalent to setting: }]; let arguments = (ins - TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, + TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - TFL_TensorOf<[F32, I32, I64, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output + TFL_TensorOf<[F32, I32, I64, QI4, I8, UI8, UI32, I1, TFL_Str, QI8, QUI8, TFL_Quint8, QI16]>:$output ); let hasVerifier = 1; @@ -3575,6 +3575,8 @@ def TFL_TanhOp: TFL_Op<"tanh", [ /*scale=*/1.0 / (1<<(bit_width-1)), /*zero_point=*/0); } }]; + + let hasFolder = 1; } def TFL_TileOp: TFL_Op<"tile", [ @@ -4279,7 +4281,7 @@ def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> { quantization parameters. }]; - let arguments = (ins TFL_TensorOf<[QI4, QI8, QUI8, QI16, F16]>:$input); + let arguments = (ins TFL_TensorOf<[QI2, QI4, QI8, QUI8, QI16, F16]>:$input); let results = (outs TFL_FpTensor:$output); diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc index a3ae7f73b24f24..b5a3319ba13362 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc +++ b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape_test.cc @@ -19,9 +19,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h" #include -#include #include -#include #include #include diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h index aa700dc166e046..29ed664e7ae78f 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h @@ -31,7 +31,7 @@ namespace tensorflow { // error status if it fails to convert the input. absl::Status ConvertJaxToTFLiteFlatBuffer( const std::string& input, const tflite::ModelFlags& model_flags, - tflite::ConverterFlags& converter_flags, string* result); + tflite::ConverterFlags& converter_flags, std::string* result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index fa94cd3b5b8120..c334f24442b491 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -140,8 +140,8 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( mlir::TFL::QuantizationSpecs quant_specs; // Parse input arrays. - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; std::vector> node_mins; std::vector> node_maxs; @@ -210,8 +210,6 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( converter_flags.convert_to_stablehlo(); pass_config.legalize_custom_tensor_list_ops = converter_flags.legalize_custom_tensor_list_ops(); - pass_config.enable_stablehlo_quantizer = - converter_flags.has_quantization_config(); pass_config.enable_composite_direct_lowering = converter_flags.enable_composite_direct_lowering(); pass_config.model_origin_framework = converter_flags.model_origin_framework(); diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h index 33b9bacf2dfdeb..446652ccb8da05 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -32,7 +32,7 @@ namespace tensorflow { // error status if it fails to convert the input. absl::Status ConvertSavedModelToTFLiteFlatBuffer( const tflite::ModelFlags& model_flags, - tflite::ConverterFlags& converter_flags, string* result, + tflite::ConverterFlags& converter_flags, std::string* result, const quantization::PyFunctionLibrary* quantization_py_function_lib); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index f837a6f0140e7b..de75080ab5da82 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -46,8 +46,8 @@ absl::Status RegisterAllCustomOps( absl::Status PopulateQuantizationSpecs( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, - mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, - std::vector* node_dtypes, + mlir::TFL::QuantizationSpecs* quant_specs, + std::vector* node_names, std::vector* node_dtypes, std::vector>>* node_shapes, std::vector>* node_mins, std::vector>* node_maxs); @@ -60,7 +60,8 @@ absl::Status ConvertMLIRToTFLiteFlatBuffer( std::unique_ptr&& context, mlir::OwningOpRef module, const mlir::TFL::PassConfig& pass_config, - const std::unordered_set& saved_model_tags, string* result, + const std::unordered_set& saved_model_tags, + std::string* result, const quantization::PyFunctionLibrary* quantization_py_function_lib); // Give a warning for any unused flags that have been specified. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index ae3b6233f8e959..1e1f79af16cbd6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -93,7 +93,7 @@ std::vector GetAsVector(const flatbuffers::Vector* vec) { class QuantizeWeightsTest : public testing::Test { protected: - QuantizeWeightsTest() {} + QuantizeWeightsTest() = default; void LoadBasicModel() { input_model_ = ReadTestModel(); diff --git a/tensorflow/compiler/mlir/lite/schema/schema.fbs b/tensorflow/compiler/mlir/lite/schema/schema.fbs index 01a214ab2c03bf..6cd1c51fb0cf9e 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema.fbs +++ b/tensorflow/compiler/mlir/lite/schema/schema.fbs @@ -24,6 +24,8 @@ // Version 3c: Move constant tensor buffers & custom op buffers outside from // Flatbuffers. Has backward compatibility with version 3, 3a and // 3b. +// Version 3d: Add ExternalBuffer tables and tensor.external_buffer field for +// referencing immutable data stored in external files. namespace tflite; @@ -263,6 +265,11 @@ table Tensor { // Currently only 1 subtype is supported. The field is defined as an array for // flexibility of supporting multiple subtypes in the future. variant_tensors:[VariantSubType]; + + // Optional reference to an ExternalBuffer entry that stores constant tensor + // data outside of the FlatBuffer. A value of 0 indicates that the tensor uses + // the traditional embedded buffer field instead. + external_buffer:uint; } // A list of builtin operators. Builtin operators are slightly faster than custom @@ -1613,6 +1620,22 @@ table Buffer { size: ulong; } +// Groups external buffers by file/URI. +table ExternalBufferGroup { + name:string; +} + +// Describes an immutable data slice stored in an external file. +table ExternalBuffer { + // Unique identifier for this external buffer. + id:uint; + // Index into the external_buffer_groups array. + group:uint; + offset:ulong; + length:ulong; + packing:string; +} + table Metadata { // A human readable string to uniquely identify a Metadata. name:string; @@ -1680,6 +1703,12 @@ table Model { // Optional SignatureDefs for the model. signature_defs:[SignatureDef]; + + // Optional groups for external weight buffers. + external_buffer_groups:[ExternalBufferGroup]; + + // Optional list of external weight buffers referenced by tensors. + external_buffers:[ExternalBuffer]; } root_type Model; diff --git a/tensorflow/compiler/mlir/lite/schema/schema_generated.h b/tensorflow/compiler/mlir/lite/schema/schema_generated.h index b04076af12a074..2b1701a8b9c0b9 100755 --- a/tensorflow/compiler/mlir/lite/schema/schema_generated.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_generated.h @@ -681,6 +681,14 @@ struct Buffer; struct BufferBuilder; struct BufferT; +struct ExternalBufferGroup; +struct ExternalBufferGroupBuilder; +struct ExternalBufferGroupT; + +struct ExternalBuffer; +struct ExternalBufferBuilder; +struct ExternalBufferT; + struct Metadata; struct MetadataBuilder; struct MetadataT; @@ -5952,6 +5960,7 @@ struct TensorT : public ::flatbuffers::NativeTable { std::vector shape_signature{}; bool has_rank = false; std::vector> variant_tensors{}; + uint32_t external_buffer = 0; TensorT() = default; TensorT(const TensorT &o); TensorT(TensorT&&) FLATBUFFERS_NOEXCEPT = default; @@ -5971,7 +5980,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_SPARSITY = 16, VT_SHAPE_SIGNATURE = 18, VT_HAS_RANK = 20, - VT_VARIANT_TENSORS = 22 + VT_VARIANT_TENSORS = 22, + VT_EXTERNAL_BUFFER = 24 }; const ::flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -6003,6 +6013,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector<::flatbuffers::Offset> *variant_tensors() const { return GetPointer> *>(VT_VARIANT_TENSORS); } + uint32_t external_buffer() const { + return GetField(VT_EXTERNAL_BUFFER, 0); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -6022,6 +6035,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyOffset(verifier, VT_VARIANT_TENSORS) && verifier.VerifyVector(variant_tensors()) && verifier.VerifyVectorOfTables(variant_tensors()) && + VerifyField(verifier, VT_EXTERNAL_BUFFER, 4) && verifier.EndTable(); } TensorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -6063,6 +6077,9 @@ struct TensorBuilder { void add_variant_tensors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors) { fbb_.AddOffset(Tensor::VT_VARIANT_TENSORS, variant_tensors); } + void add_external_buffer(uint32_t external_buffer) { + fbb_.AddElement(Tensor::VT_EXTERNAL_BUFFER, external_buffer, 0); + } explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -6085,8 +6102,10 @@ inline ::flatbuffers::Offset CreateTensor( ::flatbuffers::Offset sparsity = 0, ::flatbuffers::Offset<::flatbuffers::Vector> shape_signature = 0, bool has_rank = false, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors = 0, + uint32_t external_buffer = 0) { TensorBuilder builder_(_fbb); + builder_.add_external_buffer(external_buffer); builder_.add_variant_tensors(variant_tensors); builder_.add_shape_signature(shape_signature); builder_.add_sparsity(sparsity); @@ -6111,7 +6130,8 @@ inline ::flatbuffers::Offset CreateTensorDirect( ::flatbuffers::Offset sparsity = 0, const std::vector *shape_signature = nullptr, bool has_rank = false, - const std::vector<::flatbuffers::Offset> *variant_tensors = nullptr) { + const std::vector<::flatbuffers::Offset> *variant_tensors = nullptr, + uint32_t external_buffer = 0) { auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; auto name__ = name ? _fbb.CreateString(name) : 0; auto shape_signature__ = shape_signature ? _fbb.CreateVector(*shape_signature) : 0; @@ -6127,7 +6147,8 @@ inline ::flatbuffers::Offset CreateTensorDirect( sparsity, shape_signature__, has_rank, - variant_tensors__); + variant_tensors__, + external_buffer); } ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -16531,6 +16552,182 @@ inline ::flatbuffers::Offset CreateBufferDirect( ::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ExternalBufferGroupT : public ::flatbuffers::NativeTable { + typedef ExternalBufferGroup TableType; + std::string name{}; +}; + +struct ExternalBufferGroup FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ExternalBufferGroupT NativeTableType; + typedef ExternalBufferGroupBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + verifier.EndTable(); + } + ExternalBufferGroupT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExternalBufferGroupT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExternalBufferGroupBuilder { + typedef ExternalBufferGroup Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(ExternalBufferGroup::VT_NAME, name); + } + explicit ExternalBufferGroupBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateExternalBufferGroup( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0) { + ExternalBufferGroupBuilder builder_(_fbb); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateExternalBufferGroupDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return tflite::CreateExternalBufferGroup( + _fbb, + name__); +} + +::flatbuffers::Offset CreateExternalBufferGroup(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExternalBufferT : public ::flatbuffers::NativeTable { + typedef ExternalBuffer TableType; + uint32_t id = 0; + uint32_t group = 0; + uint64_t offset = 0; + uint64_t length = 0; + std::string packing{}; +}; + +struct ExternalBuffer FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ExternalBufferT NativeTableType; + typedef ExternalBufferBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ID = 4, + VT_GROUP = 6, + VT_OFFSET = 8, + VT_LENGTH = 10, + VT_PACKING = 12 + }; + uint32_t id() const { + return GetField(VT_ID, 0); + } + uint32_t group() const { + return GetField(VT_GROUP, 0); + } + uint64_t offset() const { + return GetField(VT_OFFSET, 0); + } + uint64_t length() const { + return GetField(VT_LENGTH, 0); + } + const ::flatbuffers::String *packing() const { + return GetPointer(VT_PACKING); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ID, 4) && + VerifyField(verifier, VT_GROUP, 4) && + VerifyField(verifier, VT_OFFSET, 8) && + VerifyField(verifier, VT_LENGTH, 8) && + VerifyOffset(verifier, VT_PACKING) && + verifier.VerifyString(packing()) && + verifier.EndTable(); + } + ExternalBufferT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExternalBufferT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExternalBufferBuilder { + typedef ExternalBuffer Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_id(uint32_t id) { + fbb_.AddElement(ExternalBuffer::VT_ID, id, 0); + } + void add_group(uint32_t group) { + fbb_.AddElement(ExternalBuffer::VT_GROUP, group, 0); + } + void add_offset(uint64_t offset) { + fbb_.AddElement(ExternalBuffer::VT_OFFSET, offset, 0); + } + void add_length(uint64_t length) { + fbb_.AddElement(ExternalBuffer::VT_LENGTH, length, 0); + } + void add_packing(::flatbuffers::Offset<::flatbuffers::String> packing) { + fbb_.AddOffset(ExternalBuffer::VT_PACKING, packing); + } + explicit ExternalBufferBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateExternalBuffer( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t id = 0, + uint32_t group = 0, + uint64_t offset = 0, + uint64_t length = 0, + ::flatbuffers::Offset<::flatbuffers::String> packing = 0) { + ExternalBufferBuilder builder_(_fbb); + builder_.add_length(length); + builder_.add_offset(offset); + builder_.add_packing(packing); + builder_.add_group(group); + builder_.add_id(id); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateExternalBufferDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t id = 0, + uint32_t group = 0, + uint64_t offset = 0, + uint64_t length = 0, + const char *packing = nullptr) { + auto packing__ = packing ? _fbb.CreateString(packing) : 0; + return tflite::CreateExternalBuffer( + _fbb, + id, + group, + offset, + length, + packing__); +} + +::flatbuffers::Offset CreateExternalBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct MetadataT : public ::flatbuffers::NativeTable { typedef Metadata TableType; std::string name{}; @@ -16802,6 +16999,8 @@ struct ModelT : public ::flatbuffers::NativeTable { std::vector metadata_buffer{}; std::vector> metadata{}; std::vector> signature_defs{}; + std::vector> external_buffer_groups{}; + std::vector> external_buffers{}; ModelT() = default; ModelT(const ModelT &o); ModelT(ModelT&&) FLATBUFFERS_NOEXCEPT = default; @@ -16819,7 +17018,9 @@ struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_BUFFERS = 12, VT_METADATA_BUFFER = 14, VT_METADATA = 16, - VT_SIGNATURE_DEFS = 18 + VT_SIGNATURE_DEFS = 18, + VT_EXTERNAL_BUFFER_GROUPS = 20, + VT_EXTERNAL_BUFFERS = 22 }; uint32_t version() const { return GetField(VT_VERSION, 0); @@ -16845,6 +17046,12 @@ struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector<::flatbuffers::Offset> *signature_defs() const { return GetPointer> *>(VT_SIGNATURE_DEFS); } + const ::flatbuffers::Vector<::flatbuffers::Offset> *external_buffer_groups() const { + return GetPointer> *>(VT_EXTERNAL_BUFFER_GROUPS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *external_buffers() const { + return GetPointer> *>(VT_EXTERNAL_BUFFERS); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_VERSION, 4) && @@ -16867,6 +17074,12 @@ struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyOffset(verifier, VT_SIGNATURE_DEFS) && verifier.VerifyVector(signature_defs()) && verifier.VerifyVectorOfTables(signature_defs()) && + VerifyOffset(verifier, VT_EXTERNAL_BUFFER_GROUPS) && + verifier.VerifyVector(external_buffer_groups()) && + verifier.VerifyVectorOfTables(external_buffer_groups()) && + VerifyOffset(verifier, VT_EXTERNAL_BUFFERS) && + verifier.VerifyVector(external_buffers()) && + verifier.VerifyVectorOfTables(external_buffers()) && verifier.EndTable(); } ModelT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -16902,6 +17115,12 @@ struct ModelBuilder { void add_signature_defs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs) { fbb_.AddOffset(Model::VT_SIGNATURE_DEFS, signature_defs); } + void add_external_buffer_groups(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffer_groups) { + fbb_.AddOffset(Model::VT_EXTERNAL_BUFFER_GROUPS, external_buffer_groups); + } + void add_external_buffers(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffers) { + fbb_.AddOffset(Model::VT_EXTERNAL_BUFFERS, external_buffers); + } explicit ModelBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -16922,8 +17141,12 @@ inline ::flatbuffers::Offset CreateModel( ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> buffers = 0, ::flatbuffers::Offset<::flatbuffers::Vector> metadata_buffer = 0, ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> metadata = 0, - ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffer_groups = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> external_buffers = 0) { ModelBuilder builder_(_fbb); + builder_.add_external_buffers(external_buffers); + builder_.add_external_buffer_groups(external_buffer_groups); builder_.add_signature_defs(signature_defs); builder_.add_metadata(metadata); builder_.add_metadata_buffer(metadata_buffer); @@ -16944,7 +17167,9 @@ inline ::flatbuffers::Offset CreateModelDirect( const std::vector<::flatbuffers::Offset> *buffers = nullptr, const std::vector *metadata_buffer = nullptr, const std::vector<::flatbuffers::Offset> *metadata = nullptr, - const std::vector<::flatbuffers::Offset> *signature_defs = nullptr) { + const std::vector<::flatbuffers::Offset> *signature_defs = nullptr, + const std::vector<::flatbuffers::Offset> *external_buffer_groups = nullptr, + const std::vector<::flatbuffers::Offset> *external_buffers = nullptr) { auto operator_codes__ = operator_codes ? _fbb.CreateVector<::flatbuffers::Offset>(*operator_codes) : 0; auto subgraphs__ = subgraphs ? _fbb.CreateVector<::flatbuffers::Offset>(*subgraphs) : 0; auto description__ = description ? _fbb.CreateString(description) : 0; @@ -16952,6 +17177,8 @@ inline ::flatbuffers::Offset CreateModelDirect( auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector(*metadata_buffer) : 0; auto metadata__ = metadata ? _fbb.CreateVector<::flatbuffers::Offset>(*metadata) : 0; auto signature_defs__ = signature_defs ? _fbb.CreateVector<::flatbuffers::Offset>(*signature_defs) : 0; + auto external_buffer_groups__ = external_buffer_groups ? _fbb.CreateVector<::flatbuffers::Offset>(*external_buffer_groups) : 0; + auto external_buffers__ = external_buffers ? _fbb.CreateVector<::flatbuffers::Offset>(*external_buffers) : 0; return tflite::CreateModel( _fbb, version, @@ -16961,7 +17188,9 @@ inline ::flatbuffers::Offset CreateModelDirect( buffers__, metadata_buffer__, metadata__, - signature_defs__); + signature_defs__, + external_buffer_groups__, + external_buffers__); } ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -17215,7 +17444,7 @@ inline void SparsityParameters::UnPackTo(SparsityParametersT *_o, const ::flatbu (void)_resolver; { auto _e = traversal_order(); if (_e) { _o->traversal_order.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->traversal_order[_i] = _e->Get(_i); } } else { _o->traversal_order.resize(0); } } { auto _e = block_map(); if (_e) { _o->block_map.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->block_map[_i] = _e->Get(_i); } } else { _o->block_map.resize(0); } } - { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->dim_metadata[_i]) { _e->Get(_i)->UnPackTo(_o->dim_metadata[_i].get(), _resolver); } else { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->dim_metadata.resize(0); } } + { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->dim_metadata[_i]) { _e->Get(_i)->UnPackTo(_o->dim_metadata[_i].get(), _resolver); } else { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->dim_metadata.resize(0); } } } inline ::flatbuffers::Offset SparsityParameters::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -17277,7 +17506,8 @@ inline TensorT::TensorT(const TensorT &o) is_variable(o.is_variable), sparsity((o.sparsity) ? new tflite::SparsityParametersT(*o.sparsity) : nullptr), shape_signature(o.shape_signature), - has_rank(o.has_rank) { + has_rank(o.has_rank), + external_buffer(o.external_buffer) { variant_tensors.reserve(o.variant_tensors.size()); for (const auto &variant_tensors_ : o.variant_tensors) { variant_tensors.emplace_back((variant_tensors_) ? new tflite::VariantSubTypeT(*variant_tensors_) : nullptr); } } @@ -17293,6 +17523,7 @@ inline TensorT &TensorT::operator=(TensorT o) FLATBUFFERS_NOEXCEPT { std::swap(shape_signature, o.shape_signature); std::swap(has_rank, o.has_rank); std::swap(variant_tensors, o.variant_tensors); + std::swap(external_buffer, o.external_buffer); return *this; } @@ -17314,7 +17545,8 @@ inline void Tensor::UnPackTo(TensorT *_o, const ::flatbuffers::resolver_function { auto _e = sparsity(); if (_e) { if(_o->sparsity) { _e->UnPackTo(_o->sparsity.get(), _resolver); } else { _o->sparsity = std::unique_ptr(_e->UnPack(_resolver)); } } else if (_o->sparsity) { _o->sparsity.reset(); } } { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } else { _o->shape_signature.resize(0); } } { auto _e = has_rank(); _o->has_rank = _e; } - { auto _e = variant_tensors(); if (_e) { _o->variant_tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->variant_tensors[_i]) { _e->Get(_i)->UnPackTo(_o->variant_tensors[_i].get(), _resolver); } else { _o->variant_tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->variant_tensors.resize(0); } } + { auto _e = variant_tensors(); if (_e) { _o->variant_tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->variant_tensors[_i]) { _e->Get(_i)->UnPackTo(_o->variant_tensors[_i].get(), _resolver); } else { _o->variant_tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->variant_tensors.resize(0); } } + { auto _e = external_buffer(); _o->external_buffer = _e; } } inline ::flatbuffers::Offset Tensor::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -17335,6 +17567,7 @@ inline ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuild auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0; auto _has_rank = _o->has_rank; auto _variant_tensors = _o->variant_tensors.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->variant_tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateVariantSubType(*__va->__fbb, __va->__o->variant_tensors[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _external_buffer = _o->external_buffer; return tflite::CreateTensor( _fbb, _shape, @@ -17346,7 +17579,8 @@ inline ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuild _sparsity, _shape_signature, _has_rank, - _variant_tensors); + _variant_tensors, + _external_buffer); } inline StablehloGatherOptionsT *StablehloGatherOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { @@ -21575,10 +21809,10 @@ inline SubGraphT *SubGraph::UnPack(const ::flatbuffers::resolver_function_t *_re inline void SubGraph::UnPackTo(SubGraphT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->tensors[_i]) { _e->Get(_i)->UnPackTo(_o->tensors[_i].get(), _resolver); } else { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->tensors.resize(0); } } + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->tensors[_i]) { _e->Get(_i)->UnPackTo(_o->tensors[_i].get(), _resolver); } else { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->tensors.resize(0); } } { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } else { _o->inputs.resize(0); } } { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } else { _o->outputs.resize(0); } } - { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operators[_i]) { _e->Get(_i)->UnPackTo(_o->operators[_i].get(), _resolver); } else { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operators.resize(0); } } + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operators[_i]) { _e->Get(_i)->UnPackTo(_o->operators[_i].get(), _resolver); } else { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->operators.resize(0); } } { auto _e = name(); if (_e) _o->name = _e->str(); } { auto _e = debug_metadata_index(); _o->debug_metadata_index = _e; } } @@ -21640,6 +21874,70 @@ inline ::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuild _size); } +inline ExternalBufferGroupT *ExternalBufferGroup::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ExternalBufferGroupT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ExternalBufferGroup::UnPackTo(ExternalBufferGroupT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = name(); if (_e) _o->name = _e->str(); } +} + +inline ::flatbuffers::Offset ExternalBufferGroup::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateExternalBufferGroup(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateExternalBufferGroup(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferGroupT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ExternalBufferGroupT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + return tflite::CreateExternalBufferGroup( + _fbb, + _name); +} + +inline ExternalBufferT *ExternalBuffer::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ExternalBufferT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ExternalBuffer::UnPackTo(ExternalBufferT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = id(); _o->id = _e; } + { auto _e = group(); _o->group = _e; } + { auto _e = offset(); _o->offset = _e; } + { auto _e = length(); _o->length = _e; } + { auto _e = packing(); if (_e) _o->packing = _e->str(); } +} + +inline ::flatbuffers::Offset ExternalBuffer::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateExternalBuffer(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateExternalBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const ExternalBufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ExternalBufferT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _id = _o->id; + auto _group = _o->group; + auto _offset = _o->offset; + auto _length = _o->length; + auto _packing = _o->packing.empty() ? 0 : _fbb.CreateString(_o->packing); + return tflite::CreateExternalBuffer( + _fbb, + _id, + _group, + _offset, + _length, + _packing); +} + inline MetadataT *Metadata::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { auto _o = std::unique_ptr(new MetadataT()); UnPackTo(_o.get(), _resolver); @@ -21724,8 +22022,8 @@ inline SignatureDefT *SignatureDef::UnPack(const ::flatbuffers::resolver_functio inline void SignatureDef::UnPackTo(SignatureDefT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->inputs[_i]) { _e->Get(_i)->UnPackTo(_o->inputs[_i].get(), _resolver); } else { _o->inputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->inputs.resize(0); } } - { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->outputs[_i]) { _e->Get(_i)->UnPackTo(_o->outputs[_i].get(), _resolver); } else { _o->outputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->outputs.resize(0); } } + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->inputs[_i]) { _e->Get(_i)->UnPackTo(_o->inputs[_i].get(), _resolver); } else { _o->inputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->inputs.resize(0); } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->outputs[_i]) { _e->Get(_i)->UnPackTo(_o->outputs[_i].get(), _resolver); } else { _o->outputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->outputs.resize(0); } } { auto _e = signature_key(); if (_e) _o->signature_key = _e->str(); } { auto _e = subgraph_index(); _o->subgraph_index = _e; } } @@ -21764,6 +22062,10 @@ inline ModelT::ModelT(const ModelT &o) for (const auto &metadata_ : o.metadata) { metadata.emplace_back((metadata_) ? new tflite::MetadataT(*metadata_) : nullptr); } signature_defs.reserve(o.signature_defs.size()); for (const auto &signature_defs_ : o.signature_defs) { signature_defs.emplace_back((signature_defs_) ? new tflite::SignatureDefT(*signature_defs_) : nullptr); } + external_buffer_groups.reserve(o.external_buffer_groups.size()); + for (const auto &external_buffer_groups_ : o.external_buffer_groups) { external_buffer_groups.emplace_back((external_buffer_groups_) ? new tflite::ExternalBufferGroupT(*external_buffer_groups_) : nullptr); } + external_buffers.reserve(o.external_buffers.size()); + for (const auto &external_buffers_ : o.external_buffers) { external_buffers.emplace_back((external_buffers_) ? new tflite::ExternalBufferT(*external_buffers_) : nullptr); } } inline ModelT &ModelT::operator=(ModelT o) FLATBUFFERS_NOEXCEPT { @@ -21775,6 +22077,8 @@ inline ModelT &ModelT::operator=(ModelT o) FLATBUFFERS_NOEXCEPT { std::swap(metadata_buffer, o.metadata_buffer); std::swap(metadata, o.metadata); std::swap(signature_defs, o.signature_defs); + std::swap(external_buffer_groups, o.external_buffer_groups); + std::swap(external_buffers, o.external_buffers); return *this; } @@ -21788,13 +22092,15 @@ inline void Model::UnPackTo(ModelT *_o, const ::flatbuffers::resolver_function_t (void)_o; (void)_resolver; { auto _e = version(); _o->version = _e; } - { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operator_codes[_i]) { _e->Get(_i)->UnPackTo(_o->operator_codes[_i].get(), _resolver); } else { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operator_codes.resize(0); } } - { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->subgraphs[_i]) { _e->Get(_i)->UnPackTo(_o->subgraphs[_i].get(), _resolver); } else { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->subgraphs.resize(0); } } + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operator_codes[_i]) { _e->Get(_i)->UnPackTo(_o->operator_codes[_i].get(), _resolver); } else { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->operator_codes.resize(0); } } + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->subgraphs[_i]) { _e->Get(_i)->UnPackTo(_o->subgraphs[_i].get(), _resolver); } else { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->subgraphs.resize(0); } } { auto _e = description(); if (_e) _o->description = _e->str(); } - { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->buffers[_i]) { _e->Get(_i)->UnPackTo(_o->buffers[_i].get(), _resolver); } else { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->buffers.resize(0); } } + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->buffers[_i]) { _e->Get(_i)->UnPackTo(_o->buffers[_i].get(), _resolver); } else { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->buffers.resize(0); } } { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } else { _o->metadata_buffer.resize(0); } } - { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->metadata[_i]) { _e->Get(_i)->UnPackTo(_o->metadata[_i].get(), _resolver); } else { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->metadata.resize(0); } } - { auto _e = signature_defs(); if (_e) { _o->signature_defs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->signature_defs[_i]) { _e->Get(_i)->UnPackTo(_o->signature_defs[_i].get(), _resolver); } else { _o->signature_defs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->signature_defs.resize(0); } } + { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->metadata[_i]) { _e->Get(_i)->UnPackTo(_o->metadata[_i].get(), _resolver); } else { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->metadata.resize(0); } } + { auto _e = signature_defs(); if (_e) { _o->signature_defs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->signature_defs[_i]) { _e->Get(_i)->UnPackTo(_o->signature_defs[_i].get(), _resolver); } else { _o->signature_defs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->signature_defs.resize(0); } } + { auto _e = external_buffer_groups(); if (_e) { _o->external_buffer_groups.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->external_buffer_groups[_i]) { _e->Get(_i)->UnPackTo(_o->external_buffer_groups[_i].get(), _resolver); } else { _o->external_buffer_groups[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->external_buffer_groups.resize(0); } } + { auto _e = external_buffers(); if (_e) { _o->external_buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->external_buffers[_i]) { _e->Get(_i)->UnPackTo(_o->external_buffers[_i].get(), _resolver); } else { _o->external_buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } } else { _o->external_buffers.resize(0); } } } inline ::flatbuffers::Offset Model::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -21813,6 +22119,8 @@ inline ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder auto _metadata_buffer = _o->metadata_buffer.size() ? _fbb.CreateVector(_o->metadata_buffer) : 0; auto _metadata = _o->metadata.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateMetadata(*__va->__fbb, __va->__o->metadata[i].get(), __va->__rehasher); }, &_va ) : 0; auto _signature_defs = _o->signature_defs.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->signature_defs.size(), [](size_t i, _VectorArgs *__va) { return CreateSignatureDef(*__va->__fbb, __va->__o->signature_defs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _external_buffer_groups = _o->external_buffer_groups.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->external_buffer_groups.size(), [](size_t i, _VectorArgs *__va) { return CreateExternalBufferGroup(*__va->__fbb, __va->__o->external_buffer_groups[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _external_buffers = _o->external_buffers.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->external_buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateExternalBuffer(*__va->__fbb, __va->__o->external_buffers[i].get(), __va->__rehasher); }, &_va ) : 0; return tflite::CreateModel( _fbb, _version, @@ -21822,7 +22130,9 @@ inline ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder _buffers, _metadata_buffer, _metadata, - _signature_defs); + _signature_defs, + _external_buffer_groups, + _external_buffers); } inline bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type) { diff --git a/tensorflow/compiler/mlir/lite/schema/schema_utils.cc b/tensorflow/compiler/mlir/lite/schema/schema_utils.cc index a173380940d600..cb61ce6243f3ad 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_utils.cc +++ b/tensorflow/compiler/mlir/lite/schema/schema_utils.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include +#include +#include +#include #include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" namespace tflite { @@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) { op_code->deprecated_builtin_code)); } +size_t TensorTypeGetSize(::tflite::TensorType data_type) { + switch (data_type) { + case ::tflite::TensorType_FLOAT32: + static_assert(sizeof(float) == 4, ""); + return 4; + case ::tflite::TensorType_FLOAT16: + static_assert(sizeof(int16_t) == 2, ""); + return 2; + case ::tflite::TensorType_INT32: + static_assert(sizeof(int32_t) == 4, ""); + return 4; + case ::tflite::TensorType_UINT8: + static_assert(sizeof(uint8_t) == 1, ""); + return 1; + case ::tflite::TensorType_INT64: + static_assert(sizeof(int64_t) == 8, ""); + return 8; + case ::tflite::TensorType_BOOL: + return sizeof(bool); + case ::tflite::TensorType_INT16: + static_assert(sizeof(int16_t) == 2, ""); + return 2; + case ::tflite::TensorType_COMPLEX64: + static_assert(sizeof(std::complex) == 8, ""); + return 8; + case ::tflite::TensorType_INT8: + static_assert(sizeof(int8_t) == 1, ""); + return 1; + case ::tflite::TensorType_FLOAT64: + static_assert(sizeof(double) == 8, ""); + return 8; + case ::tflite::TensorType_COMPLEX128: + static_assert(sizeof(std::complex) == 16, ""); + return 16; + case ::tflite::TensorType_UINT64: + static_assert(sizeof(uint64_t) == 8, ""); + return 8; + case ::tflite::TensorType_UINT32: + static_assert(sizeof(uint32_t) == 4, ""); + return 4; + case ::tflite::TensorType_UINT16: + static_assert(sizeof(uint16_t) == 2, ""); + return 2; + default: + return 0; + } +} } // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/schema/schema_utils.h b/tensorflow/compiler/mlir/lite/schema/schema_utils.h index 7498aa02ebe5c2..9c32680b85117f 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_utils.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_utils.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ +#include + #include "flatbuffers/flatbuffers.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" @@ -28,6 +30,11 @@ BuiltinOperator GetBuiltinCode(const OperatorCode *op_code); BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code); +// Returns the size of the given TensorType in bytes, or 0 if the TensorType is +// not supported, this function should be aligned with TfLiteTypeGetSize in +// lite/kernels/kernel_util.h. +size_t TensorTypeGetSize(::tflite::TensorType data_type); + } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 43fada7b0d0b62..cd553040786c72 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -539,6 +539,7 @@ cc_library( ":passes_inc_gen", ":unfold_splat_constant_pass", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:case", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:custom_call", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:dot_general", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index ae672381bacafd..9a0a185443ebc0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -3073,6 +3073,13 @@ func.func @convert_iota_ui64() -> tensor<123xui64> { func.return %0 : tensor<123xui64> } +// CHECK-LABEL: func @no_convert_iota_ui8 +func.func @no_convert_iota_ui8() -> tensor<123xui8> { + // CHECK: "mhlo.iota" + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xui8> + func.return %0 : tensor<123xui8> +} + // CHECK-LABEL: func @convert_avgpool_valid( // CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { // CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) <{data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index a77d02e78c1dce..1d8a63130ac1d9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -3721,14 +3721,43 @@ func.func @dynamic_broadcast_in_dim_general_case_expand_back_dims(%arg0: tensor< // CHECK: %2 = "tfl.broadcast_to"(%1, %arg1) : (tensor, tensor<4xi32>) -> tensor +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.case +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: case_func +func.func @case_func(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + %0 = "mhlo.case"(%arg0) ({ + %2 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + %2 = mhlo.multiply %arg1, %arg1 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor +// CHECK: %[[PRED:.*]] = tfl.not_equal(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: %[[IF:.*]] = "tfl.if"(%[[PRED]]) ({ +// CHECK: %[[MUL:.*]] = tfl.mul %arg1, %arg1 {fused_activation_function = "NONE"} : tensor +// CHECK: "tfl.yield"(%[[MUL]]) : (tensor) -> () +// CHECK: }, { +// CHECK: %[[ADD:.*]] = tfl.add %arg1, %arg2 {fused_activation_function = "NONE"} : tensor +// CHECK: "tfl.yield"(%[[ADD]]) : (tensor) -> () +// CHECK: }) : (tensor) -> tensor +// CHECK: return %[[IF]] : tensor + // ----- //===----------------------------------------------------------------------===// // mhlo.if //===----------------------------------------------------------------------===// -// CHECK-LABEL: if -func.func @if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { +// CHECK-LABEL: if_label +func.func @if_label(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { %0 = mhlo.add %arg1, %arg2 : tensor %1 = "mhlo.if"(%arg0) ({ "mhlo.return"(%0) : (tensor) -> () diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 3891d0f3fe4db3..7608ff985f1eb9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -2081,8 +2081,10 @@ class ConvertIotaOpToTfRange : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { RankedTensorType type = mlir::dyn_cast_or_null(iota_op.getType()); - // TF::RangeOp doesn't support UI16. - if (!type || type.getElementType().isUnsignedInteger(16)) return failure(); + // TF::RangeOp doesn't support UI16 and UI8. + if (!type || type.getElementType().isUnsignedInteger(16) || + type.getElementType().isUnsignedInteger(8)) + return failure(); const uint64_t dimension = iota_op.getIotaDimension(); Type element_type = type.getElementType(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index 9e2f1cf33f495f..16c194df28f591 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -320,6 +320,21 @@ cc_library( ], ) +cc_library( + name = "case", + srcs = ["case.cc"], + hdrs = ["case.h"], + deps = [ + ":util", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + cc_library( name = "if", srcs = ["if.cc"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.cc new file mode 100644 index 00000000000000..b50a5e7fd83195 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.cc @@ -0,0 +1,100 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +// Legalizes mhlo.case op to tfl.if op. +// This pattern only supports mhlo.case ops with exactly two branches. +class LegalizeCaseOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::CaseOp case_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + // mhlo.case can have N branches, but tfl.if only supports two. + if (case_op.getBranches().size() != 2) { + return rewriter.notifyMatchFailure( + case_op, "can only convert mhlo.case with 2 branches"); + } + + // `mhlo.case` takes an index, `tfl.if` takes a boolean predicate. + // For a 2-branch `mhlo.case` (branch 0 and branch 1), we need to map + // the index to a boolean. + // According to the mhlo.case spec, an out-of-bounds index defaults to the + // index of the last branch, which is 1 in this case. + // So, index 0 maps to branch 0, and any other index (1, or out of bounds) + // maps to branch 1. + // This can be expressed as a predicate `index != 0` for branch 1. + + auto loc = case_op->getLoc(); + auto index = case_op.getIndex(); + auto index_type = mlir::cast(index.getType()); + + // Create a constant tensor of the same shape as the index, filled with + // zeros. + auto const_zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(index_type)); + + // Create the predicate `index != 0`. + auto pred_type = index_type.clone(rewriter.getI1Type()); + auto pred = mhlo::CompareOp::create( + rewriter, loc, pred_type, index, const_zero, + mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), + mhlo::ComparisonDirection::NE), + mhlo::ComparisonTypeAttr{}); // Default comparison type is fine for + // integers. + + // Create the tfl.if op. + auto tfl_if = + TFL::IfOp::create(rewriter, loc, case_op.getResultTypes(), pred); + + // Branch 1 of mhlo.case becomes the `then_region` of tfl.if. + tfl_if.getThenRegion().takeBody(case_op.getBranches()[1]); + ReplaceTerminatorWithYield(tfl_if.getThenRegion(), rewriter); + + // Branch 0 of mhlo.case becomes the `else_region` of tfl.if. + tfl_if.getElseRegion().takeBody(case_op.getBranches()[0]); + ReplaceTerminatorWithYield(tfl_if.getElseRegion(), rewriter); + + rewriter.replaceOp(case_op, tfl_if.getResults()); + return success(); + } +}; + +} // namespace + +void PopulateCasePatterns(MLIRContext* context, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(context); + // Mark mhlo.case as dynamically legal: it's legal if it does NOT have + // exactly 2 branches, as those are the ones we want to convert. + target.addDynamicallyLegalOp( + [](mhlo::CaseOp op) { return op.getBranches().size() != 2; }); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h new file mode 100644 index 00000000000000..11c470a1492630 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h @@ -0,0 +1,31 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace odml { + +void PopulateCasePatterns(MLIRContext* context, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CASE_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index 9518b960f17442..0c43a5c4047a64 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/case.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep @@ -479,6 +480,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { PopulateWhilePatterns(context, patterns, target); PopulateGetDimensionSizePatterns(context, patterns, target); PopulateIfPatterns(context, patterns, target); + PopulateCasePatterns(context, patterns, target); PopulateLegalizeFftPatterns(context, patterns, target); PopulateCustomCallPatterns(context, patterns, target); @@ -493,7 +495,6 @@ void LegalizeHloToTfLitePass::runOnOperation() { } // namespace - // Creates an instance of the pass. std::unique_ptr> CreateLegalizeHloToTfLitePass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 6043e26cb757d8..2fcdfb80b6a0ad 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -261,7 +261,7 @@ func.func @mul_one_quant(%arg0: tensor<32x!quant.uniform>) -> tenso // CHECK-LABEL: @elementwise_unary_ops -func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) { +func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { %0 = arith.constant dense<-1.0> : tensor %1 = arith.constant dense<1.0> : tensor %2 = arith.constant dense<1.0> : tensor @@ -269,6 +269,7 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te %4 = arith.constant dense<4.0> : tensor %5 = arith.constant dense<4.0> : tensor %6 = arith.constant dense<2.0> : tensor + %one = arith.constant dense<1.0> : tensor // CHECK-DAG: [[cst0:%.*]] = arith.constant dense<1.000000e+00> : tensor // CHECK-DAG: [[cst1:%.*]] = arith.constant dense<0.841470957> : tensor @@ -277,7 +278,8 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te // CHECK-DAG: [[cst4:%.*]] = arith.constant dense<2.000000e+00> : tensor // CHECK-DAG: [[cst5:%.*]] = arith.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[cst6:%.*]] = arith.constant dense<4.000000e+00> : tensor - // CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]], [[cst4]], [[cst5]], [[cst6]] + // CHECK-DAG: [[cst7:%.*]] = arith.constant dense<0.761594176> : tensor + // CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]], [[cst4]], [[cst5]], [[cst6]], [[cst7]] %7 = "tfl.abs"(%0) : (tensor) -> tensor %8 = "tfl.sin"(%1) : (tensor) -> tensor @@ -286,8 +288,9 @@ func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, te %11 = "tfl.sqrt"(%4) : (tensor) -> tensor %12 = "tfl.rsqrt"(%5) : (tensor) -> tensor %13 = "tfl.square"(%6) : (tensor) -> tensor + %14 = "tfl.tanh"(%one) : (tensor) -> tensor - func.return %7, %8, %9, %10, %11, %12, %13 : tensor, tensor, tensor, tensor, tensor, tensor, tensor + func.return %7, %8, %9, %10, %11, %12, %13, %14 : tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor } // CHECK-LABEL: @max_with_neg_f32_max_val diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index e950a5d91b9876..2ce933112a0a43 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -323,21 +323,19 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer( // TODO: b/264218457 - Refactor the component below once StableHLO Quantizer // can run DRQ. Temporarily using TF Quantization for StableHLO DRQ. - if (!converter_flags.has_quantization_options()) { - // The default minimum number of elements a weights array must have to be - // quantized by this transformation. - const int kWeightsMinNumElementsDefault = 1024; - - quantization::QuantizationOptions quantization_options; - - quantization_options.mutable_quantization_method()->set_preset_method( - quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8); - quantization_options.set_op_set(quantization::UNIFORM_QUANTIZED); - quantization_options.set_min_num_elements_for_weights( - kWeightsMinNumElementsDefault); - quantization::AddQuantizePtqDynamicRangePasses(pass_manager, - quantization_options); - } + // The default minimum number of elements a weights array must have to be + // quantized by this transformation. + const int kWeightsMinNumElementsDefault = 1024; + + quantization::QuantizationOptions quantization_options; + + quantization_options.mutable_quantization_method()->set_preset_method( + quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8); + quantization_options.set_op_set(quantization::UNIFORM_QUANTIZED); + quantization_options.set_min_num_elements_for_weights( + kWeightsMinNumElementsDefault); + quantization::AddQuantizePtqDynamicRangePasses(pass_manager, + quantization_options); if (failed(pass_manager.run(module))) { return status_handler.ConsumeStatus(); } @@ -350,10 +348,6 @@ absl::Status ConvertTFExecutorToStablehloFlatbuffer( pass_manager.addPass(mlir::odml::createPrintOpStatsPass( mlir::odml::GetAcceptedStableHLODialects())); mlir::odml::AddStablehloOptimizationPasses(pass_manager); - if (converter_flags.has_quantization_options()) { - stablehlo::quantization::AddQuantizationPasses( - pass_manager, converter_flags.quantization_options()); - } if (failed(pass_manager.run(module))) { return status_handler.ConsumeStatus(); } diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc index 4d6d46e55bb5be..9ccda1d0c95e69 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version.cc @@ -177,6 +177,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { reinterpret_cast(op_sig.builtin_data); TFLITE_DCHECK(fully_connected_params != nullptr); + if (op_sig.inputs.at(1).type == kTfLiteInt2) { + return 14; + } + if (op_sig.inputs.at(0).type == kTfLiteInt16 && op_sig.inputs.at(1).type == kTfLiteInt4 && op_sig.outputs.at(0).type == kTfLiteInt16) { @@ -464,6 +468,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_SLICE: + if (op_sig.inputs.at(0).type == kTfLiteInt4) { + return 7; + } if (op_sig.inputs.at(0).type == kTfLiteUInt32) { return 6; } @@ -473,7 +480,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { if (op_sig.inputs.at(0).type == kTfLiteInt16) { return 4; } - // Version 3 supports string input types. if (op_sig.inputs.at(0).type == kTfLiteString) { return 3; } @@ -499,6 +505,9 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_DEQUANTIZE: + if (op_sig.inputs.at(0).type == kTfLiteInt2) { + return 7; + } if (op_sig.inputs.at(0).type == kTfLiteInt4) { return 6; } diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc index 641a2e45fb8c24..87313665d1811f 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/op_version_test.cc @@ -733,6 +733,15 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) { }; fake_op_sig.ext_options.fully_connected.is_per_channel_quantized = true; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 12); + + fake_op_sig = { + .op = BuiltinOperator_FULLY_CONNECTED, + .inputs = CreateOpSignatureTensorSpecs( + std::vector{kTfLiteInt8, kTfLiteInt2}), + .outputs = CreateOpSignatureTensorSpecs(kTfLiteInt8), + .builtin_data = reinterpret_cast(&fully_connected_params), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 14); } TEST(OpVersionTest, VersioningDequantizeTest) { @@ -757,6 +766,12 @@ TEST(OpVersionTest, VersioningDequantizeTest) { fake_op_sig.ext_options.dequantize.is_per_channel_quantized = true; EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5); + fake_op_sig = { + .op = BuiltinOperator_DEQUANTIZE, + .inputs = CreateOpSignatureTensorSpecs(kTfLiteInt2), + }; + EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 7); + fake_op_sig = { .op = BuiltinOperator_DEQUANTIZE, .inputs = CreateOpSignatureTensorSpecs(kTfLiteFloat32), diff --git a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc index d7e6b7c9a2064c..aca1b463878966 100644 --- a/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.cc @@ -139,6 +139,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"}, {{BuiltinOperator_FULLY_CONNECTED, 12}, "2.17.0"}, {{BuiltinOperator_FULLY_CONNECTED, 13}, "2.18.0"}, + {{BuiltinOperator_FULLY_CONNECTED, 14}, "2.21.0"}, {{BuiltinOperator_GATHER, 1}, "1.6.0"}, {{BuiltinOperator_GATHER, 2}, "1.14.0"}, {{BuiltinOperator_GATHER, 3}, "1.15.0"}, @@ -294,6 +295,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_SLICE, 4}, "2.4.0"}, {{BuiltinOperator_SLICE, 5}, "2.5.0"}, {{BuiltinOperator_SLICE, 6}, "2.14.0"}, + {{BuiltinOperator_SLICE, 7}, "2.21.0"}, {{BuiltinOperator_TANH, 1}, "1.14.0"}, {{BuiltinOperator_TANH, 2}, "1.14.0"}, {{BuiltinOperator_TANH, 3}, "2.3.0"}, @@ -326,6 +328,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_DEQUANTIZE, 4}, "2.2.0"}, {{BuiltinOperator_DEQUANTIZE, 5}, "2.7.0"}, {{BuiltinOperator_DEQUANTIZE, 6}, "2.18.0"}, + {{BuiltinOperator_DEQUANTIZE, 7}, "2.21.0"}, {{BuiltinOperator_REVERSE_SEQUENCE, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 1}, "1.14.0"}, {{BuiltinOperator_EQUAL, 2}, "1.14.0"}, diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc index 96412f20633f6a..7453ed54975a5a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -43,7 +43,7 @@ limitations under the License. namespace mlir { namespace TFL { namespace { -#define GEN_PASS_CLASSES +#define GEN_PASS_DEF_QUANTIZEVARIABLESPASS #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" using ResourceIdMap = @@ -80,7 +80,7 @@ Type GetDequantizedTypeFromAssigneVariableOp(VarHandleOp var_handle_op) { } class QuantizeVariablesPass - : public QuantizeVariablesPassBase { + : public impl::QuantizeVariablesPassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeVariablesPass) explicit QuantizeVariablesPass() = default; diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 12fab673d6e43b..1b82ca5b0e61dc 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -74,7 +74,7 @@ llvm::SmallVector ReadAsHostEndian(ArrayRef bytes) { ret.reserve(elem_count); const char* data_ptr = reinterpret_cast(bytes.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { T val = llvm::support::endian::readNext(data_ptr); ret.push_back(mlir::APInt(sizeof(T) * 8, val)); @@ -362,7 +362,7 @@ StatusOr ConvertFloatBuffer( assert(bytes_len % 2 == 0); // Supports both BF16 and F16. assert(elem_type.isF16() || elem_type.isBF16()); - int elem_count = bytes_len / 2; + size_t elem_count = bytes_len / 2; if (elem_type.isF16()) { std::vector values; @@ -370,7 +370,7 @@ StatusOr ConvertFloatBuffer( const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint16_t bit_repr = llvm::support::endian::readNext< uint16_t, llvm::endianness::native, llvm::support::unaligned>( data); @@ -385,7 +385,7 @@ StatusOr ConvertFloatBuffer( const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint16_t bit_repr = llvm::support::endian::readNext< uint16_t, llvm::endianness::native, llvm::support::unaligned>( data); @@ -398,13 +398,13 @@ StatusOr ConvertFloatBuffer( } case 32: { assert(bytes_len % 4 == 0); - int elem_count = bytes_len / 4; + size_t elem_count = bytes_len / 4; std::vector values; values.reserve(elem_count); const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint32_t bit_repr = llvm::support::endian::readNext(data); @@ -415,13 +415,13 @@ StatusOr ConvertFloatBuffer( } case 64: { assert(bytes_len % 8 == 0); - int elem_count = bytes_len / 8; + size_t elem_count = bytes_len / 8; std::vector values; values.reserve(elem_count); const char* data = reinterpret_cast(buffer.data()); - for (int i = 0; i < elem_count; i++) { + for (size_t i = 0; i < elem_count; i++) { uint64_t bit_repr = llvm::support::endian::readNext(data); diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc index 2acb4dccb88a18..0ae1247e2a156a 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc @@ -43,13 +43,13 @@ void Register(const std::string& op_name, OpRegistry* registry) { } // namespace TEST(TfTextUtilsTest, TestTfTextRegistered) { - std::unique_ptr registry(new OpRegistry); + std::unique_ptr registry = std::make_unique(); Register("WhitespaceTokenizeWithOffsets", registry.get()); EXPECT_TRUE(IsTFTextRegistered(registry.get())); } TEST(TfTextUtilsTest, TestTfTextNotRegistered) { - std::unique_ptr registry(new OpRegistry); + std::unique_ptr registry = std::make_unique(); Register("Test", registry.get()); EXPECT_FALSE(IsTFTextRegistered(registry.get())); } diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index 9d7e689f3b6a3c..0c6a636d38b822 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -124,7 +124,7 @@ class ModifyMlirModulePass : public MlirOptimizationPass { }; FunctionDef XTimesTwo() { - const Tensor kTwo = test::AsScalar(2); + const Tensor kTwo = test::AsScalar(2); return FunctionDefHelper::Define( // Name "XTimesTwo", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 5eaf5d736262ca..4f2384347a7802 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -251,7 +251,7 @@ std::string ExperimentalConvertSavedModelToMlir( // Convert the SavedModelV2Bundle to an MLIR module. - std::vector exported_names = + std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); @@ -270,10 +270,10 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( const std::string& saved_model_path, const std::string& exported_names_str, const std::string& tags, bool upgrade_legacy, bool show_debug_info, TF_Status* status) { - std::unordered_set tag_set = + std::unordered_set tag_set = absl::StrSplit(tags, ',', absl::SkipEmpty()); - std::vector exported_names = + std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); @@ -299,7 +299,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( bool show_debug_info, TF_Status* status) { // Load the saved model into a SavedModelBundle. - std::unordered_set tag_set = + std::unordered_set tag_set = absl::StrSplit(tags, ',', absl::SkipEmpty()); tensorflow::SavedModelBundle bundle; @@ -311,7 +311,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( } // Convert the SavedModelBundle to an MLIR module. - std::vector exported_names = + std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc index ae93231d4d336b..5d6d36ed3a6c7d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset_test.cc @@ -36,8 +36,6 @@ using ::testing::HasSubstr; using ::testing::Key; using ::testing::SizeIs; using ::testing::StrEq; -using ::tsl::testing::IsOk; -using ::tsl::testing::StatusIs; TEST(CreateRepresentativeDatasetFileMapTest, ConfigWithoutExplicitSignatureKeyMappedToServingDefault) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc index a3a09bdb35daaa..2fb8f11a4e4349 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/io_test.cc @@ -49,23 +49,23 @@ class TestEnvBrokenFileSystem : public tsl::Env { public: TestEnvBrokenFileSystem() = default; - bool MatchPath(const tsl::string& path, const tsl::string& pattern) override { + bool MatchPath(const std::string& path, const std::string& pattern) override { return false; } void SleepForMicroseconds(int64_t micros) override {} - tsl::string GetRunfilesDir() override { return tsl::string("dummy_path"); } + std::string GetRunfilesDir() override { return std::string("dummy_path"); } int64_t GetCurrentThreadId() override { return 0; } tsl::Thread* StartThread(const tsl::ThreadOptions& thread_options, - const tsl::string& name, + const std::string& name, absl::AnyInvocable fn) override { return nullptr; } - bool GetCurrentThreadName(tsl::string* name) override { return false; } + bool GetCurrentThreadName(std::string* name) override { return false; } void SchedClosure(absl::AnyInvocable closure) override {} @@ -82,9 +82,9 @@ class TestEnvBrokenFileSystem : public tsl::Env { return absl::OkStatus(); } - tsl::string FormatLibraryFileName(const tsl::string& name, - const tsl::string& version) override { - return tsl::string("dummy_path"); + std::string FormatLibraryFileName(const std::string& name, + const std::string& version) override { + return std::string("dummy_path"); } // This is the part that would break the `CreateTmpDir` function because it @@ -95,7 +95,7 @@ class TestEnvBrokenFileSystem : public tsl::Env { } private: - void GetLocalTempDirectories(std::vector* list) override { + void GetLocalTempDirectories(std::vector* list) override { list->push_back("/tmp"); } }; @@ -107,7 +107,7 @@ class TestEnvBrokenFileSystemAndNoLocalTempDirs private: // This is the part that essentially breaks the `GetLocalTmpFileName` function // because it doesn't provide any available temp dirs. - void GetLocalTempDirectories(std::vector* list) override {} + void GetLocalTempDirectories(std::vector* list) override {} }; TEST(IoTest, GetLocalTmpFileNameGivesValidFileName) { diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc index c2e91c5da16e93..1f6464d85f5ef4 100644 --- a/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc @@ -4822,7 +4822,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -5022,7 +5022,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index e23f510182259f..4104cf412acfd8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -335,7 +335,6 @@ def TF_IfRegionOp : TF_Op<"IfRegion", "areTypesCompatible", "getEntrySuccessorOperands", "getRegionInvocationBounds", - "getSuccessorRegions" ]> ]> { let summary = "output = cond ? then_branch output : else_branch output"; @@ -395,7 +394,6 @@ def TF_GeneratorDatasetRegionOp : TF_Op<"GeneratorDatasetRegion", "areTypesCompatible", "getEntrySuccessorOperands", "getRegionInvocationBounds", - "getSuccessorRegions" ]>, SingleBlockImplicitTerminator<"YieldOp">, TF_GeneratorOpSideEffect, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 59ba13e326a02f..6382f325a47505 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -3003,14 +3003,14 @@ void GeneratorDatasetRegionOp::getRegionInvocationBounds( } OperandRange GeneratorDatasetRegionOp::getEntrySuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { auto end = this->getOperation()->operand_end(); - if (point.isParent()) { + if (successor.isParent()) { // The op itself doesn't branch back to itself. return ::mlir::OperandRange(end, end); - } else if (point.getRegionOrNull() == &getInit()) { + } else if (successor.getSuccessor() == &getInit()) { return getInitFuncOtherArgs(); - } else if (point.getRegionOrNull() == &getNext()) { + } else if (successor.getSuccessor() == &getNext()) { return getNextFuncOtherArgs(); } else /* finalize region */ { return getFinalizeFuncOtherArgs(); @@ -3024,13 +3024,15 @@ void GeneratorDatasetRegionOp::getSuccessorRegions( // The op itself branches to `init` first. regions.push_back( RegionSuccessor(&getInit(), getInit().front().getArguments())); - } else if (point.getRegionOrNull() == &getInit()) { + } else if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getInit()) { // `init` branches to `next`, passing along the arguments given to `init`'s // yield. Said arguments precede the "other args". n = getInitFuncOtherArgs().size(); regions.push_back(RegionSuccessor( &getNext(), getNext().front().getArguments().drop_back(n))); - } else if (point.getRegionOrNull() == &getNext()) { + } else if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getNext()) { // `next` branches to itself, or to `finalize`, passing all arguments given // to `next`s yield. @@ -3045,7 +3047,8 @@ void GeneratorDatasetRegionOp::getSuccessorRegions( &getFinalize(), getFinalize().front().getArguments().slice(0, num))); } else { // `finalize` branches back to the op itself, not passing any arguments. - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + point.getTerminatorPredecessorOrNull()->getParentRegion())); } } @@ -3261,11 +3264,12 @@ void IfRegionOp::getRegionInvocationBounds( invocationBounds.assign(2, {0, 1}); } -OperandRange IfRegionOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange IfRegionOp::getEntrySuccessorOperands(RegionSuccessor successor) { // IfRegionOp currently only allows one op (the condition), so there are no // remaining operands for the successor. - assert((point.isParent() || - (point == (*this)->getRegion(0) || point == (*this)->getRegion(1))) && + assert((successor.isParent() || + (successor.getSuccessor() == &(*this)->getRegion(0) || + successor.getSuccessor() == &(*this)->getRegion(1))) && "Invalid IfRegionOp region index."); auto end = this->getOperation()->operand_end(); return ::mlir::OperandRange(end, end); @@ -3275,16 +3279,20 @@ void IfRegionOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl& regions) { if (!point.isParent()) { // The `then` and the `else` region branch back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back( + RegionSuccessor(point.getTerminatorPredecessorOrNull(), getResults())); return; } else { // The parent can branch to either `then` or `else`. - regions.push_back(RegionSuccessor(&getThenBranch())); + regions.push_back( + RegionSuccessor(&getThenBranch(), getThenBranch().getArguments())); Region* elseRegion = &this->getElseBranch(); if (!elseRegion->empty()) - regions.push_back(RegionSuccessor(elseRegion)); + regions.push_back( + RegionSuccessor(elseRegion, elseRegion->getArguments())); else - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + point.getTerminatorPredecessorOrNull()->getParentRegion())); } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 2b839d963fe2e4..23683673fe189a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -3611,8 +3611,8 @@ SmallVector WhileRegionOp::getLoopRegions() { return {&getBody()}; } //===----------------------------------------------------------------------===// OperandRange WhileRegionOp::getEntrySuccessorOperands( - RegionBranchPoint point) { - if (point.isParent()) { + RegionSuccessor successor) { + if (successor.isParent()) { // WhileRegionOp branches to the condition, which branches to the body. But // the op itself doesn't branch back to itself. So this range is empty. auto end = this->getOperation()->operand_end(); @@ -3628,21 +3628,28 @@ OperandRange WhileRegionOp::getEntrySuccessorOperands( void WhileRegionOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { - if (!point.isParent() && point == (*this)->getRegion(0)) { + if (!point.isParent() && + (point.getTerminatorPredecessorOrNull() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &(*this)->getRegion(0))) { // 'cond' branches to the body or returns. Operation *yield = getCond().front().getTerminator(); if (yield->getOperands().size() == 1 + this->getOperation()->getOperands().size()) { regions.push_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { // For compatibility with older code, we allow the "yield" in a condition // to only yield a single boolean. In that case we can't forward any args. regions.push_back(RegionSuccessor(&getBody())); - regions.push_back(RegionSuccessor()); // branch back to parent, no args + regions.push_back( + RegionSuccessor(getOperation(), getResults().take_front(0))); } - } else if (!point.isParent() && point == (*this)->getRegion(1)) { + } else if (!point.isParent() && + (point.getTerminatorPredecessorOrNull() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &(*this)->getRegion(1))) { // 'body' branches back to 'cond'. regions.push_back( RegionSuccessor(&getCond(), getCond().front().getArguments())); @@ -4510,7 +4517,7 @@ LogicalResult UniformQuantizedClipByValueOp::verify() { //===----------------------------------------------------------------------===// MutableOperandRange YieldOp::getMutableSuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { if (auto whileOp = llvm::dyn_cast(this->getOperation()->getParentOp())) { if (&whileOp.getCond() == this->getOperation()->getParentRegion()) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 765ed1171a8449..a3305eef8a0819 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1317,7 +1317,7 @@ func.func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) - // tf.Region yield number of results should match op number of results func.func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #0 to parent results: source has 2 operands, but target successor needs 1}} + // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Operation tf.Yield to parent results: source has 2 operands, but target successor needs 1}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () @@ -1332,7 +1332,7 @@ func.func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) // ----- func.func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #1 to parent results: source has 2 operands, but target successor needs 1}} + // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Operation tf.Yield to parent results: source has 2 operands, but target successor needs 1}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index bc4487a4e3fd7d..954c318b416150 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include -#include #include #include #include @@ -29,6 +28,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" @@ -230,7 +230,7 @@ std::optional> EquationToMap( llvm::StringRef equation) { llvm::SmallDenseMap map; for (int64_t i = 0; i < equation.size(); ++i) { - if (!std::isalpha(equation[i])) { + if (!llvm::isAlpha(equation[i])) { // Unsupported character in the equation. return std::nullopt; } @@ -263,7 +263,7 @@ std::optional> GetAvailableLabels( const int lhs_size = lhs.size(); for (int i = 0; i < lhs_size; ++i) { const char label = lhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { labels.remove(label); ++lhs_count; } else if (label == '.') { @@ -280,7 +280,7 @@ std::optional> GetAvailableLabels( const int rhs_size = rhs.size(); for (int i = 0; i < rhs_size; ++i) { const char label = rhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { labels.remove(label); ++rhs_count; } else if (label == '.') { @@ -318,7 +318,7 @@ std::tuple FlattenEllipsis( std::string new_lhs; for (int i = 0; i < lhs.size(); ++i) { const char label = lhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { new_lhs.push_back(label); } else { // Encounter ellipsis: generate unnamed labels then insert to the new @@ -333,7 +333,7 @@ std::tuple FlattenEllipsis( std::string new_rhs, new_rhs_labels; for (int i = 0; i < rhs.size(); ++i) { const char label = rhs[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { new_rhs.push_back(label); } else { // Encounter ellipsis: generate unnamed labels then insert to the new @@ -352,7 +352,7 @@ std::tuple FlattenEllipsis( std::string new_output; for (int i = 0; i < out.size(); ++i) { const char label = out[i]; - if (std::isalpha(label)) { + if (llvm::isAlpha(label)) { new_output.push_back(label); } else { // Encounter ellipsis: we will just insert the generated labels to the new diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc index a2c4a7031ed14b..0cdb563a45eed7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc @@ -49,7 +49,7 @@ static constexpr int kTextFileIndex_LineNumber = -1; class InitTextFileToImportPass : public impl::InitTextFileToImportPassBase { public: - InitTextFileToImportPass() {} + InitTextFileToImportPass() = default; InitTextFileToImportPass(const InitTextFileToImportPass&) {} explicit InitTextFileToImportPass(std::string saved_model_dir) { saved_model_dir_ = saved_model_dir; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc index a985cdc11611b4..41c5cd4234f1cc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc @@ -46,7 +46,7 @@ class InitTextFileToImportTestPass : public impl::InitTextFileToImportTestPassBase< InitTextFileToImportTestPass> { public: - explicit InitTextFileToImportTestPass() {} + explicit InitTextFileToImportTestPass() = default; StringRef getArgument() const final { return "tf-init-text-file-to-import-test"; @@ -115,7 +115,7 @@ class InitTextFileToImportSavedModelTestPass : public impl::InitTextFileToImportSavedModelTestPassBase< InitTextFileToImportSavedModelTestPass> { public: - explicit InitTextFileToImportSavedModelTestPass() {} + explicit InitTextFileToImportSavedModelTestPass() = default; private: void runOnOperation() override; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 2e023e3e057096..57a41f538f277f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -36,6 +35,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -289,8 +289,10 @@ ObjectNames::ObjectNames(const SavedObjectGraph& object_graph, // - `model.variables.0` // - `model.keras_api.layers.1.keras_api.trainable_variables.0` // - ... 10 more long aliases ending in digits ... - return std::make_tuple(isdigit(a.back()), a.size(), a) < - std::make_tuple(isdigit(b.back()), b.size(), b); + return std::make_tuple(absl::ascii_isdigit(a.back()), a.size(), + a) < + std::make_tuple(absl::ascii_isdigit(b.back()), b.size(), + b); }); for (const std::string& name : kv.second) { if (IsExported(name)) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc index c48f52576df4e3..0288006ee4d105 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.cc @@ -39,13 +39,13 @@ limitations under the License. namespace tensorflow { absl::Status ParseOutputArrayInfo(absl::string_view array_names, - std::vector* outputs) { + std::vector* outputs) { TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs)); return absl::OkStatus(); } -absl::Status ParseOutputArrayInfo(const std::vector& output_names, - std::vector* outputs) { +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs) { for (auto& output_name : output_names) { if (output_name.empty()) continue; outputs->push_back(output_name); @@ -57,8 +57,8 @@ absl::Status ParseInputArrayInfo(absl::string_view array_names, absl::string_view data_types, absl::string_view shapes, GraphImportConfig::InputArrays* inputs) { - std::vector node_names; - std::vector node_dtypes; + std::vector node_names; + std::vector node_dtypes; std::vector>> node_shapes; TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names)); TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes)); @@ -113,8 +113,8 @@ static absl::Status HandleSubtype(absl::string_view subtype, } absl::Status ParseInputArrayInfo( - const std::vector& node_names, - const std::vector& node_dtypes, + const std::vector& node_names, + const std::vector& node_dtypes, const std::vector>>& node_shapes, GraphImportConfig::InputArrays* inputs) { std::vector used_node_dtypes; @@ -148,7 +148,7 @@ absl::Status ParseInputArrayInfo( // StringMap doesn't support reserve else reserve input map size here. for (int i = 0, end = node_names.size(); i < end; i++) { auto& name = node_names[i]; - const string& type = used_node_dtypes[i]; + const std::string& type = used_node_dtypes[i]; if (name.empty()) continue; auto it_inserted_pair = inputs->insert({name, {}}); @@ -193,7 +193,7 @@ absl::Status ParseNodeShapes( std::vector>>& shapes_vector) { shapes_vector.clear(); if (!shapes_str.empty()) { - std::vector node_shapes_str = absl::StrSplit(shapes_str, ':'); + std::vector node_shapes_str = absl::StrSplit(shapes_str, ':'); for (int i = 0; i < node_shapes_str.size(); i++) { if (node_shapes_str[i] == "*") { shapes_vector.push_back(std::nullopt); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h index 1119d4e2b33c4f..176773da45fcbc 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h @@ -35,10 +35,10 @@ namespace tensorflow { // Parses the command line flag strings to the specification of nodes in // the Graph. absl::Status ParseOutputArrayInfo(absl::string_view array_names, - std::vector* outputs); + std::vector* outputs); -absl::Status ParseOutputArrayInfo(const std::vector& output_names, - std::vector* outputs); +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs); // Parses the command line flag strings to the specification of nodes in // the Graph. `data_types` input string can be empty since the flag is optional. @@ -48,8 +48,8 @@ absl::Status ParseInputArrayInfo(absl::string_view array_names, GraphImportConfig::InputArrays* inputs); absl::Status ParseInputArrayInfo( - const std::vector& node_names, - const std::vector& node_dtypes, + const std::vector& node_names, + const std::vector& node_dtypes, const std::vector>>& node_shapes, GraphImportConfig::InputArrays* inputs); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 858c70a54a58d6..3706b8afe34d78 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include -#include #include #include #include #include "absl/log/log.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_split.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" @@ -99,8 +99,7 @@ std::vector BridgeLoggerConfig::GetFilter( bool BridgeLoggerConfig::ShouldOnlyDumpTopLevelPasses() { const char* env_var = getenv(kEnableOnlyTopLevelPassesEnvVar); - std::string value(env_var); - std::transform(value.begin(), value.end(), value.begin(), ::tolower); + std::string value = absl::AsciiStrToLower(env_var); // Return true if value is "1" or "true"; otherwise, false. return value == "1" || value == "true"; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index b0ad4e265633d8..550ab547498f45 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -249,14 +249,14 @@ absl::StatusOr ConvertTensor(const Tensor& input_tensor, CONVERT_FLAT(DT_BOOL, bool) CONVERT_FLAT(DT_FLOAT, float) CONVERT_FLAT(DT_DOUBLE, double) - CONVERT_FLAT(DT_INT8, int8) - CONVERT_FLAT(DT_INT16, int16) - CONVERT_FLAT(DT_INT32, int32) + CONVERT_FLAT(DT_INT8, int8_t) + CONVERT_FLAT(DT_INT16, int16_t) + CONVERT_FLAT(DT_INT32, int32_t) CONVERT_FLAT(DT_INT64, int64_t) - CONVERT_FLAT(DT_UINT8, uint8) - CONVERT_FLAT(DT_UINT16, uint16) - CONVERT_FLAT(DT_UINT32, uint32) - CONVERT_FLAT(DT_UINT64, uint64) + CONVERT_FLAT(DT_UINT8, uint8_t) + CONVERT_FLAT(DT_UINT16, uint16_t) + CONVERT_FLAT(DT_UINT32, uint32_t) + CONVERT_FLAT(DT_UINT64, uint64_t) CONVERT_FLAT(DT_COMPLEX64, std::complex) CONVERT_FLAT(DT_COMPLEX128, std::complex) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index a34553623408d8..b120b6c786edb6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -162,11 +162,11 @@ TEST_F(ConvertTensorTest, Simple) { ASSERT_NO_FATAL_FAILURE(VerifyConversion( {static_cast(1), static_cast(-1)}, DT_INT4, mlir::IntegerType::get(&context, 4))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32))); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64))); @@ -175,19 +175,19 @@ TEST_F(ConvertTensorTest, Simple) { {static_cast(1), static_cast(2)}, DT_UINT4, mlir::IntegerType::get( &context, 4, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT8, mlir::IntegerType::get( &context, 8, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT16, mlir::IntegerType::get( &context, 16, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT32, mlir::IntegerType::get( &context, 32, mlir::IntegerType::SignednessSemantics::Unsigned))); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT64, mlir::IntegerType::get( &context, 64, mlir::IntegerType::SignednessSemantics::Unsigned))); @@ -222,11 +222,11 @@ TEST_F(ConvertTensorTest, SimpleDenseResourceElements) { ASSERT_NO_FATAL_FAILURE(VerifyConversion( {static_cast(1), static_cast(-1)}, DT_INT4, mlir::IntegerType::get(&context, 4), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT16, mlir::IntegerType::get(&context, 16), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT32, mlir::IntegerType::get(&context, 32), true)); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT64, mlir::IntegerType::get(&context, 64), true)); @@ -236,22 +236,22 @@ TEST_F(ConvertTensorTest, SimpleDenseResourceElements) { mlir::IntegerType::get(&context, 4, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT8, mlir::IntegerType::get(&context, 8, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT16, mlir::IntegerType::get(&context, 16, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT32, mlir::IntegerType::get(&context, 32, mlir::IntegerType::SignednessSemantics::Unsigned), true)); - ASSERT_NO_FATAL_FAILURE(VerifyConversion( + ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, 2}, DT_UINT64, mlir::IntegerType::get(&context, 64, mlir::IntegerType::SignednessSemantics::Unsigned), diff --git a/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc index 09a76102557c4f..a4f2861276a9bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config_test.cc @@ -59,9 +59,9 @@ TEST(DataDumperLoggerConfig, TestPassFilter) { 1); setenv("TF_DUMP_GRAPH_PREFIX", "sponge", 1); - const string kTestFilename = "test.txt"; + const std::string kTestFilename = "test.txt"; int print_callback_count = 0; - auto get_filename_fn = [](const string &filename, mlir::Operation *op) { + auto get_filename_fn = [](const std::string& filename, mlir::Operation* op) { return filename; }; auto print_callback = [&](llvm::raw_ostream &out) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index d9249d472b334c..3329bff4c02737 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -126,7 +126,8 @@ void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) { // For device that do not have any metadata, or if we failed to parse metadata // from the DeviceSet, we add a unit attribute to the `tf.devices` attribute. for (Device* device : device_set->devices()) { - string name = DeviceNameUtils::ParsedNameToString(device->parsed_name()); + std::string name = + DeviceNameUtils::ParsedNameToString(device->parsed_name()); if (device->device_type() == DEVICE_GPU) { auto metadata = ParseGpuDeviceMetadata(*device, &builder); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index c3e7ae75022348..abf357873a6153 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -52,8 +52,8 @@ class FakeDevice : public Device { return errors::Unimplemented("FakeDevice::Sync()"); } - static std::unique_ptr Make(const string& name, - const string& desc = "") { + static std::unique_ptr Make(const std::string& name, + const std::string& desc = "") { DeviceNameUtils::ParsedName parsed_name; DeviceNameUtils::ParseFullName(name, &parsed_name); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc index 7e92860e5ff03e..9d9780d231523f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph_test.cc @@ -26,12 +26,12 @@ limitations under the License. namespace tensorflow { namespace { -void ExpectHasSubstr(const string& s, const string& expected) { +void ExpectHasSubstr(const std::string& s, const std::string& expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } -void ExpectHasNoSubstr(const string& s, const string& expected) { +void ExpectHasNoSubstr(const std::string& s, const std::string& expected) { EXPECT_FALSE(absl::StrContains(s, expected)) << "'" << s << "' should not contain '" << expected << "'"; } @@ -39,7 +39,7 @@ void ExpectHasNoSubstr(const string& s, const string& expected) { // WritableFile that simply concats into string. class StringWritableFile : public WritableFile { public: - explicit StringWritableFile(string* str) : str_(*str) {} + explicit StringWritableFile(std::string* str) : str_(*str) {} absl::Status Append(absl::string_view data) override { absl::StrAppend(&str_, data); @@ -62,7 +62,7 @@ class StringWritableFile : public WritableFile { } private: - string& str_; + std::string& str_; }; TEST(Dump, TextualIrToFileSuccess) { @@ -72,10 +72,10 @@ TEST(Dump, TextualIrToFileSuccess) { setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); UseMlirForGraphDump(MlirDumpConfig()); - string ret = DumpGraphToFile("tir", graph); + std::string ret = DumpGraphToFile("tir", graph); ASSERT_EQ(ret, io::JoinPath(testing::TmpDir(), "tir.mlir")); - string actual; + std::string actual; TF_ASSERT_OK(ReadFileToString(Env::Default(), ret, &actual)); } @@ -86,12 +86,12 @@ TEST(Dump, TextualIrWithOptions) { .Attr("dtype", DT_FLOAT) .Finalize(&graph, &node)); - string actual; + std::string actual; StringWritableFile file(&actual); TF_ASSERT_OK(DumpTextualIRToFile(MlirDumpConfig().emit_location_information(), graph, /*flib_def=*/nullptr, &file)); - string expected_substr = R"(loc(#loc))"; + std::string expected_substr = R"(loc(#loc))"; ExpectHasSubstr(actual, expected_substr); } @@ -100,17 +100,17 @@ TEST(Dump, DumpToTFG) { Node* node; TF_CHECK_OK(NodeBuilder("A", "NoOp").Finalize(&graph, &node)); - string actual; + std::string actual; StringWritableFile file(&actual); TF_ASSERT_OK(DumpTextualIRToFile( MlirDumpConfig().emit_dialect(MlirDumpConfig::Dialect::kTFG), graph, /*flib_def=*/nullptr, &file)); - string expected_substr("tfg.graph"); + std::string expected_substr("tfg.graph"); ExpectHasSubstr(actual, expected_substr); - string not_expected_substr("tf_executor.island"); + std::string not_expected_substr("tf_executor.island"); ExpectHasNoSubstr(actual, not_expected_substr); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index b970ca84b326cf..138e13e3719328 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -44,7 +44,7 @@ struct NameCounts { llvm::StringMap counts; }; -std::string MakeUniqueFilename(string name) { +std::string MakeUniqueFilename(std::string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. @@ -274,7 +274,7 @@ void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path) { // Output dirs "sponge" (case-insensitive) have a special meaning: Dump into // the directory specified by the environment variable // TEST_UNDECLARED_OUTPUTS_DIR. - string lower_path = absl::AsciiStrToLower(path); + std::string lower_path = absl::AsciiStrToLower(path); if (lower_path == "sponge") { if (!tensorflow::io::GetTestUndeclaredOutputsDir(&path)) { LOG(ERROR) << "MLIR crash reproducer is set to '" << dir_path.str() diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 9ec1b9970ae777..9e07ece4e0999e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -400,12 +400,12 @@ absl::Status ConvertAttributes( if (auto symbol_ref = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR(ConvertAttribute( mlir::cast(symbol_ref), &value)); - func_call_attrs[string(name)] = std::move(value); + func_call_attrs[std::string(name)] = std::move(value); continue; } if (auto func_attr = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); - func_call_attrs[string(name)] = std::move(value); + func_call_attrs[std::string(name)] = std::move(value); continue; } if (mlir::isa(attr)) { @@ -434,12 +434,12 @@ absl::Status ConvertAttributes( // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in // the attribute from MLIR, it is treated as an attribute from function // calls. - std::vector name_tokens = + std::vector name_tokens = absl::StrSplit(name, '.', absl::SkipEmpty()); TF_RET_CHECK(name_tokens.size() <= 2); auto it = func_call_attrs.find(name_tokens[0]); if (it == func_call_attrs.end()) { - (*values)[string(name)] = std::move(value); + (*values)[std::string(name)] = std::move(value); } else { (*it->second.mutable_func()->mutable_attr())[name_tokens[1]] = std::move(value); @@ -457,7 +457,7 @@ absl::Status SetShapeAttribute(absl::string_view name, AttrValue value; SetTensorShapeProto(shaped_type, value.mutable_list()->add_shape()); - auto result = values->insert({string(name), value}); + auto result = values->insert({std::string(name), value}); if (!result.second) { // This should be extremely rare as it means we are adding the same // attribute multiple times/have some redundancy in representing this diff --git a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc index 50306edb28b067..fa2ff3c8a281fa 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/import_utils.cc @@ -59,7 +59,7 @@ absl::Status LoadProtoFromFileImpl(absl::string_view input_filename, T* proto) { if (std::error_code error = file_or_err.getError()) { return errors::InvalidArgument( "Could not open input file ", - string(input_filename.data(), input_filename.size()).c_str()); + std::string(input_filename.data(), input_filename.size()).c_str()); } const auto& input_file = *file_or_err; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc index a189cc14555143..fbcdc9e894fbd9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc @@ -41,7 +41,7 @@ const char kTensorPrefix[] = "tftensor$"; } // namespace -string MangleAttributeName(absl::string_view str) { +std::string MangleAttributeName(absl::string_view str) { return absl::StrCat(kAttributePrefix, str); } @@ -66,7 +66,7 @@ MangledKind GetMangledKind(absl::string_view str) { } } -string MangleShape(const TensorShapeProto& shape) { +std::string MangleShape(const TensorShapeProto& shape) { return absl::StrCat(kTensorShapePrefix, PrintShortTextProto(shape)); } @@ -74,7 +74,7 @@ absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto) { return ParseTextProto(str, kTensorShapePrefix, proto); } -string MangleTensor(const TensorProto& tensor) { +std::string MangleTensor(const TensorProto& tensor) { return absl::StrCat(kTensorPrefix, PrintShortTextProto(tensor)); } @@ -82,7 +82,7 @@ absl::Status DemangleTensor(absl::string_view str, TensorProto* proto) { return ParseTextProto(str, kTensorPrefix, proto); } -string MangleDataType(const DataType& dtype) { +std::string MangleDataType(const DataType& dtype) { return absl::StrCat(kDataTypePrefix, DataType_Name(dtype)); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h index a0c14f27b5b38f..7e95a27f0290f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h @@ -28,7 +28,7 @@ namespace mangling_util { enum class MangledKind { kUnknown, kDataType, kTensorShape, kTensor }; // Mangles an attribute name, marking the attribute as a TensorFlow attribute. -string MangleAttributeName(absl::string_view str); +std::string MangleAttributeName(absl::string_view str); // Returns true if 'str' was mangled with MangleAttributeName. bool IsMangledAttributeName(absl::string_view str); @@ -41,17 +41,17 @@ absl::string_view DemangleAttributeName(absl::string_view str); MangledKind GetMangledKind(absl::string_view str); // Return a TensorShapeProto mangled as a string. -string MangleShape(const TensorShapeProto& shape); +std::string MangleShape(const TensorShapeProto& shape); // Demangle a string mangled with MangleShape. absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto); // Return a TensorProto mangled as a string. -string MangleTensor(const TensorProto& tensor); +std::string MangleTensor(const TensorProto& tensor); // Demangle a string mangled with MangleTensor. absl::Status DemangleTensor(absl::string_view str, TensorProto* proto); // Return a DataType mangled as a string. -string MangleDataType(const DataType& dtype); +std::string MangleDataType(const DataType& dtype); // Demangle a string mangled with MangleDataType. absl::Status DemangleDataType(absl::string_view str, DataType* proto); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index c9a6f6e85c9d4d..c1479fead3a595 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -133,7 +133,7 @@ absl::Status SetTypeAttribute(absl::string_view name, ContainerT types, type_list.add_type(dtype); } - auto result = values->insert({string(name), value}); + auto result = values->insert({std::string(name), value}); assert(result.second && "cannot have multiple attributes with the same name"); (void)result; @@ -164,7 +164,7 @@ void SetShapeAttribute(absl::string_view name, ContainerT shapes, // If shape is already set, override it. This can happen if we import // without shape inference enabled and so couldn't be removed on import and // are not explicitly dropped later. - (*values)[string(name)] = value; + (*values)[std::string(name)] = value; } // Collects all the unregistered attributes for an TF dialect operation. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc index 8cb797a9a9b214..b13e099fde3557 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc @@ -214,7 +214,7 @@ absl::StatusOr> BuildConstOpGraphWithOutputShapes() { std::initializer_list dims = {2, 3, 4, 5}; Tensor tensor(data_type, TensorShape(dims)); for (int i = 0; i < 2 * 3 * 4 * 5; ++i) { - tensor.flat()(i) = i; + tensor.flat()(i) = i; } NodeDef node; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 46f7f5de1d0856..74b7304b745033 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -106,9 +106,9 @@ namespace { // Time the execution of kernels (in CPU cycles). Meant to be used as RAII. struct CompilationTimer { - uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); + uint64_t start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); - uint64 ElapsedCycles() { + uint64_t ElapsedCycles() { return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 243f4333a88525..2ab0c3c619b292 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -4864,7 +4864,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -5064,7 +5064,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; auto strides_attr = GetI64ElementsAttr(op.getStrides()); - std::vector strides{ + std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 159bc8b17bc36b..a6ee4c3e1ffbd0 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -308,7 +308,7 @@ py_strict_library( "//tensorflow/python/framework:op_def_registry", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:tf_inspect", - "@pypi_gast//:pkg", + "@pypi//gast", ], ) @@ -339,7 +339,7 @@ py_strict_library( "//tensorflow/python/autograph/pyct:transpiler", "//tensorflow/python/framework:op_def_registry", "//tensorflow/python/util:tf_inspect", - "@pypi_gast//:pkg", + "@pypi//gast", ], ) diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc index f9e70b228c0b71..ddff766c789450 100644 --- a/tensorflow/compiler/mlir/tfr/utils/utils.cc +++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "mlir/IR/Block.h" // from @llvm-project @@ -92,9 +93,9 @@ std::string GetComposeFuncName(StringRef tf_op_name) { } if (tf_op_name[i] == '.') { compose_func_name.push_back('_'); - } else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') { + } else if (llvm::isUpper(tf_op_name[i])) { compose_func_name.push_back('_'); - compose_func_name.push_back(tf_op_name[i] + 'a' - 'A'); + compose_func_name.push_back(llvm::toLower(tf_op_name[i])); } else { compose_func_name.push_back(tf_op_name[i]); } @@ -106,13 +107,13 @@ std::string GetTFOpName(StringRef compose_func_name) { std::string tf_op_name; bool after_underscore = false; for (int i = 0; i < compose_func_name.size(); ++i) { - if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') { + if (llvm::isUpper(compose_func_name[i])) { // The field name must not contain uppercase letters. return {}; } if (after_underscore) { - if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') { - tf_op_name.push_back(compose_func_name[i] + 'A' - 'a'); + if (llvm::isLower(compose_func_name[i])) { + tf_op_name.push_back(llvm::toUpper(compose_func_name[i])); after_underscore = false; } else { // The character after a "_" must be a lowercase letter. diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index cc59c9150da769..7f4a602b1330a6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -906,10 +906,6 @@ void CreateFallbackInitializationFunction( builder.create( func_op.getLoc(), /*resultTypes=*/mlir::TypeRange{}, /*operands=*/mlir::ValueRange{}, op->getAttrs()); - } else { - // TODO: b/381849919 - Remove this log once the bug is fixed. - LOG_FIRST_N(WARNING, 100) - << "Skip creation of fallback kernel for op index " << op_index; } } diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc index 077d662ef4ed1c..6ce41c7f4fe829 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -45,8 +45,6 @@ namespace { using ::testing::ElementsAreArray; using ::testing::FloatEq; using ::testing::IsEmpty; -using ::tsl::testing::IsOkAndHolds; -using ::tsl::testing::StatusIs; TEST(MlirToByteCodeTest, Basic) { constexpr char kBasicMlir[] = diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc index 8f06eb691551ef..6d6d572a79e9f2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc @@ -127,17 +127,17 @@ absl::StatusOr EmitToBinary(llvm::StringRef host_triple, return ostream.str().str(); } -absl::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, - llvm::StringRef host_triple, - llvm::ArrayRef architectures, - llvm::ArrayRef tile_sizes, - llvm::ArrayRef unroll_factors, bool print_ptx, +absl::Status Run(std::string input_file, std::string output_file, + std::string host_triple, + std::vector architectures, + std::vector tile_sizes, + std::vector unroll_factors, bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile, bool jit_i64_indexed_for_large_tensors) { // Read TF code. std::string hlo_code; TF_RETURN_IF_ERROR( - ReadFileToString(Env::Default(), input_file.str(), &hlo_code)); + ReadFileToString(Env::Default(), input_file, &hlo_code)); // Compile. mlir::DialectRegistry registry; @@ -160,7 +160,7 @@ absl::Status Run(llvm::StringRef input_file, llvm::StringRef output_file, // Write .a file. TF_RETURN_IF_ERROR( - WriteStringToFile(Env::Default(), output_file.str(), binary)); + WriteStringToFile(Env::Default(), output_file, binary)); return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 0c0bfee0e9407e..15a697ddf75807 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -118,7 +118,7 @@ class GpuKernelToBlobPass auto llvm_module_copy = llvm::CloneModule(*llvmModule); auto hsaco_or = xla::gpu::amdgpu::CompileToHsaco( llvm_module_copy.get(), - tensorflow::se::RocmComputeCapability{arch_str}, options, + stream_executor::GpuComputeCapability(tensorflow::se::RocmComputeCapability{arch_str}), options, options.DebugString()); if (!hsaco_or.ok()) { return tensorflow::errors::Internal("Failure when generating HSACO"); diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.h b/tensorflow/compiler/mlir/tosa/tfl_passes.h index 96d3cabf0c1f1f..02bd007f6fa36c 100644 --- a/tensorflow/compiler/mlir/tosa/tfl_passes.h +++ b/tensorflow/compiler/mlir/tosa/tfl_passes.h @@ -42,8 +42,8 @@ struct TOSATFLLegalizationPipelineOptions llvm::cl::desc("Dequantize the TFLite softmax"), llvm::cl::init(false)}; TOSATFLLegalizationPipelineOptions() { - disabled_patterns = std::nullopt; - enabled_patterns = std::nullopt; + disabled_patterns = {}; + enabled_patterns = {}; } }; diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h index de0872b660d4ec..0475d46a37a091 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -53,8 +53,8 @@ std::unique_ptr> createFuseBiasTFPass(); // `enabledPatterns` is a set of labels used to filter out input patterns that // do not have one of the labels in this set. std::unique_ptr> createLegalizeTFLPass( - ArrayRef disabled_patterns = std::nullopt, - ArrayRef enabled_patterns = std::nullopt); + ArrayRef disabled_patterns = {}, + ArrayRef enabled_patterns = {}); std::unique_ptr> createRetainCallOnceFuncsPass(); std::unique_ptr> createStripModuleMetadataPass(); diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc index fd50116ba7d1a7..fb5bb77644c211 100644 --- a/tensorflow/compiler/mlir/utils/name_utils.cc +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -31,8 +31,8 @@ namespace { // Checks if a character is legal for a TensorFlow node name, with special // handling if a character is at the beginning. bool IsLegalChar(char c, bool first_char) { - if (isalpha(c)) return true; - if (isdigit(c)) return true; + if (llvm::isAlpha(c)) return true; + if (llvm::isDigit(c)) return true; if (c == '.') return true; if (c == '_') return true; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 94ca1645435a2a..e44cfddd144a12 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -2,7 +2,14 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/tests:build_combined_defs.bzl", "tf_xla_combined_py_test") -load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test") +load( + "//tensorflow/compiler/tests:build_defs.bzl", + "generate_backend_suites", + "tf_xla_py_strict_test", + # copybara:uncomment_begin(google-only) + # "tpu_backends", + # copybara:uncomment_end +) load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -214,9 +221,8 @@ tf_xla_combined_py_test( name = "combined_ops_test_f", size = "medium", timeout = "long", - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, + # copybara:uncomment_begin(google-only) + # disabled_backends = tpu_backends(), # copybara:uncomment_end exec_properties = { "cpp_link.mem": "16g", @@ -341,10 +347,6 @@ tf_xla_py_strict_test( name = "add_n_test", size = "small", srcs = ["add_n_test.py"], - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", @@ -496,10 +498,6 @@ tf_xla_py_strict_test( name = "cond_test", size = "small", srcs = ["cond_test.py"], - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", @@ -1743,12 +1741,8 @@ tf_xla_py_strict_test( name = "tensor_list_ops_test", size = "small", srcs = ["tensor_list_ops_test.py"], - # copybara:uncomment_begin - # #TODO(b/286470564): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end - # TensorList ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], + # TensorList ops are only implemented on CPU. + enabled_backends = ["cpu"], tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1905,10 +1899,6 @@ tf_xla_py_strict_test( name = "while_test", size = "small", srcs = ["while_test.py"], - # copybara:uncomment_begin - # #TODO(b/291130193): Remove once the bug is fixed. - # disable_tpu_tfrt = True, - # copybara:uncomment_end tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "notap", @@ -2165,7 +2155,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - env = {"XLA_FLAGS": "--xla_backend_extra_options=xla_cpu_disable_new_fusion_emitters=true"}, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -2429,9 +2418,6 @@ tf_xla_py_strict_test( name = "where_op_tpu_test", size = "small", srcs = ["where_op_test.py"], - args = [ - "--tpu_use_tfrt=true", - ], disabled_backends = [ "cpu", "cpu_ondemand", diff --git a/tensorflow/compiler/tests/cast_test.py b/tensorflow/compiler/tests/cast_test.py index bc35db4e05f7d5..453cbeb1312648 100644 --- a/tensorflow/compiler/tests/cast_test.py +++ b/tensorflow/compiler/tests/cast_test.py @@ -35,9 +35,10 @@ def test_cast(self): dtypes.uint32, dtypes.uint64, } - for src_type in types: - for dst_type in types: - self._test_cast(src_type, dst_type) + with self.session() as session: + for src_type in types: + for dst_type in types: + self._test_cast(src_type, dst_type, session) def test_cast_fp8(self): if platform.system() == "Darwin": @@ -61,12 +62,13 @@ def test_cast_fp8(self): dtypes.uint32, dtypes.uint64, } - for fp8_type in fp8_types: - for other_type in other_types | fp8_types: - self._test_cast(fp8_type, other_type) - self._test_cast(other_type, fp8_type) + with self.session() as session: + for fp8_type in fp8_types: + for other_type in other_types | fp8_types: + self._test_cast(fp8_type, other_type, session) + self._test_cast(other_type, fp8_type, session) - def _test_cast(self, src_type, dst_type): + def _test_cast(self, src_type, dst_type, session): with self.subTest(src_type=src_type, dst_type=dst_type): shapes = [[], [4], [2, 3], [2, 0, 4]] src_np_dtype = src_type.as_numpy_dtype @@ -83,6 +85,7 @@ def _test_cast(self, src_type, dst_type): lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, expected=dst, + local_session=session, ) # Check special values. @@ -112,6 +115,7 @@ def _test_cast(self, src_type, dst_type): lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, expected=dst, + local_session=session, ) def test_give_me_a_name(self): diff --git a/tensorflow/compiler/tests/float_ops_test.py b/tensorflow/compiler/tests/float_ops_test.py index d8743016c20756..67a1ecc967f24c 100644 --- a/tensorflow/compiler/tests/float_ops_test.py +++ b/tensorflow/compiler/tests/float_ops_test.py @@ -23,449 +23,522 @@ class FloatOpsTest(xla_test.XLATestCase): def test_float_ops(self): - for dtype in self.float_types: - x = np.arange(-0.90, 0.90, 0.25) - self.assert_op_output_matches_expected( - math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype) - ) - self.assert_op_output_matches_expected( - math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype) - ) - x = np.arange(-3, 3).reshape(1, 3, 2) - self.assert_op_output_matches_expected( - math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype) - ) - - self.assert_op_output_matches_expected( - math_ops.acosh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.asinh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.atanh, - np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), - expected=np.array( - [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.ceil, - np.array([[-1.7, 1.2]], dtype=dtype), - expected=np.array([[-1, 2]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.cosh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype - ), - ) - - # Disable float16 testing for now - if dtype != np.float16: - x = np.arange(-10, 10, 1).astype(dtype) - with self.session() as session: + with self.session() as session: + for dtype in self.float_types: + x = np.arange(-0.90, 0.90, 0.25) + self.assert_op_output_matches_expected( + math_ops.acos, + x.astype(dtype), + expected=np.arccos(x).astype(dtype), + local_session=session, + ) + self.assert_op_output_matches_expected( + math_ops.asin, + x.astype(dtype), + expected=np.arcsin(x).astype(dtype), + local_session=session, + ) + x = np.arange(-3, 3).reshape(1, 3, 2) + self.assert_op_output_matches_expected( + math_ops.atan, + x.astype(dtype), + expected=np.arctan(x).astype(dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.acosh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.asinh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.atanh, + np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), + expected=np.array( + [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.ceil, + np.array([[-1.7, 1.2]], dtype=dtype), + expected=np.array([[-1, 2]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.cosh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype + ), + local_session=session, + ) + + # Disable float16 testing for now + if dtype != np.float16: + x = np.arange(-10, 10, 1).astype(dtype) erf_x = session.run(math_ops.erf(x)) erfc_x = session.run(math_ops.erfc(x)) - self.assert_op_output_matches_expected(math_ops.erf, x, expected=erf_x) - self.assert_op_output_matches_expected( - math_ops.erfc, x, expected=erfc_x - ) - - self.assert_op_output_matches_expected( - math_ops.exp, - np.array([[-1, 1]], dtype=dtype), - expected=np.array([[0.36787945, 2.7182817]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.expm1, - np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), - rtol=1e-5, - ) - - self.assert_op_output_matches_expected( - math_ops.floor, - np.array([[-1.7, 1.2]], dtype=dtype), - expected=np.array([[-2, 1]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.is_finite, - np.array( - [[-np.inf, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype - ), - expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool_), - ) - - # Tests for tf.nn ops. - self.assert_op_output_matches_expected( - nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0) - ) - - self.assert_op_output_matches_expected(nn_ops.l2_loss, dtype(4), dtype(8)) - - self.assert_op_output_matches_expected( - nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10) - ) - - self.assert_op_output_matches_expected( - math_ops.reciprocal, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[1, 0.5]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.log, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[0, 0.69314718]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.sin, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[0.841478, 0.909302]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.cos, - np.array([[1, 2]], dtype=dtype), - expected=np.array([[0.540297, -0.41614]], dtype=dtype), - ) - - # Confirm that log1p will remain precise across a range of small values. - self.assert_op_output_matches_expected( - math_ops.log1p, - np.array( - [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], - dtype=dtype, - ), - expected=np.log1p( - np.array( - [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], - dtype=dtype, - ) - ).astype(dtype), - rtol=1e-15 if dtype == np.float64 else 1e-4, - atol=1e-15 if dtype == np.float64 else 1e-4, - ) - - self.assert_op_output_matches_expected( - math_ops.rint, - np.array( - [ - [-1.7, 1.2, 4.0, 0.0], - [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5], - ], - dtype=dtype, - ), - expected=np.array( - [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype - ), - ) - self.assert_op_output_matches_expected( - math_ops.round, - np.array( - [ - [-1.7, 1.2, 4.0, 0.0], - [-3.5, -2.5, -1.5, -0.5], - [0.5, 1.5, 2.5, 3.5], - ], - dtype=dtype, - ), - expected=np.array( - [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.rsqrt, - np.array([[4, 16]], dtype=dtype), - expected=np.array([[0.5, 0.25]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.sigmoid, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [ - [0.7310586, 0.7310586, 0.7310586, 0.7310586], - [0.7310586, 0.880797, 0.95257413, 0.98201376], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - math_ops.sigmoid, - np.array([-300, -150, 0, 150, 300], dtype=dtype), - expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.sinh, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.sqrt, - np.array([[4, 9]], dtype=dtype), - expected=np.array([[2, 3]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.tan, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.tanh, - np.array( - [[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], [19, -19, 22, -22]], - dtype=dtype, - ), - expected=np.array( - [ - [0.76159418, 0.96402758, 0.99505478, 0.99932933], - [1.0, -1.0, np.nan, 1.0], - [1.0, -1.0, 1.0, -1.0], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.log_softmax, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [ - [-1.3862944, -1.3862944, -1.3862944, -1.3862944], - [-3.4401896, -2.4401896, -1.4401897, -0.44018969], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.elu, - np.array([[-1, 0, 1, -1e-6]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), - rtol=1e-5, - atol=1e-6, - ) - - self.assert_op_output_matches_expected( - nn_ops.selu, - np.array([[-1, 0, 1, -1e-5]], dtype=dtype), - expected=np.array( - [[-1.11133074, 0.0, 1.05070099, -1.758090550379974e-05]], - dtype=dtype, - ), - rtol=1e-5, - atol=1e-6, - ) - - self.assert_op_output_matches_expected( - nn_ops.relu, - np.array([[-1, 1]], dtype=dtype), - expected=np.array([[0, 1]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - nn_ops.relu6, - np.array([[-0.05, 6.05, 5]], dtype=dtype), - expected=np.array([[0, 6, 5]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - nn_ops.leaky_relu, - np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype), - ) - - self.assert_op_output_matches_expected( - nn_ops.softmax, - np.array([1, 2, 3, 4], dtype=dtype), - expected=np.array( - [0.032058604, 0.087144323, 0.23688284, 0.64391428], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.softmax, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [ - [0.25, 0.25, 0.25, 0.25], - [0.032058604, 0.087144323, 0.23688284, 0.64391428], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.softmax, - np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), - expected=np.array( - [ - [[0.5, 0.5], [0.5, 0.5]], - [[0.26894142, 0.73105858], [0.26894142, 0.73105858]], - ], - dtype=dtype, - ), - ) - - self.assert_op_output_matches_expected( - nn_ops.softsign, - np.array([[-2, -1, 0, 1, 2]], dtype=dtype), - expected=np.array( - [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.sign, - np.array( - [[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0, float("nan")]], dtype=dtype - ), - expected=np.array( - [[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0, float("nan")]], dtype=dtype - ), - ) - - self.assert_op_output_matches_expected( - math_ops.is_finite, - np.array( - [[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype - ), - expected=np.array( - [[True, False, True], [False, True, True]], dtype=np.bool_ - ), - ) - - self.assert_op_output_matches_expected( - math_ops.lgamma, - np.array(0.5, dtype=dtype), - expected=np.array(np.log(np.pi) / 2, dtype=dtype), - ) - - self.assert_op_output_matches_expected( - math_ops.lgamma, - np.array( - [ - [1, 2, 3], - [4, 5, 6], - [1 / 2, 3 / 2, 5 / 2], - [-3 / 2, -7 / 2, -11 / 2], - ], - dtype=dtype, - ), - expected=np.array( - [ - [0, 0, np.log(2.0)], - [np.log(6.0), np.log(24.0), np.log(120)], - [ - np.log(np.pi) / 2, - np.log(np.pi) / 2 - np.log(2), - np.log(np.pi) / 2 - np.log(4) + np.log(3), - ], - [ - np.log(np.pi) / 2 - np.log(3) + np.log(4), - np.log(np.pi) / 2 - np.log(105) + np.log(16), - np.log(np.pi) / 2 - np.log(10395) + np.log(64), - ], - ], - dtype=dtype, - ), - ) - - # The actual result is complex. Take the real part. - self.assert_op_output_matches_expected( - math_ops.lgamma, - np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), - expected=np.array( - [ - np.log(np.pi) / 2 + np.log(2), - np.log(np.pi) / 2 - np.log(15) + np.log(8), - np.log(np.pi) / 2 - np.log(945) + np.log(32), - ], - dtype=dtype, - ), - atol=1e-4, - ) - - self.assert_op_output_matches_expected( - math_ops.digamma, - np.array( - [ - [1.0, 0.5, 1 / 3.0], - [0.25, 1 / 6.0, 0.125], - [2.0, 3.0, 4.0], - [6.0, 8.0, 9.0], - ], - dtype=dtype, - ), - expected=np.array( - [ - [ - -np.euler_gamma, - -2 * np.log(2) - np.euler_gamma, - -np.pi / 2 / np.sqrt(3) - - 3 * np.log(3) / 2 - - np.euler_gamma, - ], - [ - -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, - -np.pi * np.sqrt(3) / 2 - - 2 * np.log(2) - - 3 * np.log(3) / 2 - - np.euler_gamma, - -np.pi / 2 - - 4 * np.log(2) - - ( - np.pi - + np.log(2 + np.sqrt(2)) - - np.log(2 - np.sqrt(2)) - ) - / np.sqrt(2) - - np.euler_gamma, - ], - [ - 1 - np.euler_gamma, - 1.5 - np.euler_gamma, - 11 / 6.0 - np.euler_gamma, - ], - [ - 137 / 60.0 - np.euler_gamma, - 363 / 140.0 - np.euler_gamma, - 761 / 280.0 - np.euler_gamma, - ], - ], - dtype=dtype, - ), - ) + self.assert_op_output_matches_expected( + math_ops.erf, + x, + expected=erf_x, + local_session=session, + ) + self.assert_op_output_matches_expected( + math_ops.erfc, + x, + expected=erfc_x, + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.exp, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[0.36787945, 2.7182817]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.expm1, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + local_session=session, + rtol=1e-5, + ) + + self.assert_op_output_matches_expected( + math_ops.floor, + np.array([[-1.7, 1.2]], dtype=dtype), + expected=np.array([[-2, 1]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.is_finite, + np.array( + [[-np.inf, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype + ), + expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool_), + local_session=session, + ) + + # Tests for tf.nn ops. + self.assert_op_output_matches_expected( + nn_ops.l2_loss, + np.array([[[]]], dtype=dtype), + expected=dtype(0), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.l2_loss, + dtype(4), + dtype(8), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.l2_loss, + np.array([[-2, 4]], dtype=dtype), + expected=dtype(10), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.reciprocal, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[1, 0.5]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.log, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0, 0.69314718]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sin, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0.841478, 0.909302]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.cos, + np.array([[1, 2]], dtype=dtype), + expected=np.array([[0.540297, -0.41614]], dtype=dtype), + local_session=session, + ) + + # Confirm that log1p will remain precise across a range of small values. + self.assert_op_output_matches_expected( + math_ops.log1p, + np.array( + [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], + dtype=dtype, + ), + expected=np.log1p( + np.array( + [[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]], + dtype=dtype, + ) + ).astype(dtype), + local_session=session, + rtol=1e-15 if dtype == np.float64 else 1e-4, + atol=1e-15 if dtype == np.float64 else 1e-4, + ) + + self.assert_op_output_matches_expected( + math_ops.rint, + np.array( + [ + [-1.7, 1.2, 4.0, 0.0], + [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5], + ], + dtype=dtype, + ), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype + ), + local_session=session, + ) + self.assert_op_output_matches_expected( + math_ops.round, + np.array( + [ + [-1.7, 1.2, 4.0, 0.0], + [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5], + ], + dtype=dtype, + ), + expected=np.array( + [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.rsqrt, + np.array([[4, 16]], dtype=dtype), + expected=np.array([[0.5, 0.25]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sigmoid, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), + expected=np.array( + [ + [0.7310586, 0.7310586, 0.7310586, 0.7310586], + [0.7310586, 0.880797, 0.95257413, 0.98201376], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sigmoid, + np.array([-300, -150, 0, 150, 300], dtype=dtype), + expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sinh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sqrt, + np.array([[4, 9]], dtype=dtype), + expected=np.array([[2, 3]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.tan, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.tanh, + np.array( + [ + [1, 2, 3, 4], + [np.inf, -np.inf, np.nan, 20], + [19, -19, 22, -22], + ], + dtype=dtype, + ), + expected=np.array( + [ + [0.76159418, 0.96402758, 0.99505478, 0.99932933], + [1.0, -1.0, np.nan, 1.0], + [1.0, -1.0, 1.0, -1.0], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.log_softmax, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), + expected=np.array( + [ + [-1.3862944, -1.3862944, -1.3862944, -1.3862944], + [-3.4401896, -2.4401896, -1.4401897, -0.44018969], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.elu, + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array( + [[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype + ), + rtol=1e-5, + atol=1e-6, + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.selu, + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0.0, 1.05070099, -1.758090550379974e-05]], + dtype=dtype, + ), + rtol=1e-5, + atol=1e-6, + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.relu, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[0, 1]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.relu6, + np.array([[-0.05, 6.05, 5]], dtype=dtype), + expected=np.array([[0, 6, 5]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.leaky_relu, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array([[-0.4, -0.2, 0.0, 1.0, 2.0]], dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softmax, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array( + [0.032058604, 0.087144323, 0.23688284, 0.64391428], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softmax, + np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), + expected=np.array( + [ + [0.25, 0.25, 0.25, 0.25], + [0.032058604, 0.087144323, 0.23688284, 0.64391428], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softmax, + np.array([[[1, 1], [1, 1]], [[1, 2], [3, 4]]], dtype=dtype), + expected=np.array( + [ + [[0.5, 0.5], [0.5, 0.5]], + [[0.26894142, 0.73105858], [0.26894142, 0.73105858]], + ], + dtype=dtype, + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + nn_ops.softsign, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array( + [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.sign, + np.array( + [[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0, float("nan")]], dtype=dtype + ), + expected=np.array( + [[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0, float("nan")]], dtype=dtype + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.is_finite, + np.array( + [[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype + ), + expected=np.array( + [[True, False, True], [False, True, True]], dtype=np.bool_ + ), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.lgamma, + np.array(0.5, dtype=dtype), + expected=np.array(np.log(np.pi) / 2, dtype=dtype), + local_session=session, + ) + + self.assert_op_output_matches_expected( + math_ops.lgamma, + np.array( + [ + [1, 2, 3], + [4, 5, 6], + [1 / 2, 3 / 2, 5 / 2], + [-3 / 2, -7 / 2, -11 / 2], + ], + dtype=dtype, + ), + expected=np.array( + [ + [0, 0, np.log(2.0)], + [np.log(6.0), np.log(24.0), np.log(120)], + [ + np.log(np.pi) / 2, + np.log(np.pi) / 2 - np.log(2), + np.log(np.pi) / 2 - np.log(4) + np.log(3), + ], + [ + np.log(np.pi) / 2 - np.log(3) + np.log(4), + np.log(np.pi) / 2 - np.log(105) + np.log(16), + np.log(np.pi) / 2 - np.log(10395) + np.log(64), + ], + ], + dtype=dtype, + ), + local_session=session, + ) + + # The actual result is complex. Take the real part. + self.assert_op_output_matches_expected( + math_ops.lgamma, + np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), + expected=np.array( + [ + np.log(np.pi) / 2 + np.log(2), + np.log(np.pi) / 2 - np.log(15) + np.log(8), + np.log(np.pi) / 2 - np.log(945) + np.log(32), + ], + dtype=dtype, + ), + local_session=session, + atol=1e-4, + ) + + self.assert_op_output_matches_expected( + math_ops.digamma, + np.array( + [ + [1.0, 0.5, 1 / 3.0], + [0.25, 1 / 6.0, 0.125], + [2.0, 3.0, 4.0], + [6.0, 8.0, 9.0], + ], + dtype=dtype, + ), + expected=np.array( + [ + [ + -np.euler_gamma, + -2 * np.log(2) - np.euler_gamma, + -np.pi / 2 / np.sqrt(3) + - 3 * np.log(3) / 2 + - np.euler_gamma, + ], + [ + -np.pi / 2 - 3 * np.log(2) - np.euler_gamma, + -np.pi * np.sqrt(3) / 2 + - 2 * np.log(2) + - 3 * np.log(3) / 2 + - np.euler_gamma, + -np.pi / 2 + - 4 * np.log(2) + - ( + np.pi + + np.log(2 + np.sqrt(2)) + - np.log(2 - np.sqrt(2)) + ) + / np.sqrt(2) + - np.euler_gamma, + ], + [ + 1 - np.euler_gamma, + 1.5 - np.euler_gamma, + 11 / 6.0 - np.euler_gamma, + ], + [ + 137 / 60.0 - np.euler_gamma, + 363 / 140.0 - np.euler_gamma, + 761 / 280.0 - np.euler_gamma, + ], + ], + dtype=dtype, + ), + local_session=session, + ) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index b96608ed392263..fcd3aadbe10c9a 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -110,12 +110,12 @@ namespace { int64_t tf_xla_random_seed = 0; int32_t tf_xla_test_repetitions = 20; int64_t tf_xla_max_tensor_size = 10000LL; -string* tf_xla_test_device_ptr; // initial value set in main() -string* tf_xla_reference_device_ptr; // initial value set in main() +std::string* tf_xla_test_device_ptr; // initial value set in main() +std::string* tf_xla_reference_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; bool tf_xla_test_use_mlir = false; -string LocalDeviceToFullDeviceName(const string& device) { +std::string LocalDeviceToFullDeviceName(const std::string& device) { return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } @@ -129,7 +129,7 @@ constexpr std::array kAllNumberTypes = { // operator. class OpTestBuilder { public: - explicit OpTestBuilder(const string& op_name); + explicit OpTestBuilder(const std::string& op_name); // Adds an input 'tensor' as a Placeholder node. OpTestBuilder& Input(const Tensor& tensor); @@ -161,10 +161,11 @@ class OpTestBuilder { // sets it to the NodeDef of the operator under test. Fills 'inputs' and // 'outputs' with the names of the input placeholder nodes and the output // identity nodes, respectively. - absl::Status BuildGraph(const string& name_prefix, const string& device, - bool use_jit, GraphDef* graphdef, - NodeDef** test_node_def, std::vector* inputs, - std::vector* outputs) const; + absl::Status BuildGraph(const std::string& name_prefix, + const std::string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, + std::vector* inputs, + std::vector* outputs) const; struct InputDescription { Tensor tensor; @@ -182,7 +183,7 @@ class OpTestBuilder { std::vector inputs_; }; -OpTestBuilder::OpTestBuilder(const string& op_name) { +OpTestBuilder::OpTestBuilder(const std::string& op_name) { node_def_.set_op(op_name); } @@ -247,12 +248,10 @@ OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, return *this; } -absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, - const string& device, bool use_jit, - GraphDef* graphdef, - NodeDef** test_node_def, - std::vector* inputs, - std::vector* outputs) const { +absl::Status OpTestBuilder::BuildGraph( + const std::string& name_prefix, const std::string& device, bool use_jit, + GraphDef* graphdef, NodeDef** test_node_def, + std::vector* inputs, std::vector* outputs) const { OpRegistryInterface* op_registry = OpRegistry::Global(); const OpDef* op_def; @@ -275,7 +274,7 @@ absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = absl::StrCat(name_prefix, "_input_", i); + std::string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -286,7 +285,7 @@ absl::Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = absl::StrCat(name_prefix, "_output_", i); + std::string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -494,7 +493,7 @@ class OpTest : public ::testing::Test { const std::vector& spatial_dims); // Converts an int64 vector to an int32 vector. - std::vector AsInt32s(const std::vector& int64s); + std::vector AsInt32s(const std::vector& int64s); std::mt19937& generator() { return *generator_; } @@ -664,16 +663,16 @@ class TensorGeneratorComplex64 : public TensorGenerator { } }; -class TensorGeneratorInt32 : public TensorGenerator { +class TensorGeneratorInt32 : public TensorGenerator { public: explicit TensorGeneratorInt32(OpTest& test) : TensorGenerator(test) {} DataType dtype() override { return DT_INT32; } - void RandomVals(std::optional lo, std::optional hi, + void RandomVals(std::optional lo, std::optional hi, bool needs_unique_values, - absl::FixedArray& vals) override { - absl::flat_hash_set already_generated; - std::uniform_int_distribution distribution(lo.value_or(-(1 << 20)), - hi.value_or(1 << 20)); + absl::FixedArray& vals) override { + absl::flat_hash_set already_generated; + std::uniform_int_distribution distribution(lo.value_or(-(1 << 20)), + hi.value_or(1 << 20)); for (int64_t i = 0; i < vals.size(); ++i) { int32_t generated; do { @@ -685,13 +684,13 @@ class TensorGeneratorInt32 : public TensorGenerator { } }; -class TensorGeneratorInt64 : public TensorGenerator { +class TensorGeneratorInt64 : public TensorGenerator { public: explicit TensorGeneratorInt64(OpTest& test) : TensorGenerator(test) {} DataType dtype() override { return DT_INT64; } - void RandomVals(std::optional lo, std::optional hi, + void RandomVals(std::optional lo, std::optional hi, bool needs_unique_values, - absl::FixedArray& vals) override { + absl::FixedArray& vals) override { absl::flat_hash_set already_generated; std::uniform_int_distribution distribution( lo.value_or(-(1LL << 40)), hi.value_or(1LL << 40)); @@ -928,18 +927,19 @@ Tensor OpTest::RandomBoundedTensor(DataType dtype, Tensor lo, Tensor hi) { break; } case DT_INT32: { - auto lo_flat = lo.flat(); - auto hi_flat = hi.flat(); - test::FillFn(&tensor, [this, &lo_flat, &hi_flat](int i) -> int32 { - std::uniform_int_distribution distribution(lo_flat(i), - hi_flat(i)); - return distribution(generator()); - }); + auto lo_flat = lo.flat(); + auto hi_flat = hi.flat(); + test::FillFn( + &tensor, [this, &lo_flat, &hi_flat](int i) -> int32_t { + std::uniform_int_distribution distribution(lo_flat(i), + hi_flat(i)); + return distribution(generator()); + }); break; } case DT_INT64: { - auto lo_flat = lo.flat(); - auto hi_flat = hi.flat(); + auto lo_flat = lo.flat(); + auto hi_flat = hi.flat(); test::FillFn( &tensor, [this, &lo_flat, &hi_flat](int i) -> int64_t { std::uniform_int_distribution distribution(lo_flat(i), @@ -1021,21 +1021,21 @@ OpTest::BroadcastableDims() { Tensor OpTest::RandomReductionIndices(int rank) { std::bernoulli_distribution random_bool; - std::vector indices; + std::vector indices; for (int i = 0; i < rank; ++i) { if (random_bool(generator())) { indices.push_back(i); } } - return test::AsTensor(indices); + return test::AsTensor(indices); } // Helper that converts 'values' to an int32 or int64 Tensor. static Tensor AsIntTensor(DataType dtype, const std::vector& values) { switch (dtype) { case DT_INT32: { - std::vector values32(values.begin(), values.end()); - return test::AsTensor(values32); + std::vector values32(values.begin(), values.end()); + return test::AsTensor(values32); } case DT_INT64: return test::AsTensor(values); @@ -1092,9 +1092,9 @@ OpTest::ConcatArguments OpTest::ChooseConcatArguments(bool int64_idx_allowed) { std::vector dims = RandomDims(1, 4, 0, 64); int axis = - std::uniform_int_distribution(0, dims.size() - 1)(generator()); - a.axis = - use_int64_idx ? test::AsScalar(axis) : test::AsScalar(axis); + std::uniform_int_distribution(0, dims.size() - 1)(generator()); + a.axis = use_int64_idx ? test::AsScalar(axis) + : test::AsScalar(axis); for (int i = 0; i < a.n; ++i) { std::vector shape = dims; @@ -1113,7 +1113,7 @@ OpTest::EinsumArguments OpTest::ChooseEinsumArguments() { switch (op_kind) { case matmul: case batchmatmul: { - std::vector dims; + std::vector dims; if (op_kind == matmul) { a.equation = "ij,jk->ik"; dims = RandomDims(2, 2); @@ -1131,7 +1131,7 @@ OpTest::EinsumArguments OpTest::ChooseEinsumArguments() { } case dot: { a.equation = "i,i->"; - std::vector dims = RandomDims(1, 1); + std::vector dims = RandomDims(1, 1); a.lhs_dims = dims; a.rhs_dims = dims; break; @@ -1166,11 +1166,11 @@ OpTest::GatherArguments OpTest::ChooseGatherArguments(bool axis_0) { a.batch_dims, kDefaultMaxRank - 1); axis = axis_distribution(generator()); } - a.axis = test::AsScalar((int32)axis); + a.axis = test::AsScalar((int32_t)axis); a.params_shape = RandomDims(axis + 1, kDefaultMaxRank, 1, 16); std::vector indices_shape = RandomDims(0, 3, 0, 16); - a.indices = RandomBoundedTensor(DT_INT32, 0, a.params_shape[axis] - 1, - false, indices_shape); + a.indices = RandomBoundedTensor( + DT_INT32, 0, a.params_shape[axis] - 1, false, indices_shape); return a; } @@ -1209,7 +1209,7 @@ OpTest::ScatterArguments OpTest::ChooseScatterArguments() { a.indices_type = DT_INT32; a.shape = RandomDims(1, kDefaultMaxRank, 1); int rank = a.shape.size(); - std::uniform_int_distribution index_len_dist(1, rank); + std::uniform_int_distribution index_len_dist(1, rank); int index_len = index_len_dist(generator()); std::vector indices_first = RandomDims(1, kDefaultMaxRank - 1, 1); std::vector indices_shape(indices_first); @@ -1219,9 +1219,9 @@ OpTest::ScatterArguments OpTest::ChooseScatterArguments() { updates_shape.push_back(a.shape[index_len + i]); } Tensor indices_lo(a.indices_type, TensorShape(indices_shape)); - test::FillFn(&indices_lo, [](int i) -> int32 { return 0; }); + test::FillFn(&indices_lo, [](int i) -> int32_t { return 0; }); Tensor indices_hi(a.indices_type, TensorShape(indices_shape)); - test::FillFn(&indices_hi, [index_len, &a](int i) -> int32 { + test::FillFn(&indices_hi, [index_len, &a](int i) -> int32_t { int idx_dim = i % index_len; return a.shape[idx_dim] - 1; }); @@ -1239,16 +1239,16 @@ OpTest::SliceArguments OpTest::ChooseSliceArguments(bool neg_one_size) { a.shape = RandomDims(); int rank = a.shape.size(); - std::vector indices(rank); + std::vector indices(rank); a.size.resize(rank); for (int i = 0; i < rank; ++i) { indices[i] = - std::uniform_int_distribution(0, a.shape[i])(generator()); + std::uniform_int_distribution(0, a.shape[i])(generator()); int64_t low = neg_one_size ? -1 : 0; a.size[i] = std::uniform_int_distribution( low, a.shape[i] - indices[i])(generator()); } - a.indices = test::AsTensor(indices); + a.indices = test::AsTensor(indices); return a; } @@ -1341,8 +1341,8 @@ std::vector OpTest::ImageDims( return dims; } -std::vector OpTest::AsInt32s(const std::vector& int64s) { - return std::vector(int64s.begin(), int64s.end()); +std::vector OpTest::AsInt32s(const std::vector& int64s) { + return std::vector(int64s.begin(), int64s.end()); } // Functions for comparing tensors. @@ -1382,11 +1382,11 @@ bool IsClose(const complex64& x, const complex64& y, double atol, } template -string Str(T x) { +std::string Str(T x) { return absl::StrCat(x); } template <> -string Str(complex64 x) { +std::string Str(complex64 x) { return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } @@ -1460,7 +1460,7 @@ absl::Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, case DT_COMPLEX64: return TensorsAreCloseImpl(a, b, atol, rtol); case DT_INT32: - return TensorsAreEqualImpl(a, b); + return TensorsAreEqualImpl(a, b); case DT_INT64: return TensorsAreEqualImpl(a, b); case DT_BOOL: @@ -1499,9 +1499,10 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( VLOG(1) << "Input: " << input_tensors.back().DebugString(); } - string reference_device = + std::string reference_device = LocalDeviceToFullDeviceName(*tf_xla_reference_device_ptr); - string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); + std::string test_device = + LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; if (!DeviceNameUtils::ParseLocalName(*tf_xla_test_device_ptr, &parsed_name)) { @@ -1512,8 +1513,8 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( ++num_tests_; GraphDef graph; - std::vector expected_inputs, test_inputs; - std::vector expected_fetches, test_fetches; + std::vector expected_inputs, test_inputs; + std::vector expected_fetches, test_fetches; absl::Status status = builder.BuildGraph( absl::StrCat("test", num_tests_, "_expected"), reference_device, /*use_jit=*/false, &graph, /*test_node_def=*/nullptr, &expected_inputs, @@ -1550,8 +1551,9 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( return kFatalError; } - std::vector> expected_feeds(expected_inputs.size()); - std::vector> test_feeds(test_inputs.size()); + std::vector> expected_feeds( + expected_inputs.size()); + std::vector> test_feeds(test_inputs.size()); CHECK_EQ(input_tensors.size(), expected_inputs.size()); CHECK_EQ(input_tensors.size(), test_inputs.size()); @@ -1707,12 +1709,12 @@ TEST_F(OpTest, ArgMax) { auto type = Choose({DT_BOOL, DT_FLOAT}); std::vector dims = RandomDims(1, 5, 1); int num_dims = dims.size(); - int reduce_dim = - std::uniform_int_distribution(-num_dims, num_dims)(generator()); + int reduce_dim = std::uniform_int_distribution( + -num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMax") .RandomInput(type, dims) - .Input(test::AsScalar(reduce_dim)) + .Input(test::AsScalar(reduce_dim)) .Attr("T", type) .Attr("Tidx", DT_INT32) .Attr("output_type", DT_INT32)); @@ -1724,12 +1726,12 @@ TEST_F(OpTest, ArgMin) { auto type = Choose({DT_BOOL, DT_FLOAT}); std::vector dims = RandomDims(1, 5, 1); int num_dims = dims.size(); - int reduce_dim = - std::uniform_int_distribution(-num_dims, num_dims)(generator()); + int reduce_dim = std::uniform_int_distribution( + -num_dims, num_dims)(generator()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ArgMin") .RandomInput(type, dims) - .Input(test::AsScalar(reduce_dim)) + .Input(test::AsScalar(reduce_dim)) .Attr("T", type) .Attr("Tidx", DT_INT32) .Attr("output_type", DT_INT32)); @@ -1786,7 +1788,7 @@ TEST_F(OpTest, AvgPool) { std::uniform_int_distribution(1, dims[2])(generator()); int stride_rows = random_int(generator()), stride_cols = random_int(generator()); - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool") .RandomInput(DT_FLOAT, dims) @@ -1817,7 +1819,7 @@ TEST_F(OpTest, AvgPool3D) { int64_t batch = dims[3]; int64_t feature = dims[4]; - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool3D") .RandomInput(DT_FLOAT, @@ -1837,13 +1839,13 @@ TEST_F(OpTest, AvgPoolGrad) { Repeatedly([this]() { int batch = RandomDim(1), features = RandomDim(1); WindowedSpatialDims d = ChooseWindowedSpatialDims(2); - std::vector input_dims = + std::vector input_dims = AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); std::vector output_dims = ImageDims(FORMAT_NHWC, batch, features, d.output_dims); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPoolGrad") - .Input(test::AsTensor(input_dims)) + .Input(test::AsTensor(input_dims)) .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) @@ -1859,13 +1861,13 @@ TEST_F(OpTest, AvgPool3DGrad) { Repeatedly([this]() { int batch = RandomDim(1), features = RandomDim(1); WindowedSpatialDims d = ChooseWindowedSpatialDims(3); - std::vector input_dims = + std::vector input_dims = AsInt32s(ImageDims(FORMAT_NHWC, batch, features, d.input_dims)); std::vector output_dims = ImageDims(FORMAT_NHWC, batch, features, d.output_dims); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("AvgPool3DGrad") - .Input(test::AsTensor(input_dims)) + .Input(test::AsTensor(input_dims)) .RandomInput(DT_FLOAT, output_dims) .Attr("T", DT_FLOAT) .Attr("ksize", ImageDims(FORMAT_NHWC, 1, 1, d.kernel_dims)) @@ -1976,8 +1978,8 @@ TEST_F(OpTest, BatchToSpaceND) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("BatchToSpaceND") .RandomInput(type, input_dims) - .Input(test::AsTensor( - std::vector(block_dims.begin(), block_dims.end()))) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) .Input(crops) .Attr("T", type)); }); @@ -2202,15 +2204,15 @@ TEST_F(OpTest, ConcatOffset) { std::vector dims = RandomDims(1); int concat_dim = - std::uniform_int_distribution(0, dims.size() - 1)(generator()); + std::uniform_int_distribution(0, dims.size() - 1)(generator()); OpTestBuilder builder("ConcatOffset"); - builder.Input(test::AsScalar(concat_dim)); + builder.Input(test::AsScalar(concat_dim)); builder.Attr("N", n); for (int i = 0; i < n; ++i) { - std::vector shape(dims.begin(), dims.end()); + std::vector shape(dims.begin(), dims.end()); shape[concat_dim] = RandomDim(); - builder.Input(test::AsTensor(shape)); + builder.Input(test::AsTensor(shape)); } return ExpectTfAndXlaOutputsAreClose(builder); }); @@ -2284,7 +2286,8 @@ TEST_F(OpTest, IFFT3D) { TEST_F(OpTest, RFFT) { Repeatedly([this]() { std::vector dims = RandomDims(1, kDefaultMaxRank, 3); - Tensor fft_shape = test::AsTensor(AsInt32s({dims[dims.size() - 1]})); + Tensor fft_shape = + test::AsTensor(AsInt32s({dims[dims.size() - 1]})); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape)); }); @@ -2293,7 +2296,7 @@ TEST_F(OpTest, RFFT) { TEST_F(OpTest, RFFT2D) { Repeatedly([this]() { std::vector dims = RandomDims(2, kDefaultMaxRank, 3); - Tensor fft_shape = test::AsTensor( + Tensor fft_shape = test::AsTensor( AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]})); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); @@ -2303,7 +2306,7 @@ TEST_F(OpTest, RFFT2D) { TEST_F(OpTest, RFFT3D) { Repeatedly([this]() { std::vector dims = RandomDims(3, kDefaultMaxRank, 3); - Tensor fft_shape = test::AsTensor(AsInt32s( + Tensor fft_shape = test::AsTensor(AsInt32s( {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]})); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); @@ -2315,7 +2318,7 @@ TEST_F(OpTest, IRFFT) { std::vector dims = RandomDims(1, kDefaultMaxRank, 3); int64_t orig_size = dims[dims.size() - 1]; dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; - Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT") .RandomInput(DT_COMPLEX64, dims) .Input(fft_shape)); @@ -2328,7 +2331,7 @@ TEST_F(OpTest, IRFFT2D) { std::vector orig_size = {dims[dims.size() - 2], dims[dims.size() - 1]}; dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; - Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D") .RandomInput(DT_COMPLEX64, dims) .Input(fft_shape)); @@ -2341,7 +2344,7 @@ TEST_F(OpTest, IRFFT3D) { std::vector orig_size = { dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}; dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; - Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D") .RandomInput(DT_COMPLEX64, dims) .Input(fft_shape)); @@ -2387,7 +2390,7 @@ TEST_F(OpTest, Conv2DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); - Tensor kernel_shape = test::AsTensor(AsInt32s( + Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( @@ -2409,7 +2412,7 @@ TEST_F(OpTest, Conv2DBackpropInput) { int features_in = random_int(generator()); int features_out = random_int(generator()); int32_t batch = RandomDim(); - Tensor in_shape = test::AsTensor( + Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); @@ -2465,7 +2468,7 @@ TEST_F(OpTest, Conv3DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); - Tensor kernel_shape = test::AsTensor( + Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); DataType type = DT_FLOAT; @@ -2489,7 +2492,7 @@ TEST_F(OpTest, Conv3DBackpropInput) { int features_in = random_int(generator()); int features_out = random_int(generator()); int32_t batch = RandomDim(1); - Tensor in_shape = test::AsTensor( + Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); std::vector backprop = ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); @@ -2587,7 +2590,7 @@ TEST_F(OpTest, DepthwiseConv2DNativeBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims); std::vector backprop = ImageDims( FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); - Tensor kernel_shape = test::AsTensor(AsInt32s( + Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier})); std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); strides[2] = strides[1]; // Current impl only supports equal strides @@ -2612,7 +2615,7 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) { int features_in = random_int(generator()); int depth_multiplier = random_int(generator()); int32_t batch = RandomDim(); - Tensor in_shape = test::AsTensor( + Tensor in_shape = test::AsTensor( AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims))); std::vector backprop = ImageDims( FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); @@ -2717,15 +2720,15 @@ TEST_F(OpTest, DynamicStitch) { // implementation does so require. However, the native TF implementation // leaves undefined values if we don't cover everything, so we can't // really test that case anyway. - std::vector indices(size); + std::vector indices(size); std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), generator()); int pos = 0; for (int i = 0; i < n; ++i) { TensorShape shape(index_dims[i]); - Tensor t = test::AsTensor( - absl::Span(indices).subspan(pos, shape.num_elements()), + Tensor t = test::AsTensor( + absl::Span(indices).subspan(pos, shape.num_elements()), shape); builder.Input(t); pos += t.NumElements(); @@ -2785,8 +2788,8 @@ TEST_F(OpTest, EluGrad) { TEST_F(OpTest, ScatterNd) { Repeatedly([this]() { auto a = ChooseScatterArguments(); - auto shape = test::AsTensor( - std::vector(a.shape.begin(), a.shape.end())); + auto shape = test::AsTensor( + std::vector(a.shape.begin(), a.shape.end())); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ScatterNd") .Input(a.indices) .Input(a.updates) @@ -2859,8 +2862,9 @@ TEST_F(OpTest, ExpandDims) { auto type = Choose(kAllXlaTypes); std::vector in_dims = RandomDims(); Tensor dim(DT_INT32, TensorShape()); - std::uniform_int_distribution d(-1 - in_dims.size(), in_dims.size()); - dim.scalar()() = d(generator()); + std::uniform_int_distribution d(-1 - in_dims.size(), + in_dims.size()); + dim.scalar()() = d(generator()); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ExpandDims") .RandomInput(type, in_dims) .Input(dim) @@ -2872,10 +2876,10 @@ TEST_F(OpTest, Fill) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); - std::vector shape(dims.begin(), dims.end()); + std::vector shape(dims.begin(), dims.end()); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Fill") - .Input(test::AsTensor(shape)) + .Input(test::AsTensor(shape)) .RandomInput(type, {}) .Attr("T", type)); }); @@ -2953,9 +2957,9 @@ TEST_F(OpTest, GatherNd) { std::vector output_shape(output_outer_shape); output_shape.push_back(index_len); Tensor lo(indices_type, TensorShape(output_shape)); - test::FillFn(&lo, [](int i) -> int32 { return 0; }); + test::FillFn(&lo, [](int i) -> int32_t { return 0; }); Tensor hi(indices_type, TensorShape(output_shape)); - test::FillFn(&hi, [index_len, ¶ms_shape](int i) -> int32 { + test::FillFn(&hi, [index_len, ¶ms_shape](int i) -> int32_t { int idx_dim = i % index_len; return params_shape[idx_dim] - 1; }); @@ -3020,7 +3024,7 @@ TEST_F(OpTest, InplaceUpdate) { x_dims.insert(x_dims.end(), common_dims.begin(), common_dims.end()); std::vector i_shape{v_dims[0]}; Tensor i = - RandomBoundedTensor(DT_INT32, 0, x_dims[0] - 1, true, i_shape); + RandomBoundedTensor(DT_INT32, 0, x_dims[0] - 1, true, i_shape); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("InplaceUpdate") .RandomInput(type, x_dims) .Input(i) @@ -3050,7 +3054,7 @@ TEST_F(OpTest, InvertPermutation) { // TODO(b/211012712): Once needs_unique_values case is linear instead of // quadratic time, use default Dim max instead of 8. int64_t len = RandomDim(0, 8); - Tensor x = RandomBoundedTensor(DT_INT32, 0, len - 1, true, {len}); + Tensor x = RandomBoundedTensor(DT_INT32, 0, len - 1, true, {len}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("InvertPermutation").Input(x).Attr("T", DT_INT32)); }); @@ -3155,7 +3159,7 @@ TEST_F(OpTest, Lgamma) { TEST_F(OpTest, LinSpace) { Repeatedly([this]() { auto ToScalar = [](DataType type, int x) { - if (type == DT_INT32) return test::AsScalar(x); + if (type == DT_INT32) return test::AsScalar(x); return test::AsScalar(x); }; std::uniform_int_distribution distribution(-50, 50); @@ -3294,11 +3298,11 @@ TEST_F(OpTest, MatrixBandPart) { auto type = Choose(kAllXlaTypes); auto index_type = Choose({DT_INT32, DT_INT64}); auto num_lower = - RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, - 2 * kDefaultMaxDimensionSize, false, {}); + RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, + 2 * kDefaultMaxDimensionSize, false, {}); auto num_upper = - RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, - 2 * kDefaultMaxDimensionSize, false, {}); + RandomBoundedTensor(index_type, -2 * kDefaultMaxDimensionSize, + 2 * kDefaultMaxDimensionSize, false, {}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixBandPart") .RandomInput(type) .Input(num_lower) @@ -3334,12 +3338,12 @@ TEST_F(OpTest, MatrixDiagPartV3) { auto type = Choose(kAllXlaTypes); auto align = Choose( {"LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT"}); - auto k0 = std::uniform_int_distribution( + auto k0 = std::uniform_int_distribution( -2 * kDefaultMaxDimensionSize, 2 * kDefaultMaxDimensionSize)(generator()); - auto k1 = std::uniform_int_distribution( + auto k1 = std::uniform_int_distribution( k0, 2 * kDefaultMaxDimensionSize)(generator()); - auto k = test::AsTensor({k0, k1}); + auto k = test::AsTensor({k0, k1}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPartV3") .RandomInput(type) .Input(k) @@ -3373,10 +3377,10 @@ TEST_F(OpTest, MatrixSetDiagV2) { int64_t max_num_diags = shape[rank - 2] + shape[rank - 1] - 1; int64_t num_diags = std::uniform_int_distribution(2, max_num_diags)(generator()); - int32 k0 = std::uniform_int_distribution( + int32_t k0 = std::uniform_int_distribution( -shape[rank - 2] + 1, shape[rank - 1] - num_diags)(generator()); - int32 k1 = k0 + num_diags - 1; - Tensor k = test::AsTensor({k0, k1}); + int32_t k1 = k0 + num_diags - 1; + Tensor k = test::AsTensor({k0, k1}); int64_t max_diag_len = std::min(shape[rank - 2] + std::min(k1, 0), shape[rank - 1] + std::min(-k0, 0)); std::vector diagonal_shape(shape); @@ -3428,7 +3432,7 @@ TEST_F(OpTest, MaxPool) { int stride_rows = random_int(generator()), stride_cols = random_int(generator()); - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("MaxPool") .RandomInput(DT_FLOAT, dims) @@ -3462,7 +3466,7 @@ TEST_F(OpTest, MaxPool3D) { int64_t batch = dims[3]; int64_t feature = dims[4]; - string padding = Choose({"SAME", "VALID"}); + std::string padding = Choose({"SAME", "VALID"}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("MaxPool3D") .RandomInput(DT_FLOAT, @@ -3589,20 +3593,20 @@ TEST_F(OpTest, OneHot) { int32_t depth = RandomDim(); Tensor indices(DT_INT32, TensorShape(dims)); - std::uniform_int_distribution distribution(-depth * 2, depth * 2); - test::FillFn(&indices, [this, &distribution](int i) -> int32 { + std::uniform_int_distribution distribution(-depth * 2, depth * 2); + test::FillFn(&indices, [this, &distribution](int i) -> int32_t { return distribution(generator()); }); - int axis = std::uniform_int_distribution(-num_dims - 5, - num_dims + 5)(generator()); + int axis = std::uniform_int_distribution( + -num_dims - 5, num_dims + 5)(generator()); OpTestBuilder builder("OneHot"); builder.Attr("T", type); builder.Attr("TI", DT_INT32); builder.Attr("axis", axis); builder.Input(indices); - builder.Input(test::AsScalar(depth)); + builder.Input(test::AsScalar(depth)); builder.RandomInput(type, {}); builder.RandomInput(type, {}); return ExpectTfAndXlaOutputsAreClose(builder); @@ -3625,8 +3629,8 @@ TEST_F(OpTest, Pack) { std::vector dims = RandomDims(); int num_dims = dims.size(); - int axis = std::uniform_int_distribution(-num_dims - 1, - num_dims)(generator()); + int axis = std::uniform_int_distribution(-num_dims - 1, + num_dims)(generator()); OpTestBuilder builder("Pack"); builder.Attr("T", type); @@ -3768,7 +3772,7 @@ TEST_F(OpTest, RandomUniform) { TEST_F(OpTest, Range) { Repeatedly([this]() { auto ToScalar = [](DataType type, int x) { - if (type == DT_INT32) return test::AsScalar(x); + if (type == DT_INT32) return test::AsScalar(x); if (type == DT_INT64) return test::AsScalar(x); if (type == DT_FLOAT) return test::AsScalar(x); if (type == DT_DOUBLE) return test::AsScalar(x); @@ -3885,8 +3889,8 @@ TEST_F(OpTest, Reshape) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Reshape") .RandomInput(type, dims_before) - .Input(test::AsTensor( - std::vector(dims_after.begin(), dims_after.end()))) + .Input(test::AsTensor( + std::vector(dims_after.begin(), dims_after.end()))) .Attr("T", type)); }); } @@ -3912,8 +3916,8 @@ TEST_F(OpTest, ResizeBilinear) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ResizeBilinear") .RandomInput(DT_FLOAT, in_dims) - .Input(test::AsTensor( - std::vector(out_dims.begin(), out_dims.end()))) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) .Attr("T", DT_FLOAT) .Attr("align_corners", true)); }); @@ -3965,14 +3969,14 @@ TEST_F(OpTest, ReverseSequence) { int batch_size = dims[batch_dim]; int max_seq_len = dims[seq_dim]; - std::vector seq_lens(batch_size); - std::uniform_int_distribution d(0, max_seq_len); + std::vector seq_lens(batch_size); + std::uniform_int_distribution d(0, max_seq_len); absl::c_generate(seq_lens, [&]() { return d(generator()); }); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ReverseSequence") .RandomInput(type, dims) - .Input(test::AsTensor(seq_lens)) + .Input(test::AsTensor(seq_lens)) .Attr("seq_dim", seq_dim) .Attr("batch_dim", batch_dim) .Attr("T", type) @@ -4161,14 +4165,15 @@ TEST_F(OpTest, Size) { TEST_F(OpTest, Slice) { Repeatedly([this]() { SliceArguments a = ChooseSliceArguments(true); - std::vector size; + std::vector size; size.insert(size.end(), a.size.begin(), a.size.end()); - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice") - .RandomInput(a.type, a.shape) - .Input(a.indices) - .Input(test::AsTensor(size)) - .Attr("T", a.type) - .Attr("Index", a.indices_type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Slice") + .RandomInput(a.type, a.shape) + .Input(a.indices) + .Input(test::AsTensor(size)) + .Attr("T", a.type) + .Attr("Index", a.indices_type)); }); } @@ -4302,8 +4307,8 @@ TEST_F(OpTest, SpaceToBatchND) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SpaceToBatchND") .RandomInput(type, input_dims) - .Input(test::AsTensor( - std::vector(block_dims.begin(), block_dims.end()))) + .Input(test::AsTensor( + std::vector(block_dims.begin(), block_dims.end()))) .Input(paddings) .Attr("T", type)); }); @@ -4360,16 +4365,16 @@ TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) { int64_t batch_size = dims[0]; int64_t num_classes = dims[1]; - std::vector indices(batch_size); + std::vector indices(batch_size); for (int64_t i = 0; i < batch_size; ++i) { - indices[i] = - std::uniform_int_distribution(0, num_classes - 1)(generator()); + indices[i] = std::uniform_int_distribution( + 0, num_classes - 1)(generator()); } return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SparseSoftmaxCrossEntropyWithLogits") .RandomInput(DT_FLOAT, dims) - .Input(test::AsTensor(indices)) + .Input(test::AsTensor(indices)) .Attr("T", DT_FLOAT) .Attr("Tlabels", DT_INT32)); }); @@ -4383,18 +4388,19 @@ TEST_F(OpTest, Split) { auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(1); std::uniform_int_distribution ud; - int32_t dim = std::uniform_int_distribution( - -static_cast(dims.size()), - static_cast(dims.size()) - 1)(generator()); + int32_t dim = std::uniform_int_distribution( + -static_cast(dims.size()), + static_cast(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution(1, 5)(generator()); // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; dims[dim] *= n; - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split") - .Input(test::AsScalar(dim)) - .RandomInput(type, dims) - .Attr("T", type) - .Attr("num_split", n)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Split") + .Input(test::AsScalar(dim)) + .RandomInput(type, dims) + .Attr("T", type) + .Attr("num_split", n)); }); } @@ -4405,12 +4411,12 @@ TEST_F(OpTest, SplitV) { Repeatedly([this]() { // NOLINT: due to GTEST_SKIP auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(1, kDefaultMaxRank, 1); - int32_t dim = std::uniform_int_distribution( - -static_cast(dims.size()), - static_cast(dims.size()) - 1)(generator()); + int32_t dim = std::uniform_int_distribution( + -static_cast(dims.size()), + static_cast(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution( 1, std::min(5, static_cast(dims[dim])))(generator()); - std::vector size_splits(n); + std::vector size_splits(n); for (int i = 0; i < n - 1; ++i) { size_splits.push_back(dims[dim] / n); } @@ -4418,8 +4424,8 @@ TEST_F(OpTest, SplitV) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SplitV") .RandomInput(type, dims) - .Input(test::AsTensor(size_splits)) - .Input(test::AsScalar(dim)) + .Input(test::AsTensor(size_splits)) + .Input(test::AsScalar(dim)) .Attr("T", type) .Attr("num_split", n) .Attr("Tlen", DT_INT32)); @@ -4519,12 +4525,12 @@ TEST_F(OpTest, StridedSlice) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); - std::vector begin(data_dims.size()), end(data_dims.size()); - std::vector strides(data_dims.size()); + std::vector begin(data_dims.size()), end(data_dims.size()); + std::vector strides(data_dims.size()); for (int i = 0; i < data_dims.size(); ++i) { - begin[i] = std::uniform_int_distribution( + begin[i] = std::uniform_int_distribution( -2 * data_dims[i], 2 * data_dims[i])(generator()); - end[i] = std::uniform_int_distribution( + end[i] = std::uniform_int_distribution( -2 * data_dims[i], 2 * data_dims[i])(generator()); // TODO(b/31360685): support strides other than 1 or -1 strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1; @@ -4547,9 +4553,9 @@ TEST_F(OpTest, StridedSlice) { return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("StridedSlice") .RandomInput(type, data_dims) - .Input(test::AsTensor(begin)) - .Input(test::AsTensor(end)) - .Input(test::AsTensor(strides)) + .Input(test::AsTensor(begin)) + .Input(test::AsTensor(end)) + .Input(test::AsTensor(strides)) .Attr("T", type) .Attr("Index", DT_INT32) .Attr("begin_mask", begin_mask) @@ -4660,14 +4666,14 @@ TEST_F(OpTest, Tile) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(1); - std::vector multiples(t_dims.size()); + std::vector multiples(t_dims.size()); for (int i = 0; i < t_dims.size(); ++i) { multiples[i] = std::uniform_int_distribution(1, 3)(generator()); } return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Tile") .RandomInput(type, t_dims) - .Input(test::AsTensor(multiples)) + .Input(test::AsTensor(multiples)) .Attr("T", type)); }); } @@ -4678,10 +4684,11 @@ TEST_F(OpTest, TopKV2) { Repeatedly([this]() { // NOLINT: due to GTEST_SKIP auto type = Choose({DT_INT32, DT_FLOAT, DT_INT64}); auto shape = RandomDims(1); - int32 k = std::uniform_int_distribution(1, shape[0])(generator()); + int32_t k = + std::uniform_int_distribution(1, shape[0])(generator()); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TopKV2") .RandomInput(type, shape) - .Input(test::AsScalar(k)) + .Input(test::AsScalar(k)) .Attr("sorted", RandomBool()) .Attr("T", type)); }); @@ -4691,13 +4698,14 @@ TEST_F(OpTest, Transpose) { Repeatedly([this]() { auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); - std::vector perm(data_dims.size()); + std::vector perm(data_dims.size()); std::iota(perm.begin(), perm.end(), 0); std::shuffle(perm.begin(), perm.end(), generator()); - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose") - .RandomInput(type, data_dims) - .Input(test::AsTensor(perm)) - .Attr("T", type)); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Transpose") + .RandomInput(type, data_dims) + .Input(test::AsTensor(perm)) + .Attr("T", type)); }); } @@ -4887,8 +4895,8 @@ TEST_F(OpTest, FusedBatchNormTraining) { } // namespace tensorflow int main(int argc, char** argv) { - tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU:0"); - tensorflow::tf_xla_reference_device_ptr = new tensorflow::string("CPU:0"); + tensorflow::tf_xla_test_device_ptr = new std::string("GPU:0"); + tensorflow::tf_xla_reference_device_ptr = new std::string("CPU:0"); std::vector flag_list = { tensorflow::Flag( "tf_xla_random_seed", &tensorflow::tf_xla_random_seed, @@ -4913,7 +4921,7 @@ int main(int argc, char** argv) { "tf_xla_test_use_mlir", &tensorflow::tf_xla_test_use_mlir, "Use MLIR legalization kernels for the operator under test"), }; - tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { LOG(ERROR) << "\n" << usage; diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 5b41a8108573ac..938277324f1de6 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -149,8 +149,6 @@ def testSimple3(self): expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testVariableRankUpdate(self): self._VariableRankTests(_NumpyUpdate, self._runScatterNd) diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 61b9b2c25f0291..36a1fe43db8109 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -78,8 +78,6 @@ def _unsortedSegmentMax(self, data, indices, num_segments): return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, num_segments) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testSegmentSum(self): for dtype in self.numeric_types: self.assertAllClose( @@ -88,8 +86,6 @@ def testSegmentSum(self): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 4)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 def testSegmentProd(self): for dtype in self.numeric_types: self.assertAllClose( @@ -98,8 +94,6 @@ def testSegmentProd(self): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 4)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 def testSegmentProdNumSegmentsLess(self): for dtype in self.numeric_types: self.assertAllClose( @@ -108,8 +102,6 @@ def testSegmentProdNumSegmentsLess(self): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 3)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 def testSegmentProdNumSegmentsMore(self): for dtype in self.numeric_types: self.assertAllClose( @@ -194,8 +186,6 @@ def testUnsortedSegmentSum0DIndices1DData(self): self._unsortedSegmentSum( np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testUnsortedSegmentSum1DIndices1DData(self): for dtype in self.numeric_types: self.assertAllClose( @@ -204,8 +194,6 @@ def testUnsortedSegmentSum1DIndices1DData(self): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): for dtype in self.numeric_types: self.assertAllClose( @@ -214,8 +202,6 @@ def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): for dtype in self.numeric_types: data = np.array( @@ -232,8 +218,6 @@ def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): [50, 51, 52, 53], [0, 1, 2, 3], [0, 0, 0, 0]], dtype=dtype), y) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self): for dtype in self.numeric_types: data = np.array( @@ -249,8 +233,6 @@ def testUnsortedSegmentSum1DIndices2DDataNonDisjoint(self): [0, 0, 0, 0]], dtype=dtype), y) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( @@ -268,8 +250,6 @@ def testUnsortedSegmentSum2DIndices3DData(self): ], [0, 0, 0.], [90, 92, 94], [103, 104, 105], [0, 0, 0]], dtype=dtype), y) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 25-05-14 def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( @@ -298,8 +278,6 @@ def testUnsortedSegmentSumShapeError(self): math_ops.unsorted_segment_sum, data, indices, num_segments)) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self): """Tests for min, max, and prod ops. diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc index 641af606bb24d1..c27b8070bbb450 100644 --- a/tensorflow/compiler/tests/unary_ops_composition_test.cc +++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc @@ -48,9 +48,9 @@ static bool Initialized = [] { class UnaryOpsCompositionTest : public OpsTestBase { protected: template - void RunComposedOp(const std::vector op_names, T input_scalar_value, - T expected_scalar_value) { - string xla_device_name = + void RunComposedOp(const std::vector op_names, + T input_scalar_value, T expected_scalar_value) { + std::string xla_device_name = tensorflow::IsGoogleCudaEnabled() ? DEVICE_XLA_GPU : DEVICE_XLA_CPU; SetDevice(DeviceType(xla_device_name), std::unique_ptr(DeviceFactory::NewDevice( diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 7e9069dacfbaca..037560a142998d 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -215,8 +215,6 @@ def testCos(self): math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5 ) - @test.disable_with_predicate( - pred=test.is_built_with_rocm, skip_message="Test fails on ROCm.") #TODO(rocm): weekly sync 24-11-05 def testSigmoidNumericalStability(self): for dtype in self.float_types: if dtype != np.float16: diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 20f93d86adfad1..d642418a44c2f5 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -308,7 +308,8 @@ def device_scope(self): yield def assert_op_output_matches_expected( - self, op, inp, expected, equality_test=None, rtol=1e-3, atol=1e-5 + self, op, inp, expected, local_session, + equality_test=None, rtol=1e-3, atol=1e-5 ): """Verifies that 'op' produces 'expected' when fed input 'inp' . @@ -316,25 +317,25 @@ def assert_op_output_matches_expected( op: operator to test inp: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. + local_session: The session to use for the test. equality_test: either None, or a function that tests two numpy arrays for equality. If None, self.assertAllClose is used. rtol: relative tolerance for equality test. atol: absolute tolerance for equality test. """ - with self.session() as local_session: - with self.test_scope(): - pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name='a' - ) - output = op(pinp) - result = local_session.run(output, {pinp: inp}) - if equality_test is None: - self.assertEqual(output.dtype, expected.dtype) - self.assertAllCloseAccordingToType( - expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03 - ) - else: - equality_test(result, expected, rtol=rtol, atol=atol) + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name='a' + ) + output = op(pinp) + result = local_session.run(output, {pinp: inp}) + if equality_test is None: + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03 + ) + else: + equality_test(result, expected, rtol=rtol, atol=atol) def test_scope(self): """Deprecated alias of `device_scope`. diff --git a/tensorflow/compiler/tf2tensorrt/common/datavec.h b/tensorflow/compiler/tf2tensorrt/common/datavec.h index eff32f1f521af4..34b419d1d20d62 100644 --- a/tensorflow/compiler/tf2tensorrt/common/datavec.h +++ b/tensorflow/compiler/tf2tensorrt/common/datavec.h @@ -27,7 +27,7 @@ namespace tensorrt { // Input/output data format for OpConverterTest::BuildAndRun(). struct InputOutputData { size_t TotalBytes() const { return tensor.TotalBytes(); } - string name; + std::string name; Tensor tensor; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc index c8eb3db2e0b9e4..b4c3052953c677 100755 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/einsum.cc @@ -739,16 +739,16 @@ class ReIndexer { // Initializes the index map with existing lowercase labels. ReIndexer(std::string eq) { for (char c : eq) { - if (islower(c)) { + if (absl::ascii_islower(c)) { idx_map_[c] = c; } } } // Finds new character for uppercase character c. char operator()(char c) { - if (!std::isupper(c)) return c; + if (!absl::ascii_isupper(c)) return c; if (idx_map_.count(c) > 0) return idx_map_[c]; - char new_idx = std::tolower(c); + char new_idx = absl::ascii_tolower(c); // If lower(c) is not used in the equation, use it to replace c. if (idx_map_.count(new_idx) == 0) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc index faedcf3de8c427..000c32df25d253 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.cc @@ -81,9 +81,7 @@ string ProfileStrategyToName(const ProfileStrategy strategy) { } Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) { - string name_lowercase(name); - std::transform(name.begin(), name.end(), name_lowercase.begin(), - [](unsigned char c) { return std::tolower(c); }); + std::string name_lowercase = absl::AsciiStrToLower(name); if (name_lowercase == "range") { *strategy = ProfileStrategy::kRange; } else if (name_lowercase == "optimal") { diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc index 30aff91a76d3b1..d1bf00a53d1cc3 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -99,7 +99,7 @@ string TRTEngineCacheResource::DebugString() const { EngineContext* TRTEngineCacheResource::GetEngineContext( const std::vector& input_shapes) { EngineContext* engine_context = nullptr; - int64 min_matched_batch_size = kint64max; + int64 min_matched_batch_size = std::numeric_limits::max(); for (const auto& pair : cache_) { const std::vector& cached_input_shapes = pair.first; // This should not happen, but just for safety. diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 254a1e85c35192..e5545445817ec2 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -138,6 +138,25 @@ cc_library( ], ) +cc_library( + name = "encoded_buffer_allocation_info", + hdrs = ["encoded_buffer_allocation_info.h"], + visibility = [":friends"], + deps = [ + "@local_xla//xla/backends/cpu:buffer_allocation_info", + ], +) + +tf_cc_test( + name = "encoded_buffer_allocation_info_test", + srcs = ["encoded_buffer_allocation_info_test.cc"], + deps = [ + ":encoded_buffer_allocation_info", + "@com_google_googletest//:gtest_main", + "@local_xla//xla/backends/cpu:buffer_allocation_info", + ], +) + cc_library( name = "tf2xla", srcs = ["tf2xla.cc"], @@ -218,6 +237,7 @@ filegroup( name = "xla_compiled_cpu_runtime_hdrs", srcs = [ "allocator.h", + "encoded_buffer_allocation_info.h", "xla_compiled_cpu_function.h", "//tensorflow/core/kernels:xla_cpu_runtime_hdrs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", @@ -355,6 +375,7 @@ cc_library( # "@local_tsl//tsl/platform:context", # "@local_tsl//tsl/platform:cord", # "@local_tsl//tsl/platform:env_time", +# "@local_tsl//tsl/platform:refcount", # "@local_tsl//tsl/platform:ml_dtypes", # "@local_tsl//tsl/platform:logging", # "@local_tsl//tsl/platform:macros", @@ -437,6 +458,7 @@ cc_library( ":allocator", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + ":encoded_buffer_allocation_info", "@local_xla//xla/service:custom_call_status_internal", "@local_xla//xla/backends/cpu/runtime:rng_state_lib", "@local_xla//xla/backends/cpu:alignment", @@ -502,6 +524,7 @@ cc_library( hdrs = ["xla_jit_compiled_cpu_function.h"], visibility = ["//visibility:public"], deps = [ + ":encoded_buffer_allocation_info", ":tf2xla", ":tf2xla_proto_cc", ":xla_compiled_cpu_function", diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index c7c8702b49b774..d9f6927c09ecd6 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -180,7 +180,7 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_0) { // not need to be a constant. Output reshape = ops::Reshape(root, arg1, add); reshape.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, - std::vector()); + std::vector()); Graph graph(OpRegistry::Global()); TF_ASSERT_OK(root.ToGraph(&graph)); @@ -203,7 +203,7 @@ TEST(ConstAnalysisTest, RespectExplicitAttr_1) { // Force const analysis to pretend that the first argument to `add` needs to // be a constant. - std::vector add_constant_inputs; + std::vector add_constant_inputs; add_constant_inputs.push_back("x"); add.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, add_constant_inputs); diff --git a/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h new file mode 100644 index 00000000000000..5981751259967a --- /dev/null +++ b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h @@ -0,0 +1,99 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_ +#define TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_ + +#include + +#include "xla/backends/cpu/buffer_allocation_info.h" + +namespace xla { +namespace cpu { + +// Encoded version of `BufferAllocationInfo`, which can be used to reconstruct +// the `BufferAllocationInfo` later. It's used in the AOT compiler, to +// represent buffer allocation info as a lightweight struct. +struct EncodedBufferAllocationInfo { + EncodedBufferAllocationInfo(uint64_t packed_kind_and_size, + uint32_t entry_param_number, + uint32_t result_number) + : packed_kind_and_size(packed_kind_and_size), + entry_param_number(entry_param_number), + result_number(result_number) {} + + // Encodes BufferAllocationInfo into the struct that can be used to + // reconstruct the BufferAllocationInfo later using the constructor. We need + // this because we use BufferAllocationInfo in places where using protocol + // buffers would negatively impact binary size. + explicit EncodedBufferAllocationInfo( + const BufferAllocationInfo& buffer_info) { + packed_kind_and_size = Pack(buffer_info.kind(), buffer_info.size()); + entry_param_number = buffer_info.is_entry_parameter() + ? buffer_info.entry_parameter_number() + : -1; + result_number = buffer_info.is_result() ? buffer_info.result_number() : -1; + } + + explicit operator BufferAllocationInfo() const { + auto kind = UnpackKind(packed_kind_and_size); + auto size = UnpackSize(packed_kind_and_size); + int32_t entry_param_number = static_cast(this->entry_param_number); + int32_t result_number = static_cast(this->result_number); + + switch (kind) { + case BufferAllocationInfo::Kind::kConstant: + return BufferAllocationInfo::Constant(size); + case BufferAllocationInfo::Kind::kTemp: + return BufferAllocationInfo::Temp(size); + case BufferAllocationInfo::Kind::kParameter: + if (entry_param_number >= 0 && result_number >= 0) { + return BufferAllocationInfo::InOutParameter(size, entry_param_number, + result_number); + } + if (entry_param_number >= 0) { + return BufferAllocationInfo::EntryParameter(size, entry_param_number); + } + return BufferAllocationInfo::Result(size, result_number); + case BufferAllocationInfo::Kind::kThreadLocal: + return BufferAllocationInfo::ThreadLocal(size); + } + } + + static uint64_t Pack(BufferAllocationInfo::Kind kind, uint64_t size) { + return (static_cast(size) << 2) | static_cast(kind); + } + + static constexpr BufferAllocationInfo::Kind UnpackKind(uint64_t packed) { + return static_cast((packed << 62) >> 62); + } + + static constexpr uint64_t UnpackSize(uint64_t packed) { return packed >> 2; } + + uint64_t packed_kind_and_size = 0; + uint32_t entry_param_number = -1; + uint32_t result_number = -1; +}; +} // namespace cpu + +// TODO(ezhulenev): This is a temporary hack to keep `tfcompile` code working. +namespace cpu_function_runtime { +using BufferInfo = ::xla::cpu::BufferAllocationInfo; +using EncodedBufferInfo = ::xla::cpu::EncodedBufferAllocationInfo; +} // namespace cpu_function_runtime + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_ diff --git a/third_party/xla/xla/backends/cpu/buffer_allocation_info_test.cc b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info_test.cc similarity index 88% rename from third_party/xla/xla/backends/cpu/buffer_allocation_info_test.cc rename to tensorflow/compiler/tf2xla/encoded_buffer_allocation_info_test.cc index 3848bb6c4db313..c9fc52100abb33 100644 --- a/third_party/xla/xla/backends/cpu/buffer_allocation_info_test.cc +++ b/tensorflow/compiler/tf2xla/encoded_buffer_allocation_info_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2025 The OpenXLA Authors. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/backends/cpu/buffer_allocation_info.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include +#include "xla/backends/cpu/buffer_allocation_info.h" namespace xla::cpu { namespace { -TEST(BufferAllocationInfoTest, RoundTrip) { +TEST(EncodedBufferAllocationInfoTest, RoundTrip) { auto round_trip = [](const BufferAllocationInfo& buffer_info) { EncodedBufferAllocationInfo encoded(buffer_info); BufferAllocationInfo round_trip(encoded); diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index ba297127eae117..2adc83512c6617 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -83,11 +83,11 @@ struct ClusterTupleLessThan { }; // TODO(jpienaar): Move to OutputTensor. -string DebugString(const OutputTensor& tensor) { +std::string DebugString(const OutputTensor& tensor) { return absl::StrCat(tensor.node->name(), ":", tensor.index); } -string Branch_Name(BranchType b) { +std::string Branch_Name(BranchType b) { switch (b) { case BranchType::kElseBranch: return "else"; @@ -100,13 +100,13 @@ string Branch_Name(BranchType b) { } } -string DebugString(StateMap::CondId cond_state) { +std::string DebugString(StateMap::CondId cond_state) { if (cond_state == nullptr || cond_state->empty()) return "{}"; using value_type = StateMap::CondState::value_type; return absl::StrCat( "{", absl::StrJoin(*cond_state, ", ", - [](string* output, const value_type& pred_branch) { + [](std::string* output, const value_type& pred_branch) { const OutputTensor& pred = pred_branch.first; const BranchType& branch = pred_branch.second; if (branch == BranchType::kNeither) @@ -200,7 +200,7 @@ struct CondArgNode { explicit CondArgNode(Node* src, int src_output) : src(src), src_output(src_output) {} - string ToString() const { + std::string ToString() const { return absl::StrCat("src=", src->name(), ":", src_output, " switches=", NodesToString(switches)); } @@ -212,11 +212,11 @@ struct CondArgNode { }; using CondArgNodes = std::vector; -string DebugString(const CondArgNodes& nodes) { +std::string DebugString(const CondArgNodes& nodes) { return absl::StrCat( "[", absl::StrJoin(nodes, ", ", - [](string* output, const CondArgNode& node) { + [](std::string* output, const CondArgNode& node) { absl::StrAppend(output, node.ToString()); }), "]"); @@ -263,20 +263,20 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } -string StateMap::CondStateToString(const Node* node) const { +std::string StateMap::CondStateToString(const Node* node) const { return CondStateToString(LookupCondId(node)); } -string StateMap::CondStateToString(StateMap::CondId id) const { +std::string StateMap::CondStateToString(StateMap::CondId id) const { return DebugString(id); } -string StateMap::AncestorStateToString(const Node* node) const { +std::string StateMap::AncestorStateToString(const Node* node) const { if (auto id = LookupAncestorId(node)) { return absl::StrCat( "{", absl::StrJoin(*id, ",", - [](string* output, const AncestorNode& ancestor) { + [](std::string* output, const AncestorNode& ancestor) { absl::StrAppend(output, ancestor.output_tensor.node->name(), ":", ancestor.output_tensor.index); @@ -340,7 +340,7 @@ class Conditional { // Internal name of conditional. The name is based on the first merge node // added. - string name() const; + std::string name() const; // The FunctionalizeCond instance that created this. FunctionalizeCond* parent_; @@ -751,7 +751,7 @@ absl::Status Conditional::BuildIfNode(Graph* graph, VLOG(2) << "Build cond function for " << name(); NodeDebugInfo debug_info((*merges_.begin())->def()); NodeDefBuilder builder(name(), "If", library, &debug_info); - const string branch_name[] = {"else_branch", "then_branch"}; + const std::string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -817,7 +817,7 @@ absl::Status Conditional::BuildIfNode(Graph* graph, builder.Attr("Tcond", DT_BOOL); // Add some internal attributes which need to be propagated. for (absl::string_view attr_name : kAttrsToPropagate) { - string attr_val; + std::string attr_val; if (GetNodeAttr(predicate_.node->def(), attr_name, &attr_val).ok()) { builder.Attr(attr_name, attr_val); } @@ -949,7 +949,7 @@ absl::Status Conditional::BuildAndReplace( return absl::OkStatus(); } -string Conditional::name() const { +std::string Conditional::name() const { CHECK(!merges_.empty()); return absl::StrCat((*merges_.begin())->name(), "_if"); } @@ -958,7 +958,7 @@ absl::Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, int port) { NodeBuilder id_builder(replacee->name(), "Identity"); id_builder.Input(if_node, port); - string outside_compilation; + std::string outside_compilation; if (GetNodeAttr(if_node->def(), kXlaOutsideCompilationAttr, &outside_compilation) .ok()) { @@ -1580,7 +1580,7 @@ absl::Status FunctionalizeCond::FunctionalizeInternal() { return absl::OkStatus(); } -void FunctionalizeCond::DumpGraphWithCondState(const string& name) { +void FunctionalizeCond::DumpGraphWithCondState(const std::string& name) { const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup"; for (Node* n : graph_->nodes()) { diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index e37555b053d7ed..25d773ad50a105 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -136,11 +136,11 @@ class StateMap { BranchType FindBranchOf(CondId id, OutputTensor predicate) const; // Returns textual representation of node's CondState. - string CondStateToString(const Node* node) const; - string CondStateToString(CondId id) const; + std::string CondStateToString(const Node* node) const; + std::string CondStateToString(CondId id) const; // Returns textual representation of node's AncestorState. - string AncestorStateToString(const Node* node) const; + std::string AncestorStateToString(const Node* node) const; // Returns whether the cond state is the dead state. bool IsDead(CondId id) const; @@ -201,7 +201,7 @@ class FunctionalizeCond { absl::Status PropagateUpdatedState(const Node* replacee); // Dump graph with the CondState annotated. - void DumpGraphWithCondState(const string& name); + void DumpGraphWithCondState(const std::string& name); // Adds `switch_id` to the list of Switch node ids. void AddSwitchId(int switch_id); diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index 50bd47ad73e77e..edb2a7e0ea1b33 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -48,7 +48,7 @@ class FunctionalizeCondTest : public ::testing::Test { return fc_->state_map_.GetCondId(state); } - string GetString(const StateMap::StateMap::CondId id) { + std::string GetString(const StateMap::StateMap::CondId id) { return fc_->state_map_.CondStateToString(id); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index ac38725269bfd9..22b9b9187ecd7d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -51,8 +51,9 @@ namespace tensorflow { // Maps function name to // - new function name, if the function body was functionalized // - std::nullopt, if not -using FuncMap = std::map>; -using FuncMapIter = std::map>::const_iterator; +using FuncMap = std::map>; +using FuncMapIter = + std::map>::const_iterator; // Returns whether function has been processed before. bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) { @@ -65,8 +66,8 @@ bool FunctionHasBeenModified(FuncMapIter func_iter) { } // Returns a name for the new functionalized version of a function. -string GetNewFunctionName( - const string& func_name, Node* n, +std::string GetNewFunctionName( + const std::string& func_name, Node* n, AssociatedFunctionInfo::AssociatedFunctionType func_type, FunctionLibraryDefinition* fld) { // For SymbolicGradient, `func_name` is always "SymbolicGradient" which @@ -79,14 +80,15 @@ string GetNewFunctionName( } // Returns name to which a modified function has been mapped. -const string& GetMappedFunctionName(FuncMapIter func_iter) { +const std::string& GetMappedFunctionName(FuncMapIter func_iter) { DCHECK(func_iter->second.has_value()); return func_iter->second.value(); } // Updates `func_map` with function given by `canonicalized_name`. -void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, - const string& new_func_name, bool function_modified) { +void UpdateFunctionMap(FuncMap* func_map, const std::string& canonicalized_name, + const std::string& new_func_name, + bool function_modified) { // If function was modified store its new name, otherwise add empty entry to // record that function has been processed and does not need to be rewritten. (*func_map)[canonicalized_name] = @@ -95,8 +97,9 @@ void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, // Adds new function def to graph's function library if necessary. absl::Status AddFunctionDefToGraphLibrary( - const string& func_name, const AssociatedFunctionInfo& associated_function, - Graph* graph, FunctionLibraryDefinition* fld) { + const std::string& func_name, + const AssociatedFunctionInfo& associated_function, Graph* graph, + FunctionLibraryDefinition* fld) { const OpRegistrationData* op_reg_data; // We have to be careful with adding the function def since there are three // different `OpRegistryInterface`s involved here: @@ -129,8 +132,8 @@ absl::Status AddFunctionDefToGraphLibrary( // Functionalizes function given by `func_name`. Update `func_map` accordingly. absl::Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, + const std::string& func_name, const std::string& new_func_name, + const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter = {}); @@ -165,11 +168,11 @@ absl::Status FunctionalizeControlFlowForNodeAssociatedFunctions( associated_functions.size() == 1); // Process one node-function-pair. - string func_name = associated_function.func_name(); - string canonicalized_name = + std::string func_name = associated_function.func_name(); + std::string canonicalized_name = Canonicalize(func_name, AttrSlice(&associated_function.attrs())); auto func_iter = func_map->find(canonicalized_name); - string new_func_name; + std::string new_func_name; if (FunctionHasBeenProcessed(func_iter, func_map)) { if (FunctionHasBeenModified(func_iter)) { *any_function_modified = true; @@ -202,8 +205,8 @@ absl::Status FunctionalizeControlFlowForNodeAssociatedFunctions( } absl::Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, + const std::string& func_name, const std::string& new_func_name, + const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) { *function_modified = false; @@ -341,8 +344,8 @@ absl::Status FunctionalizeControlFlowForXlaPass::Run( // Find XLA compile ops and its corresponding FunctionDef. // TPUCompile op is not in the map because graph rewriting might happen // multiple times, and we want to avoid functionalize it again. - static std::map* kNodeTypeToFunctionAttrMapping = - new std::map{ + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ // _TPUReplicate ops are generated by EncapsulateTPUComputationsPass. {"_TPUReplicate", "computation"}, // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. @@ -355,12 +358,12 @@ absl::Status FunctionalizeControlFlowForXlaPass::Run( if (it == kNodeTypeToFunctionAttrMapping->end()) { continue; } - const string func_attr = it->second; + const std::string func_attr = it->second; NameAttrList func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); VLOG(2) << "Graph has node " << n->type_string() << ". Corresponding function: " << func.name(); - string new_func_name = options.flib_def->UniqueFunctionName( + std::string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 7727853a8c4233..24fe7f5e13e7e0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -46,7 +46,7 @@ namespace { // Returns the names of the "then" and "else" functions for the If node in a // graph. -absl::Status FindIfThenAndElse(const GraphDef& graph, string* op_name, +absl::Status FindIfThenAndElse(const GraphDef& graph, std::string* op_name, NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "If") { @@ -97,7 +97,7 @@ INSTANTIATE_TEST_SUITE_P( info) { bool restrict_to_tpu_nodes = std::get<0>(info.param); bool wrap_cond_in_function = std::get<1>(info.param); - string name = + std::string name = absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter", wrap_cond_in_function ? "_in_function" : "_in_graph"); return name; @@ -114,7 +114,7 @@ void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { auto identity_t = ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); - auto seventeen = ops::Const( + auto seventeen = ops::Const( scope.WithOpName("cond").WithControlDependencies(identity_t), 17); auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, @@ -122,7 +122,7 @@ void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { auto identity_f = ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); - auto twenty_three = ops::Const( + auto twenty_three = ops::Const( scope.WithOpName("cond").WithControlDependencies(identity_f), 23); auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); auto add = ops::Add(scope.WithOpName("cond/false/add"), @@ -146,7 +146,7 @@ void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { void ConditionalTestFixture::CheckGraphDef( const GraphDef& graph_def, const FunctionLibraryDefinition& library) { - string op_name; + std::string op_name; NameAttrList then_fn; NameAttrList else_fn; TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); @@ -285,7 +285,7 @@ void ConditionalTestFixture::RunTest() { FunctionLibraryRuntime::Handle handle; // Functionalized function name is the type string of `cond_node`. - string func_name; + std::string func_name; for (Node* n : graph.nodes()) { if (n->name() == "cond_node") { func_name = n->type_string(); @@ -341,7 +341,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -352,7 +352,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { switch_.output_false); auto identity = ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto next_iteration = @@ -405,7 +405,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); @@ -427,7 +427,7 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); @@ -463,7 +463,8 @@ FunctionDef GetNoinlineFunctionDef() { // return [x + 1] // Define the above function, and add it to the given graph. It's used as the // while loop body in NoinlineLoopBody test. -absl::Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +absl::Status AddNoinlineFunctionToGraph(const std::string& node_name, + Graph* graph) { FunctionDefLibrary fdef_lib; *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); @@ -481,7 +482,7 @@ absl::Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { // x = array_ops.placeholder(dtypes.int32) // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) TEST(FunctionalizeControlFlow, NoinlineLoopBody) { - const string& noinline_node_name = "while/increment_fn"; + const std::string& noinline_node_name = "while/increment_fn"; Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -491,7 +492,7 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { "while/while_context"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -585,7 +586,7 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { } TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { - const string& noinline_node_name = "while/increment_fn"; + const std::string& noinline_node_name = "while/increment_fn"; Graph graph(OpRegistry::Global()); { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -622,7 +623,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); auto merge = ops::Merge(scope.WithOpName("while/Merge"), std::initializer_list{enter, dummy}); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), 10); auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); @@ -631,7 +632,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); auto identity = ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto next_iteration = @@ -673,7 +674,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); @@ -695,7 +696,7 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); auto add = ops::Add(scope.WithOpName("while/add"), identity, one); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); @@ -739,14 +740,15 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { std::initializer_list{enter_y, dummy}); // Loop condition - auto three = ops::Const(scope.WithOpName("while/cond/three") - .WithControlDependencies(merge_x.output), - 3); + auto three = + ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(merge_x.output), + 3); auto cond_add = ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); - auto ten = ops::Const(scope.WithOpName("while/cond/ten") - .WithControlDependencies(merge_x.output), - 10); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(merge_x.output), + 10); auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); @@ -765,10 +767,10 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), switch_y.output_true); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/one").WithControlDependencies(identity_x), 1); - auto two = ops::Const( + auto two = ops::Const( scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), 2); @@ -825,14 +827,15 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { Scope scope = Scope::NewRootScope().ExitOnError(); auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto three = ops::Const(scope.WithOpName("while/cond/three") - .WithControlDependencies(arg0.output), - 3); + auto three = + ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); auto cond_add = ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); - auto ten = ops::Const(scope.WithOpName("while/cond/ten") - .WithControlDependencies(arg0.output), - 10); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(arg0.output), + 10); auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0); @@ -859,10 +862,10 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("while/add/one").WithControlDependencies(identity_x), 1); - auto two = ops::Const( + auto two = ops::Const( scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), 2); @@ -922,7 +925,7 @@ INSTANTIATE_TEST_SUITE_P( bool mark_inner_loop_tpu = std::get<1>(info.param); bool mark_outer_loop_tpu = std::get<2>(info.param); - string node_string; + std::string node_string; if (mark_inner_loop_tpu && mark_outer_loop_tpu) node_string = "both_loops_tpu"; else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu) @@ -930,7 +933,7 @@ INSTANTIATE_TEST_SUITE_P( else node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu"; - string name = absl::StrCat( + std::string name = absl::StrCat( restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string); return name; }); @@ -961,21 +964,21 @@ void ComplexTestFixture::RunTest() { auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); + auto three = ops::Const(scope.WithOpName("three"), 3); auto y = ops::Add(scope.WithOpName("y"), x, three); auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, TensorShape({})); // Outer loop - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); auto enter_i = ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), std::initializer_list{enter_i, dummy}); - auto ten = ops::Const(scope.WithOpName("outer/Less/y") - .WithControlDependencies(merge_i.output), - 10); + auto ten = ops::Const(scope.WithOpName("outer/Less/y") + .WithControlDependencies(merge_i.output), + 10); auto less_i = ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); auto outer_loop_cond = @@ -998,7 +1001,7 @@ void ComplexTestFixture::RunTest() { ops::internal::Enter::Attrs().IsConstant(true)); // Inner loop - auto one_j = ops::Const( + auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), one_j, "inner"); @@ -1018,9 +1021,10 @@ void ComplexTestFixture::RunTest() { auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), std::initializer_list{enter_k, dummy}); - auto five = ops::Const(scope.WithOpName("outer/inner/Five") - .WithControlDependencies(merge_j.output), - 5); + auto five = + ops::Const(scope.WithOpName("outer/inner/Five") + .WithControlDependencies(merge_j.output), + 5); auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); auto loop_cond = @@ -1047,7 +1051,7 @@ void ComplexTestFixture::RunTest() { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("outer/inner/One") .WithControlDependencies( absl::Span{assign.operation}), @@ -1061,7 +1065,7 @@ void ComplexTestFixture::RunTest() { scope.WithOpName("outer/inner/NextIteration_k"), identity_k); // Body and backedge for outer loop. - auto one_outer = ops::Const( + auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") @@ -1086,9 +1090,10 @@ void ComplexTestFixture::RunTest() { } // Add '_tpu_replicate' attributes as specified. for (Node* n : graph.nodes()) { - string name = n->name(); - bool is_inner_node = name.find("outer/inner/") != string::npos; - bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos; + std::string name = n->name(); + bool is_inner_node = name.find("outer/inner/") != std::string::npos; + bool is_outer_node = + !is_inner_node && name.find("outer/") != std::string::npos; if ((is_inner_node && mark_inner_loop_tpu_) || (is_outer_node && mark_outer_loop_tpu_)) { n->AddAttr("_tpu_replicate", "cluster"); @@ -1159,13 +1164,13 @@ void ComplexTestFixture::CheckOuterNodesFunctionalized( { Scope scope = Scope::NewRootScope().ExitOnError(); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto three = ops::Const(scope.WithOpName("three"), 3); + auto three = ops::Const(scope.WithOpName("three"), 3); auto y = ops::Add(scope.WithOpName("y"), x, three); auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, TensorShape({})); - auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), std::initializer_list{zero, y, x, var}, @@ -1184,7 +1189,7 @@ void ComplexTestFixture::CheckOuterNodesFunctionalized( auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - auto ten = ops::Const( + auto ten = ops::Const( scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), 10); auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); @@ -1220,14 +1225,14 @@ void ComplexTestFixture::CheckOuterNodesFunctionalized( auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); - auto one_j = ops::Const( + auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = ops::While(scope.WithOpName("outer/inner/LoopCond"), std::initializer_list{one_j, arg1, arg2, arg3}, inner_cond_fn, inner_body_fn); - auto one_outer = ops::Const( + auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") @@ -1262,7 +1267,7 @@ void ComplexTestFixture::CheckInnerNodesFunctionalized( auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3); - auto five = ops::Const( + auto five = ops::Const( scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0); @@ -1299,7 +1304,7 @@ void ComplexTestFixture::CheckInnerNodesFunctionalized( auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - auto one = ops::Const( + auto one = ops::Const( scope.WithOpName("outer/inner/One") .WithControlDependencies( absl::Span{assign.operation}), diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index cf3413154b8baa..d8558e7fb2b5fe 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -42,7 +42,7 @@ absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { absl::Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, - std::unordered_map* frames, + std::unordered_map* frames, const NodeFilter& node_filter) { for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 970f62daa42af3..90c50f75e36387 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -47,7 +47,7 @@ struct WhileLoopArg { // Information about a loop frame. struct WhileLoopFrame { - string name; + std::string name; // Pointer to the parent frame. The root frame has a pointer to itself. WhileLoopFrame* parent = nullptr; @@ -76,7 +76,7 @@ struct WhileLoopFrame { // `FunctionalizeControlFlow` for more details about node filters). absl::Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, - std::unordered_map* frames, + std::unordered_map* frames, const NodeFilter& node_filter = {}); // Check that the graph has no cycle containing the given node. @@ -97,10 +97,10 @@ absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); // Returns a textual representation of the names of the nodes in the input. template -string NodesToString(const T& nodes) { +std::string NodesToString(const T& nodes) { return absl::StrCat("{", absl::StrJoin(nodes, ",", - [](string* output, const Node* node) { + [](std::string* output, const Node* node) { absl::StrAppend(output, node->name()); }), "}"); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 2c02379c36cd45..b8183afd59481a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -438,7 +438,7 @@ absl::Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, builder.Attr("body", body_name); // Add some internal attributes which need to be propagated. for (absl::string_view attr_name : kAttrsToPropagate) { - string attr_val; + std::string attr_val; if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) { builder.Attr(attr_name, attr_val); } @@ -513,7 +513,7 @@ absl::Status FunctionalizeWhileLoop(Graph* graph, // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - std::vector unreachable_nodes; + std::vector unreachable_nodes; TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); if (!unreachable_nodes.empty()) { return errors::InvalidArgument( @@ -522,7 +522,7 @@ absl::Status FunctionalizeWhileLoop(Graph* graph, } // Builds Frames, indexed by name. - std::unordered_map frames; + std::unordered_map frames; TF_RETURN_IF_ERROR( ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter)); diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc index 2759ad8384cd81..b331272a2c9504 100644 --- a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { namespace { -absl::Status GetTestDevice(Session* session, string* test_device) { +absl::Status GetTestDevice(Session* session, std::string* test_device) { std::vector devices; TF_RETURN_IF_ERROR(session->ListDevices(&devices)); @@ -85,7 +85,7 @@ TEST(FusedBatchnormReserveSpaceTest, Test) { std::unique_ptr session( tensorflow::NewSession(tensorflow::SessionOptions{})); - string test_device; + std::string test_device; TF_ASSERT_OK(GetTestDevice(session.get(), &test_device)); Scope root = tensorflow::Scope::NewRootScope(); @@ -108,8 +108,8 @@ TEST(FusedBatchnormReserveSpaceTest, Test) { Output variance = Const(root.WithOpName("variance"), Input::Initializer(variance_data)); - string tf_device = absl::StrCat("/device:", test_device, ":0"); - string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); + std::string tf_device = absl::StrCat("/device:", test_device, ":0"); + std::string xla_device = absl::StrCat("/device:XLA_", test_device, ":0"); FusedBatchNorm fused_batch_norm_tf( root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input, diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index f23c423fbb2632..5f794005b7c7c0 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -292,12 +292,12 @@ absl::Status GraphCompiler::CompileFunctionalNode(Node* n, } } if (add_token_input_output) { - std::vector token_input_nodes; + std::vector token_input_nodes; TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()), kXlaTokenInputNodesAttrName, &token_input_nodes)); std::vector token_inputs; - for (const string& node_name : token_input_nodes) { + for (const std::string& node_name : token_input_nodes) { auto token_or = compiler->GetNodeToken(node_name); TF_RETURN_IF_ERROR(token_or.status()); token_inputs.push_back(std::move(token_or).value()); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_test.cc b/tensorflow/compiler/tf2xla/graph_compiler_test.cc index 3010ac7f0b026b..2dcb2ea0b52d45 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_test.cc @@ -104,8 +104,8 @@ class GraphCompilerTest : public ::testing::Test { core::ScopedUnref context_unref(xla_context); xla_context->Ref(); - auto step_container = - std::make_unique(0, [this](const string& name) { + auto step_container = std::make_unique( + 0, [this](const std::string& name) { absl::Status status = this->device_->resource_manager()->Cleanup(name); }); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index d1c984e26f390a..116c1e68f66fe6 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -44,7 +44,7 @@ const char* const kFetchIdAttr = "_fetch_id"; const char* const kShapeAttr = "_shape"; const char* const kDebugNameAttr = "_debug_name"; -typedef std::unordered_map NodeMap; +typedef std::unordered_map NodeMap; // Each feed id identifies the positional output of some node, which may consist // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed @@ -54,14 +54,14 @@ typedef std::unordered_map NodeMap; absl::Status AddArgNodes( Graph* graph, const NodeMap& node_map, const protobuf::RepeatedPtrField& feeds, - const std::unordered_map& feed_remapping, + const std::unordered_map& feed_remapping, std::unordered_set* arg_nodes) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const tf2xla::Feed& feed = feeds[arg_index]; // All feeds have been replaced by placeholders. const int output_index = 0; - const string key = TensorIdToString(feed.id()); + const std::string key = TensorIdToString(feed.id()); const auto remap_it = feed_remapping.find(key); auto node_it = node_map.find(remap_it->second); if (node_it == node_map.end()) { @@ -149,7 +149,7 @@ absl::Status AddRetvalNodes( // execution to know the input and output args for the generated function. absl::Status RewriteAndPruneGraph( Graph* graph, const tf2xla::Config& config, - const std::unordered_map& feed_remapping) { + const std::unordered_map& feed_remapping) { NodeMap node_map; for (Node* n : graph->nodes()) { node_map[n->name()] = n; @@ -164,7 +164,7 @@ absl::Status RewriteAndPruneGraph( FixupSourceAndSinkEdges(graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. - std::set missing_feeds, missing_fetches; + std::set missing_feeds, missing_fetches; for (const tf2xla::Feed& feed : config.feed()) { missing_feeds.insert(TensorIdToString(feed.id())); } @@ -173,14 +173,14 @@ absl::Status RewriteAndPruneGraph( } for (const Node* n : graph->op_nodes()) { if (n->type_string() == FunctionLibraryDefinition::kArgOp) { - string feed_id; + std::string feed_id; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { return errors::Aborted(FunctionLibraryDefinition::kArgOp, " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { - string fetch_id; + std::string fetch_id; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { return errors::Aborted(FunctionLibraryDefinition::kRetOp, @@ -277,7 +277,7 @@ absl::Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, GraphDef first_copy_def = graph_def; // Maps from name:port of a feed to the name:port of the placeholder to use. - std::unordered_map feed_remapping; + std::unordered_map feed_remapping; TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), &feed_remapping, &first_copy_def)); diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc index a6ddbfd3a01fef..74c888d37de784 100644 --- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc @@ -94,9 +94,9 @@ class CollectiveReduceV2Op : public XlaOpKernel { private: DataType dtype_ = DT_INVALID; - string merge_op_name_; - string final_op_name_; - string communication_hint_; + std::string merge_op_name_; + std::string final_op_name_; + std::string communication_hint_; CollectiveReduceV2Op(const CollectiveReduceV2Op&) = delete; void operator=(const CollectiveReduceV2Op&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 0dd528e3dea173..240a099f075aa2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -48,7 +48,7 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); OP_REQUIRES_OK( ctx, ctx->GetAttr("exponential_avg_factor", &exponential_avg_factor_)); - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), @@ -61,7 +61,7 @@ class FusedBatchNormOp : public XlaOpKernel { errors::InvalidArgument( "FusedBatchNormEx supports at most 1 side input.")); add_side_input_ = (num_side_inputs == 1); - string activation_mode; + std::string activation_mode; OP_REQUIRES_OK(ctx, ctx->GetAttr("activation_mode", &activation_mode)); OP_REQUIRES(ctx, activation_mode == "Identity" || activation_mode == "Relu", @@ -249,7 +249,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 7c89720292b0a7..94486a104152ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -66,9 +66,11 @@ class BCastArgsOp : public XlaOpKernel { Tensor output(val_type, TensorShape({len})); for (int64_t i = 0; i < len; ++i) { if (val_type == DT_INT32) { - output.flat()(i) = static_cast(bcast.output_shape()[i]); + output.flat()(i) = + static_cast(bcast.output_shape()[i]); } else { - output.flat()(i) = static_cast(bcast.output_shape()[i]); + output.flat()(i) = + static_cast(bcast.output_shape()[i]); } } ctx->SetConstantOutput(0, output); @@ -129,9 +131,9 @@ class BCastGradArgsOp : public XlaOpKernel { Tensor constant(val_type, TensorShape({len})); for (int64_t i = 0; i < len; ++i) { if (val_type == DT_INT32) { - constant.flat()(i) = static_cast(v[i]); + constant.flat()(i) = static_cast(v[i]); } else { - constant.flat()(i) = static_cast(v[i]); + constant.flat()(i) = static_cast(v[i]); } } ctx->SetConstantOutput(idx, constant); diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index 2bf4ab52c8b59e..bf428711664d76 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -28,7 +28,7 @@ namespace { class BiasOp : public XlaOpKernel { public: explicit BiasOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; + std::string data_format; if (ctx->GetAttr("data_format", &data_format).ok()) { OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index 510d5225d6f04b..7d323b16d8856e 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -55,7 +55,7 @@ class BucketizeOp : public XlaOpKernel { /*broadcast_dimensions=*/{0}), xla::S32); xla::XlaOp buckets = xla::Reduce( - comparison, /*init_value=*/xla::ConstantR0(builder, 0), + comparison, /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); context->SetOutput(0, buckets); diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index cead6d10c2a0eb..da40d84e73f063 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -66,7 +66,7 @@ XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) { return {unpruned_branches_, ctx->Input(0)}; } - int32_t branch_index = branch_index_literal.Get({}); + int32_t branch_index = branch_index_literal.Get({}); if (branch_index < 0 || branch_index >= unpruned_branches_.size()) { branch_index = unpruned_branches_.size() - 1; } @@ -187,7 +187,8 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { // Add any TensorArray gradients touched by the then/else computation to // the enclosing graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { + for (const std::string& grad_source : + update.tensor_array_gradients_accessed) { VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; @@ -289,7 +290,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { // Set token input for this "case" op. std::vector token_inputs; token_inputs.reserve(token_input_nodes_.size()); - for (const string& node_name : token_input_nodes_) { + for (const std::string& node_name : token_input_nodes_) { auto token_or = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token_or.status()); token_inputs.push_back(token_or.value()); diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index a4c01bea65a04d..6574fb4aac4c5e 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -65,8 +65,8 @@ class XlaCaseOp : public XlaOpKernel { DataTypeVector input_types_; DataTypeVector output_types_; bool has_token_input_output_; - std::vector token_input_nodes_; - string original_node_name_; + std::vector token_input_nodes_; + std::string original_node_name_; // Whether to propagate compile time consts into the cond branches. // This is not supported by default now since it may cause HBM memory // overheads. diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index e8c804791299a7..2c69974d8373dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -185,7 +185,7 @@ class StatelessCategoricalOp : public CategoricalOp { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessCategoricalOp(const StatelessCategoricalOp&) = delete; void operator=(const StatelessCategoricalOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index d2463a9974b1bb..7ab53f7ad89e75 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -38,7 +38,7 @@ template ::value>::type* = nullptr> DstT CastTo(int32_t src) { - return absl::bit_cast(static_cast(src)); + return absl::bit_cast(static_cast(src)); } // Returns scalar constant with the value in the tensor, if the given proto has diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 3fe22dcb4441e7..59f72e630c0f75 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -163,8 +163,8 @@ absl::Status CheckConvAttrs(const ConvOpAttrs& attrs) { absl::Status ConvBackpropComputeDimensionsV2XlaShapes( absl::string_view label, int num_spatial_dims, const xla::Shape& input_shape, const xla::Shape& filter_shape, - const xla::Shape& out_backprop_shape, absl::Span dilations, - const std::vector& strides, Padding padding, + const xla::Shape& out_backprop_shape, absl::Span dilations, + const std::vector& strides, Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims, absl::Span explicit_paddings) { TensorShape input_tensor_shape, filter_tensor_shape, @@ -203,7 +203,7 @@ absl::StatusOr ConvOpAttrs::Create(int num_spatial_dims, ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); } - string data_format; + std::string data_format; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); if (!FormatFromString(data_format, &attrs.data_format)) { return errors::InvalidArgument("Invalid data format: ", data_format); @@ -231,7 +231,7 @@ absl::StatusOr ConvNDOpAttrs::Create(OpKernelConstruction* ctx) { ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings)); } - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format_str)); if (!(data_format_str == "CHANNELS_LAST" || data_format_str == "CHANNELS_FIRST")) { diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 94e454df205df2..e64cebe3970cd8 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -54,8 +54,8 @@ struct ConvOpAttrs { bool depthwise; int num_spatial_dims; - std::vector dilations; - std::vector strides; + std::vector dilations; + std::vector strides; Padding padding; std::vector explicit_paddings; TensorFormat data_format; @@ -68,8 +68,8 @@ struct ConvNDOpAttrs { int groups; int batch_dims; - std::vector dilations; - std::vector strides; + std::vector dilations; + std::vector strides; Padding padding; std::vector explicit_paddings; TensorFormat data_format; diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index b1da0acd61608f..82fdf8ea577e39 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -92,9 +92,9 @@ class ConvNDOp : public XlaOpKernel { ConvOpAttrs forward_attrs; forward_attrs.depthwise = false; forward_attrs.num_spatial_dims = num_spatial_dims; - forward_attrs.dilations = attrs_.dilations.empty() - ? std::vector(num_spatial_dims + 2, 1) - : attrs_.dilations; + forward_attrs.dilations = + attrs_.dilations.empty() ? std::vector(num_spatial_dims + 2, 1) + : attrs_.dilations; forward_attrs.strides = attrs_.strides; forward_attrs.padding = attrs_.padding; forward_attrs.explicit_paddings = attrs_.explicit_paddings; diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index 226d6248bd00d8..27818415169dbe 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -36,9 +36,9 @@ class DataFormatDimMapOp : public XlaOpKernel { public: explicit DataFormatDimMapOp(OpKernelConstruction* context) : XlaOpKernel(context) { - string src_format; + std::string src_format; OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); - string dst_format; + std::string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, errors::InvalidArgument( @@ -69,9 +69,9 @@ class DataFormatDimMapOp : public XlaOpKernel { void Compile(XlaOpKernelContext* context) override { auto builder = context->builder(); xla::XlaOp dst_indices = - xla::ConstantR1(builder, absl::Span(dst_idx_)); + xla::ConstantR1(builder, absl::Span(dst_idx_)); const int dims = dst_idx_.size(); - xla::XlaOp rank = xla::ConstantR0(builder, dims); + xla::XlaOp rank = xla::ConstantR0(builder, dims); xla::XlaOp src_indices = (xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank; xla::XlaOp output = @@ -81,7 +81,7 @@ class DataFormatDimMapOp : public XlaOpKernel { } private: - std::vector dst_idx_; + std::vector dst_idx_; DataFormatDimMapOp(const DataFormatDimMapOp&) = delete; void operator=(const DataFormatDimMapOp&) = delete; @@ -146,13 +146,13 @@ class DataFormatVecPermuteOp : public XlaOpKernel { input_tensor_shape.DebugString())); } - string src_format_str = src_format_; - string dst_format_str = dst_format_; + std::string src_format_str = src_format_; + std::string dst_format_str = dst_format_; if (input_tensor_shape.dim_size(0) == spatial_dim_count) { // If the input is a vector of size spatial_dim_count, treat the elements // as spatial dimensions. auto keep_only_spatial_dimensions = - [spatial_dim_count](string* format_str) -> void { + [spatial_dim_count](std::string* format_str) -> void { auto new_end = std::remove_if(format_str->begin(), format_str->end(), [spatial_dim_count](const char dim) { @@ -164,7 +164,7 @@ class DataFormatVecPermuteOp : public XlaOpKernel { keep_only_spatial_dimensions(&src_format_str); keep_only_spatial_dimensions(&dst_format_str); } - std::vector dst_indices(dim0); + std::vector dst_indices(dim0); for (int i = 0; i < dim0; ++i) { for (int j = 0; j < dim0; ++j) { if (src_format_str[i] == dst_format_str[j]) { @@ -174,14 +174,14 @@ class DataFormatVecPermuteOp : public XlaOpKernel { } } xla::XlaOp indices = - xla::ConstantR1(builder, absl::Span(dst_indices)); + xla::ConstantR1(builder, absl::Span(dst_indices)); xla::XlaOp output = xla::TorchIndexSelect(ctx->Input(0), indices, 0); ctx->SetOutput(0, output); } private: - string src_format_; - string dst_format_; + std::string src_format_; + std::string dst_format_; DataFormatVecPermuteOp(const DataFormatVecPermuteOp&) = delete; void operator=(const DataFormatVecPermuteOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index e8e2babffd529c..7e93ed9c32e126 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -31,7 +31,7 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index d383c7d0ab4aa3..bc03e14556f9cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -42,7 +42,7 @@ float get_fullrange() { class DequantizeOp : public XlaOpKernel { public: explicit DequantizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string mode_string; + std::string mode_string; int axis; bool narrow_range; diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc index 141415bcd0d8c0..a5665baa6e3dc5 100644 --- a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -39,11 +39,11 @@ class DeviceIndexOp : public XlaOpKernel { // When compiling we are not executing on any physical device, so we return // a sentinel value (size of the list of devices). ctx->SetOutput( - 0, xla::ConstantR0(ctx->builder(), device_names_.size())); + 0, xla::ConstantR0(ctx->builder(), device_names_.size())); } private: - std::vector device_names_; + std::vector device_names_; }; REGISTER_XLA_OP(Name("DeviceIndex"), DeviceIndexOp); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index ceeea010ee7858..ae7488ad1e1cbd 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -54,8 +54,8 @@ class DynamicPartitionOp : public XlaOpKernel { xla::XlaOp CountS32(XlaOpKernelContext* ctx, xla::XlaOp input, int64_t target) { xla::XlaOp equal_dim = - xla::Compare(input, xla::ConstantR0(ctx->builder(), target), {}, - xla::ComparisonDirection::kEq); + xla::Compare(input, xla::ConstantR0(ctx->builder(), target), + {}, xla::ComparisonDirection::kEq); xla::XlaOp casted = xla::ConvertElementType(equal_dim, xla::S32); return xla::ReduceAll( casted, xla::Zero(ctx->builder(), xla::S32), @@ -178,8 +178,9 @@ class DynamicPartitionOp : public XlaOpKernel { } else { xla::XlaOp length; if (count_diff != 0) { - length = xla::Div(partition_length[i], - xla::ConstantR0(ctx->builder(), count_diff)); + length = + xla::Div(partition_length[i], + xla::ConstantR0(ctx->builder(), count_diff)); } else { length = CountS32(ctx, ctx->Input(1), /*target=*/i); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index cb7e4f6f96437e..edf9afb5ae14fb 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -145,8 +145,8 @@ class DynamicStitchOp : public XlaOpKernel { // Construct the reverse mapping, for each index, of which slice of which // input it comes from. - std::vector src_input_vector(number_of_indices); - std::vector src_slice_vector(number_of_indices); + std::vector src_input_vector(number_of_indices); + std::vector src_slice_vector(number_of_indices); std::vector src_index_used(number_of_indices); int index_used_count = 0; for (int input_num = 0; input_num < indices.size(); input_num++) { diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 4a1de78d9371b3..b9ca65cfbd6371 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -179,9 +179,9 @@ class ExtractImagePatchesOp : public XlaOpKernel { } protected: - std::vector ksizes_; - std::vector dilations_; - std::vector strides_; + std::vector ksizes_; + std::vector dilations_; + std::vector strides_; Padding padding_; private: diff --git a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc index b2b1eb3343e698..8075982c766a97 100644 --- a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc @@ -154,7 +154,7 @@ class FusedConv2DInt8Op : public XlaOpKernel { // Un-vectorize NCHW_VECT_C to NCHW. TensorFormat orig_data_format = conv_attrs_.data_format; - int64 vect_width = -1; + int64_t vect_width = -1; switch (conv_attrs_.data_format) { case FORMAT_NCHW_VECT_C: vect_width = conv_input_shape.dimensions(4); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2783951e1b6b0f..e94f74d1fed8ef 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -275,7 +275,7 @@ class GatherOp : public XlaOpKernel { // The number of batch dimensions, as passed in the batch_dims attribute. // It must be less than or equal to rank(indices). - int32 batch_dims_ = 0; + int32_t batch_dims_ = 0; }; REGISTER_XLA_OP(Name("Gather"), MlirXlaOpKernel); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 033144e9f308e4..2aec21a6db5888 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -28,7 +28,7 @@ namespace { class GatherOp : public XlaOpKernel { public: explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) { - string dnums_attr; + std::string dnums_attr; OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); OP_REQUIRES( context, dnums_.ParsePartialFromString(dnums_attr), @@ -60,7 +60,7 @@ class ScatterOp : public XlaOpKernel { explicit ScatterOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK( context, context->GetAttr("update_computation", &update_computation_)); - string dnums_attr; + std::string dnums_attr; OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); OP_REQUIRES( context, dnums_.ParsePartialFromString(dnums_attr), diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 17db09722ba954..56c86d3d597227 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -84,7 +84,8 @@ static absl::StatusOr PopulateTensorArrayGradients( // Add any TensorArray gradients touched by the then/else computation to // the enclosing graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { + for (const std::string& grad_source : + update.tensor_array_gradients_accessed) { VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; @@ -318,7 +319,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { if (has_token_input_output_ && i == num_inputs - 1) { // Set token input for this "if" op. std::vector token_inputs; - for (const string& node_name : token_input_nodes_) { + for (const std::string& node_name : token_input_nodes_) { auto token_or = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token_or.status()); token_inputs.push_back(token_or.value()); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index fc6dd2e08bf41f..c11cfcb08e0b09 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -61,8 +61,8 @@ class XlaIfOp : public XlaOpKernel { DataTypeVector output_types_; std::vector output_shapes_; bool has_token_input_output_; - std::vector token_input_nodes_; - string original_node_name_; + std::vector token_input_nodes_; + std::string original_node_name_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index a8eb7bbf794268..a2676e095b91b7 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -352,10 +352,11 @@ struct WhileCondFn { xla::XlaBuilder* cond_builder) const { xla::XlaOp row_idx = values[0]; xla::XlaOp row_in_bounds = - xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); + xla::Lt(row_idx, xla::ConstantR0(cond_builder, num_boxes)); xla::XlaOp num_outputs_so_far = values[1]; - xla::XlaOp results_not_full = xla::Lt( - num_outputs_so_far, xla::ConstantR0(cond_builder, output_size)); + xla::XlaOp results_not_full = + xla::Lt(num_outputs_so_far, + xla::ConstantR0(cond_builder, output_size)); return xla::And(row_in_bounds, results_not_full); } }; @@ -375,7 +376,7 @@ struct SuppressBodyFn { auto num_outputs_so_far = values[1]; auto iou_mask = values[2]; auto included_iou = values[3]; - auto zero = xla::ConstantR0(builder, 0); + auto zero = xla::ConstantR0(builder, 0); // Determine if current elem is active using a slice. // TODO(b/118437727): The only reason we need an explicit vector is because // some old GCCs can't deduce the right type for MakeConstSpan, and @@ -386,7 +387,7 @@ struct SuppressBodyFn { active_elem = xla::Reshape(active_elem, {}); // Increment output count iff current elem is not suppressed. num_outputs_so_far = xla::Select( - active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), + active_elem, num_outputs_so_far + xla::ConstantR0(builder, 1), num_outputs_so_far); // Slice out the row_idx. auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); @@ -412,7 +413,7 @@ struct SuppressBodyFn { } included_iou = xla::Select(cond, xla::And(included_iou, supp_mask), included_iou); - row_idx = row_idx + xla::ConstantR0(builder, 1); + row_idx = row_idx + xla::ConstantR0(builder, 1); return std::vector{row_idx, num_outputs_so_far, iou_mask, included_iou}; } @@ -456,7 +457,7 @@ class NonMaxSuppressionOp : public XlaOpKernel { errors::InvalidArgument( "scores size ", std::to_string(scores_shape.dim_size(0)), " must equal number of boxes ", std::to_string(num_boxes))); - OP_REQUIRES(context, num_boxes <= kint32max, + OP_REQUIRES(context, num_boxes <= std::numeric_limits::max(), errors::InvalidArgument("XLA compilation requires number of " "boxes to be <= kint32max, got ", num_boxes)); @@ -477,7 +478,7 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); - OP_REQUIRES(context, output_size <= kint32max, + OP_REQUIRES(context, output_size <= std::numeric_limits::max(), errors::InvalidArgument("Need output_size <= kint32Max, got ", output_size)); const xla::XlaOp score_thresh = context->Input("score_threshold"); @@ -564,8 +565,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { std::vector init_values; init_values.reserve(4); - init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx - init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs + init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx + init_values.push_back(xla::ConstantR0(builder, 0)); // num_outputs init_values.push_back(iou_thresh_mask); init_values.push_back(included_iou); @@ -595,8 +596,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { // can be suppressed by score threshold. xla::XlaOp ones_included = xla::Select( included, - xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), - xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); + xla::Broadcast(xla::ConstantR0(builder, 1), {num_boxes}), + xla::Broadcast(xla::ConstantR0(builder, 0), {num_boxes})); // num_valid is scalar. Value should be bound by output_size. xla::XlaOp num_valid_total = xla::Reduce( @@ -604,8 +605,8 @@ class NonMaxSuppressionOp : public XlaOpKernel { /*init_value=*/xla::ConstantR0(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); - xla::XlaOp num_valid = - xla::Min(num_valid_total, xla::ConstantR0(builder, output_size)); + xla::XlaOp num_valid = xla::Min( + num_valid_total, xla::ConstantR0(builder, output_size)); // Re-index into the original scores input tensor, using a Gather. // Boxes were suppressed in the sorted domain. diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 58811c10744131..9959f8d4e44be6 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -120,8 +120,8 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( const int64_t out_size_factor = align_corners ? out_size[i] - 1 : out_size[i]; - int64_t gcd = MathUtil::GCD(static_cast(in_size_factor), - static_cast(out_size_factor)); + int64_t gcd = MathUtil::GCD(static_cast(in_size_factor), + static_cast(out_size_factor)); dims.stride[i] = in_size_factor / gcd; dims.kernel_size[i] = out_size_factor / gcd; } diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index f357262a39c35b..5b730cc0a9076d 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -96,7 +96,7 @@ class InTopKOp : public XlaOpKernel { xla::CreateScalarAddComputation(xla::S32, xla_builder), {1}); xla::XlaOp result = - xla::And(xla::Lt(num_gt_r1, xla::ConstantR0(xla_builder, k)), + xla::And(xla::Lt(num_gt_r1, xla::ConstantR0(xla_builder, k)), xla::IsFinite(targets_values_r1)); context->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index 718f59e1227dc1..899c0063035b82 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -469,7 +469,7 @@ class TfCallbackDevice : public DeviceBase { set_tensorflow_accelerator_device_info(&accelerator_device_info_); } - const string& name() const override { return name_; } + const std::string& name() const override { return name_; } PerOpGpuDevice* MakeGpuDevice() override { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index dfe8a36005b837..aabbd8d3b0514e 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -60,7 +60,7 @@ class ListDiffOp : public XlaOpKernel { absl::Status status; switch (val_type) { case DT_INT32: - status = ListDiffWithIndexType(context, idx_type); + status = ListDiffWithIndexType(context, idx_type); break; case DT_INT64: status = ListDiffWithIndexType(context, idx_type); @@ -111,7 +111,7 @@ class ListDiffOp : public XlaOpKernel { DataType idx_type) { switch (idx_type) { case DT_INT32: - return ListDiff(context); + return ListDiff(context); case DT_INT64: return ListDiff(context); default: diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 48e8f976cc67bb..8e7c966bdf35fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -57,7 +57,7 @@ static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal, void ReadAlignment(OpKernelConstruction* context, bool* left_align_superdiagonal, bool* left_align_subdiagonal) { - string align; + std::string align; OP_REQUIRES_OK(context, context->GetAttr("align", &align)); *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT"; diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 82dbfb3839312c..215de2bc5067e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -78,7 +78,7 @@ class OneHotOp : public XlaOpKernel { } private: - int32 axis_; + int32_t axis_; OneHotOp(const OneHotOp&) = delete; void operator=(const OneHotOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 1758451faf469f..15b2b5f9d2ebbb 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -113,7 +113,7 @@ class PadOp : public XlaOpKernel { high_pad_size = xla::Reshape(high_pad_size, {}); high_pad_size = xla::ConvertElementType(high_pad_size, xla::S32); // Low pad has to be static. - xla::XlaOp low_pad_size = xla::ConstantR0( + xla::XlaOp low_pad_size = xla::ConstantR0( ctx->builder(), pad_literal.Get({i, 0})); xla::XlaOp input_size = xla::GetDimensionSize(input, i); xla::XlaOp total_size = low_pad_size + input_size + high_pad_size; @@ -122,7 +122,7 @@ class PadOp : public XlaOpKernel { total_size, xla::ValueInferenceMode::kUpperBound); OP_REQUIRES_OK(ctx, size_upper_bound_status_or.status()); auto size_upper_bound = - size_upper_bound_status_or.value().Get({}); + size_upper_bound_status_or.value().Get({}); OP_REQUIRES( ctx, size_upper_bound.has_value(), errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index aa7c78b8b8f97a..77db609d997614 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -88,8 +88,8 @@ class PoolingOp : public XlaOpKernel { num_spatial_dims_(num_spatial_dims), reduction_type_(reduction_type) { if (ctx->num_inputs() == 1) { - std::vector ksize_int; - std::vector stride_int; + std::vector ksize_int; + std::vector stride_int; OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); OP_REQUIRES(ctx, ksize_int.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " @@ -255,15 +255,15 @@ class MaxPoolOp : public PoolingOp { ctx->builder()->GetShape(pooling); OP_REQUIRES_OK(ctx, result_shape.status()); - int64 num_channels = result_shape->dimensions(1); + int64_t num_channels = result_shape->dimensions(1); OP_REQUIRES( ctx, num_channels % *vect_width == 0, errors::FailedPrecondition("Result of NCHW_VECT_C op must have " "channels multiple of ", *vect_width, ", but was ", num_channels)); - absl::InlinedVector new_dims(result_shape->dimensions().begin(), - result_shape->dimensions().end()); + absl::InlinedVector new_dims( + result_shape->dimensions().begin(), result_shape->dimensions().end()); new_dims[1] /= *vect_width; new_dims.insert(new_dims.begin() + 2, *vect_width); pooling = @@ -298,7 +298,7 @@ class AvgPoolOp : public PoolingOp { : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) { - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -466,7 +466,7 @@ class MaxPool2DGradOp : public MaxPoolGradOp { public: explicit MaxPool2DGradOp(OpKernelConstruction* ctx) : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -505,7 +505,7 @@ class AvgPoolGradOp : public XlaOpKernel { errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -561,7 +561,7 @@ class AvgPoolGradOp : public XlaOpKernel { protected: const int num_spatial_dims_; std::vector ksize_; - std::vector stride_; + std::vector stride_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; }; @@ -677,7 +677,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { auto b = ctx->builder(); - auto sixteen = xla::ConstantR0(b, 16); + auto sixteen = xla::ConstantR0(b, 16); // in (f32) -> round to 7 mantissa bits (bf16)-> 16-high-bit u32. // // NOTE: Use a ReducePrecision operation instead of a cast to BF16 and back @@ -702,7 +702,7 @@ class MaxPoolGradGradOp : public XlaOpKernel { const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs"); auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs"); - auto sixteen = xla::ConstantR0(rb.get(), 16); + auto sixteen = xla::ConstantR0(rb.get(), 16); auto lhs_criteria = xla::ShiftLeft(xla::ShiftRightLogical( xla::BitcastConvertType(lhs, xla::S32), sixteen), @@ -749,7 +749,7 @@ class MaxPool2DGradGradOp : public MaxPoolGradGradOp { public: explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); @@ -767,7 +767,7 @@ class MaxPool3DGradGradOp : public MaxPoolGradGradOp { public: explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx) : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) { - string data_format; + std::string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index cac9f8a68f234e..961fce9caa7728 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -113,7 +113,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { errors::Internal("Expected 4 inputs to QuantizeAndDequantize")); num_bits = ctx->Input(3); } else { - num_bits = xla::ConstantR0(b, num_bits_); + num_bits = xla::ConstantR0(b, num_bits_); } const xla::XlaOp zero = XlaHelpers::Zero(b, data_type); @@ -129,17 +129,17 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { xla::XlaOp min_quantized, max_quantized; if (signed_input_) { if (narrow_range_) { - min_quantized = - -Pow(two, ConvertElementType( - num_bits - xla::ConstantR0(b, 1), xla_type)) + - one; + min_quantized = -Pow(two, ConvertElementType( + num_bits - xla::ConstantR0(b, 1), + xla_type)) + + one; } else { min_quantized = -Pow(two, ConvertElementType( - num_bits - xla::ConstantR0(b, 1), xla_type)); + num_bits - xla::ConstantR0(b, 1), xla_type)); } max_quantized = - Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), + Pow(two, ConvertElementType(num_bits - xla::ConstantR0(b, 1), xla_type)) - one; } else { @@ -222,7 +222,7 @@ class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp { OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63), errors::InvalidArgument("num_bits is out of range: ", num_bits_, " with signed_input_ ", signed_input_)); - string round_mode_string; + std::string round_mode_string; OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); OP_REQUIRES( ctx, diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc index 8f2350f26861c4..dea3ecf85af7b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc @@ -140,7 +140,7 @@ absl::StatusOr GetAlgId(XlaOpKernelContext* ctx, int alg_input_idx) { if (alg_dtype == DT_INT32) { return alg_literal.Get({}); } else { - return alg_literal.Get({}); + return alg_literal.Get({}); } } @@ -172,7 +172,7 @@ DataType MaybeConvertBF16ToF32(DataType const& dtype) { } absl::StatusOr BuildUniformRandoms( - XlaOpKernelContext* ctx, DataType dtype, string device_type_string, + XlaOpKernelContext* ctx, DataType dtype, std::string device_type_string, TensorShape shape, std::function lo_fn, std::function hi_fn) { @@ -190,7 +190,7 @@ absl::StatusOr BuildUniformRandoms( absl::StatusOr BuildUniformRandoms(XlaOpKernelContext* ctx, DataType dtype, - string device_type_string, + std::string device_type_string, xla::Shape xla_shape, xla::XlaOp lo, xla::XlaOp hi) { xla::XlaOp key = ctx->Input(kRandomKeyInputIdx); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index 11ff44602f1900..5fb7aa4822834c 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -73,7 +73,7 @@ DataType MaybeConvertBF16ToF32(DataType const& dtype); // type, in the given low and high range, where low and high are expressed in // XLA functions. absl::StatusOr BuildUniformRandoms( - XlaOpKernelContext* ctx, DataType dtype, string device_type_string, + XlaOpKernelContext* ctx, DataType dtype, std::string device_type_string, TensorShape shape, std::function lo, std::function hi); @@ -82,7 +82,7 @@ absl::StatusOr BuildUniformRandoms( // ops. absl::StatusOr BuildUniformRandoms(XlaOpKernelContext* ctx, DataType dtype, - string device_type_string, + std::string device_type_string, xla::Shape xla_shape, xla::XlaOp lo, xla::XlaOp hi); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 6a8a98342c1123..3bfe9e384405b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -119,7 +119,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { } } - string desc = ctx->op_kernel().name(); + std::string desc = ctx->op_kernel().name(); xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index c54c4613d29e44..a1dd0164e73fc7 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -311,7 +311,7 @@ XlaOp CalculateGradData(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::Pad(grad_data, xla::Zero(ctx->builder(), warp_type), xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); - auto shifting_value = xla::ConstantR1( + auto shifting_value = xla::ConstantR1( ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); auto shifted_gather_indices = xla::Add(gather_indices, shifting_value, {last_warp_dim}); @@ -384,7 +384,7 @@ XlaOp CalculateGradWarp(XlaOpKernelContext* ctx, XlaOp grad_output, XlaOp ratio, xla::Pad(data, xla::Zero(ctx->builder(), data_type), xla::MakeEdgePaddingConfig({{0, 0}, {1, 1}, {1, 1}, {0, 0}})); - auto shifting_value = xla::ConstantR1( + auto shifting_value = xla::ConstantR1( ctx->builder(), {/*batch=*/0, /*x(width)=*/1, /*y(height)=*/1}); auto shifted_gather_indices = xla::Add(gather_indices, shifting_value, {last_warp_dim}); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 5cecbf37706283..5c77a4dfe29934 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -134,8 +134,8 @@ class ReverseSequenceOp : public XlaOpKernel { } private: - int32 batch_dim_; - int32 seq_dim_; + int32_t batch_dim_; + int32_t seq_dim_; }; REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index e1e93d614286a3..32b75c26c70212 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -35,7 +35,7 @@ class SendOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override; private: - string tensor_name_; + std::string tensor_name_; SendOp(const SendOp&) = delete; void operator=(const SendOp&) = delete; @@ -60,7 +60,7 @@ class RecvOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override; private: - string tensor_name_; + std::string tensor_name_; xla::Shape shape_; RecvOp(const RecvOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 108bf3848aae93..d24d1688d188a6 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -104,7 +104,8 @@ class RangeOp : public XlaOpKernel { absl::StatusOr output; switch (type) { case DT_INT32: - output = CreateRangeTensor(start, limit, delta, ctx->builder()); + output = + CreateRangeTensor(start, limit, delta, ctx->builder()); break; case DT_INT64: output = diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 7e8889cb2ccee6..07bf81e9d76b58 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -109,7 +109,7 @@ class XlaSetBoundOp : public XlaOpKernel { bound_shape.DebugString())); int64_t bound; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); - xla::Literal bound_literal = xla::LiteralUtil::CreateR0(bound); + xla::Literal bound_literal = xla::LiteralUtil::CreateR0(bound); xla::XlaOp result = xla::CustomCall( ctx->builder(), "SetBound", {ctx->Input("input")}, ctx->InputXlaShape("input").value(), "", false, {}, &bound_literal); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index 57825657b205ab..beb38ce9a273ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -33,15 +33,15 @@ absl::Status TensorShapeToConstant(const TensorShape& input_shape, Tensor* shape_constant) { const int dims = input_shape.dims(); if (shape_constant->dtype() == DT_INT32) { - auto vec = shape_constant->vec(); + auto vec = shape_constant->vec(); for (int i = 0; i < dims; ++i) { int64_t dim_size = input_shape.dim_size(i); - if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { return errors::InvalidArgument( "Shape with out_type=int32 does not support tensors > int32max", " but dim ", i, " is ", dim_size); } - vec(i) = static_cast(dim_size); + vec(i) = static_cast(dim_size); } } else { auto vec = shape_constant->vec(); diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc index 74e04e035ef3be..0ee9173cda69e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc @@ -101,8 +101,8 @@ absl::Status GetAndValidateAttributes(OpKernelConstruction* ctx, return absl::OkStatus(); } -std::vector GetSliceIndices(absl::Span num_partitions, - absl::Span slice_shape, +std::vector GetSliceIndices(absl::Span num_partitions, + absl::Span slice_shape, const int index) { DCHECK_EQ(num_partitions.size(), slice_shape.size()); @@ -213,7 +213,7 @@ class XlaSplitNDBaseOp : public XlaOpKernel { // Calculate paddings necessary for slice instead of padding input and // slicing subsequently to reduce temporary memory allocation. for (int dim = 0; dim < rank; ++dim) { - const int64 dim_size = input_shape.dim_size(dim); + const int64_t dim_size = input_shape.dim_size(dim); if (slice_start_indices[dim] >= dim_size) { // Complete padding. slice_start_indices[dim] = dim_size; @@ -387,9 +387,9 @@ class XlaConcatNDBaseOp : public XlaOpKernel { std::vector update_slice_start_indices; update_slice_start_indices.reserve(rank); - for (int64 start_index : slice_start_indices) { + for (int64_t start_index : slice_start_indices) { update_slice_start_indices.push_back( - xla::ConstantR0(ctx->builder(), start_index)); + xla::ConstantR0(ctx->builder(), start_index)); } output = xla::DynamicUpdateSlice(output, input_slice, update_slice_start_indices); diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 844a31f97990fc..b0e337cec20c33 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -180,8 +180,8 @@ class SliceOp : public XlaOpKernel { xla::Reshape(xla::Slice(ctx->Input(2), {i}, {i + 1}, {1}), {}); if (constant_size_is_minus_one && size[i] == -1) { // size = input_.dim_size(i) - begin[i] - dynamic_size = xla::ConstantR0(ctx->builder(), - input_shape.dim_size(i)) - + dynamic_size = xla::ConstantR0(ctx->builder(), + input_shape.dim_size(i)) - begin_indices[i]; } auto constant_size = ctx->value_inference().AnalyzeConstant( @@ -192,7 +192,7 @@ class SliceOp : public XlaOpKernel { // triggered when some dimensions's slice sizes are constant while // some are dynamic. sliced = xla::SliceInDim( - sliced, 0, constant_size->Get({}).value(), 1, i); + sliced, 0, constant_size->Get({}).value(), 1, i); } else { // We gave a generous bound (same as input) to the output, try reset // the bound if a tighter one can be found. diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index ac33e0877200dc..180ba322f0fdd0 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -34,7 +34,7 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format_str; + std::string data_format_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc index 124e36557f1429..f6d468131ac94e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -69,8 +69,8 @@ class XlaSpmdFullToShardShapeOp : public XlaOpKernel { } private: - string manual_sharding_str_; - int32 single_dim_; + std::string manual_sharding_str_; + int32_t single_dim_; std::vector unspecified_dims_; XlaSpmdFullToShardShapeOp(const XlaSpmdFullToShardShapeOp&) = delete; void operator=(const XlaSpmdFullToShardShapeOp&) = delete; @@ -120,8 +120,8 @@ class XlaSpmdShardToFullShapeOp : public XlaOpKernel { private: TensorShape full_shape_; - string manual_sharding_str_; - int32 single_dim_; + std::string manual_sharding_str_; + int32_t single_dim_; std::vector unspecified_dims_; XlaSpmdShardToFullShapeOp(const XlaSpmdShardToFullShapeOp&) = delete; void operator=(const XlaSpmdShardToFullShapeOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 3c99ad63565266..4672477be3534b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -120,7 +120,7 @@ class StackOp : public XlaOpKernel { private: DataType dtype_; - string stack_name_; + std::string stack_name_; StackOp(const StackOp&) = delete; void operator=(const StackOp&) = delete; @@ -152,7 +152,7 @@ class StackPushOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; TensorShape slice_shape = elem_shape; @@ -164,7 +164,7 @@ class StackPushOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple( b, {xla::DynamicUpdateSlice(ta, update, start_indices), - xla::Add(index, xla::ConstantR0(b, 1))}))); + xla::Add(index, xla::ConstantR0(b, 1))}))); ctx->SetOutput(0, value); } @@ -204,12 +204,12 @@ class StackPopOp : public XlaOpKernel { xla::XlaOp ta = xla::GetTupleElement(state, 0); xla::XlaOp index = xla::GetTupleElement(state, 1); - index = Sub(index, xla::ConstantR0(b, 1)); + index = Sub(index, xla::ConstantR0(b, 1)); OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. std::vector start_indices(stack_shape.dims(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; auto slice_shape = stack_shape.dim_sizes(); diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index e7ff8194b96ce8..80047c5f17cc98 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -511,7 +511,7 @@ class RngSkipOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RngSkip").CompileTimeConstantInput("algorithm"), RngSkipOp<>); -using RngReadAndSkipOp = RngSkipOp; +using RngReadAndSkipOp = RngSkipOp; REGISTER_XLA_OP(Name("RngReadAndSkip").CompileTimeConstantInput("alg"), RngReadAndSkipOp); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index aa71c5c34d2e1a..246981c3465ef1 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -76,7 +76,7 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { // `BitcastConvertType(ConvertElementType(u32, U16), BF16)`, to avoid the // unclear `ConvertElementType(f32, BF16)` behavior. xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & - xla::ConstantR0(builder, 0xFFFF0000); + xla::ConstantR0(builder, 0xFFFF0000); return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32), xla::BF16); } else { @@ -184,7 +184,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformOp(const StatelessRandomUniformOp&) = delete; void operator=(const StatelessRandomUniformOp&) = delete; @@ -240,7 +240,7 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformIntOp(const StatelessRandomUniformIntOp&) = delete; void operator=(const StatelessRandomUniformIntOp&) = delete; @@ -283,7 +283,7 @@ class StatelessRandomUniformFullIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformFullIntOp(const StatelessRandomUniformFullIntOp&) = delete; @@ -336,7 +336,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomNormalOp(const StatelessRandomNormalOp&) = delete; void operator=(const StatelessRandomNormalOp&) = delete; @@ -384,7 +384,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessTruncatedNormalOp(const StatelessTruncatedNormalOp&) = delete; void operator=(const StatelessTruncatedNormalOp&) = delete; @@ -449,7 +449,7 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessParameterizedTruncatedNormalOp( const StatelessParameterizedTruncatedNormalOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index ce1fee91ae6a51..689e6ca3f7bf41 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -128,7 +128,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformOp(const StatelessRandomUniformOp&) = delete; void operator=(const StatelessRandomUniformOp&) = delete; @@ -177,7 +177,7 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformIntOp(const StatelessRandomUniformIntOp&) = delete; void operator=(const StatelessRandomUniformIntOp&) = delete; @@ -225,7 +225,7 @@ class StatelessRandomUniformFullIntOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomUniformFullIntOp(const StatelessRandomUniformFullIntOp&) = delete; @@ -295,7 +295,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessRandomNormalOp(const StatelessRandomNormalOp&) = delete; void operator=(const StatelessRandomNormalOp&) = delete; @@ -330,7 +330,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { private: DataType dtype_; - string device_type_string_; + std::string device_type_string_; StatelessTruncatedNormalOp(const StatelessTruncatedNormalOp&) = delete; void operator=(const StatelessTruncatedNormalOp&) = delete; @@ -369,7 +369,7 @@ class GetKeyCounterOp : public XlaOpKernel { } private: - string device_type_string_; + std::string device_type_string_; GetKeyCounterOp(const GetKeyCounterOp&) = delete; void operator=(const GetKeyCounterOp&) = delete; @@ -392,7 +392,7 @@ class GetAlgOp : public XlaOpKernel { } private: - string device_type_string_; + std::string device_type_string_; GetAlgOp(const GetAlgOp&) = delete; void operator=(const GetAlgOp&) = delete; @@ -430,7 +430,7 @@ class GetKeyCounterAlgOp : public XlaOpKernel { } private: - string device_type_string_; + std::string device_type_string_; GetKeyCounterAlgOp(const GetKeyCounterAlgOp&) = delete; void operator=(const GetKeyCounterAlgOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index e15196bd756462..1b44d1e07c4bd8 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -142,7 +142,7 @@ class StridedSliceOp : public XlaOpKernel { // Pad input to 2x to avoid OOB access. slice = xla::Pad(slice, xla::Zero(ctx->builder(), ctx->input_xla_type(0)), padding_config); - for (int64 i = 0; i < result_dims_are_dynamic.size(); ++i) { + for (int64_t i = 0; i < result_dims_are_dynamic.size(); ++i) { if (result_dims_are_dynamic[i]) { slice = xla::RemoveDynamicDimension(slice, i); } @@ -178,7 +178,7 @@ class StridedSliceOp : public XlaOpKernel { // Can't infer a lower bound. return false; } - return lower_bound->Get({}) >= 0; + return lower_bound->Get({}) >= 0; }; if (begin_mask) { begin_index = zero; @@ -220,7 +220,7 @@ class StridedSliceOp : public XlaOpKernel { // size 1 dims of a shape. slice = xla::Reshape(slice, final_shape.dim_sizes()); for (int64_t i = 0; i < final_shape.dims(); ++i) { - int64 processing_shape_dim = shape_spec.output_to_processing_mapping[i]; + int64_t processing_shape_dim = shape_spec.output_to_processing_mapping[i]; // If processing_shape_dim is -1, it means the output dimension was newly // added by new_axis_mask_, which doesn't show up in input. if (processing_shape_dim != -1) { @@ -341,9 +341,9 @@ class StridedSliceOp : public XlaOpKernel { int64_t sparse_index = shape_spec.output_to_sparse_mapping[i]; bool end_is_dynamic = sparse_index == -1 ? false : ends_are_dynamic[sparse_index]; - bool backward_slice = sparse_index == -1 - ? false - : end_literal.Get({sparse_index}) < 0; + bool backward_slice = + sparse_index == -1 ? false + : end_literal.Get({sparse_index}) < 0; if (input_is_dynamic || end_is_dynamic) { OP_REQUIRES( ctx, strides[input_index] == 1, @@ -363,8 +363,8 @@ class StridedSliceOp : public XlaOpKernel { "sized slice with dynamic negative index %lld. ")); operand_size = xla::Add( operand_size, - xla::ConstantR0(ctx->builder(), - end_literal.Get({sparse_index}))); + xla::ConstantR0( + ctx->builder(), end_literal.Get({sparse_index}))); } else { // The end of slice with dynamic slice size is the min of operand // shape and slice size. E.g., t[:end_size], result size is @@ -376,13 +376,13 @@ class StridedSliceOp : public XlaOpKernel { {}); } else { end_size = - xla::ConstantR0(ctx->builder(), end[input_index]); + xla::ConstantR0(ctx->builder(), end[input_index]); } operand_size = xla::Min(operand_size, end_size); } slice = xla::SetDimensionSize( slice, - xla::Sub(operand_size, xla::ConstantR0( + xla::Sub(operand_size, xla::ConstantR0( ctx->builder(), begin[input_index])), i); } @@ -397,8 +397,8 @@ class StridedSliceOp : public XlaOpKernel { } private: - int32 begin_mask_, end_mask_; - int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + int32_t begin_mask_, end_mask_; + int32_t ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; }; @@ -634,8 +634,8 @@ class StridedSliceGradOp : public XlaOpKernel { } private: - int32 begin_mask_, end_mask_; - int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + int32_t begin_mask_, end_mask_; + int32_t ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; }; @@ -751,8 +751,8 @@ class StridedSliceAssignOp : public XlaOpKernel { } private: - int32 begin_mask_, end_mask_; - int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; + int32_t begin_mask_, end_mask_; + int32_t ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; DataType dtype_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 888908e30b2331..e89c3e3b4f837b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -94,7 +94,7 @@ absl::Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, // Checks that the TensorArray 'resource' has been initialized, and has type // 'dtype'. Sets 'shape' to the shape -absl::Status CheckTensorArrayIsInitialized(const string& op_name, +absl::Status CheckTensorArrayIsInitialized(const std::string& op_name, const XlaResource* resource, DataType dtype) { if (resource->kind() != XlaResource::kTensorArray) { @@ -184,7 +184,7 @@ class TensorArrayOp : public XlaOpKernel { private: PartialTensorShape element_shape_; DataType dtype_; - string tensor_array_name_; + std::string tensor_array_name_; TensorArrayOp(const TensorArrayOp&) = delete; void operator=(const TensorArrayOp&) = delete; @@ -218,7 +218,7 @@ class TensorArrayWriteOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; TensorShape slice_shape = elem_shape; @@ -270,7 +270,7 @@ class TensorArrayReadOp : public XlaOpKernel { // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. std::vector start_indices(ta_shape.dims(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; auto slice_shape = ta_shape.dim_sizes(); @@ -430,7 +430,7 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = xla::Reshape(xla::Slice(indices, {i}, {i + 1}, {1}), {}); std::vector start_indices(elem_shape.dims() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices, dtype_); } @@ -570,7 +570,8 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = static_cast(var->max_array_size()); + size_tensor.scalar()() = + static_cast(var->max_array_size()); ctx->SetConstantOutput(0, size_tensor); } @@ -609,7 +610,7 @@ class TensorArrayGradOp : public XlaOpKernel { } private: - string source_; + std::string source_; TensorArrayGradOp(const TensorArrayGradOp&) = delete; void operator=(const TensorArrayGradOp&) = delete; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index a1f58d5ae9b40e..f128c96c570e6c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -70,7 +70,7 @@ absl::StatusOr>> GetTensorListDynamicDims( dynamic_dims.push_back(ctx->Input(1)); } else { dynamic_dims.push_back( - xla::ConstantR0(ctx->builder(), num_elements)); + xla::ConstantR0(ctx->builder(), num_elements)); } for (int64_t dim = 0; dim < element_shape.dimensions().size(); ++dim) { if (dims_are_dynamic[dim]) { @@ -80,7 +80,7 @@ absl::StatusOr>> GetTensorListDynamicDims( dynamic_dims.push_back(dynamic_dim_size); } else { dynamic_dims.push_back( - xla::ConstantR0(ctx->builder(), dynamic_sizes[dim])); + xla::ConstantR0(ctx->builder(), dynamic_sizes[dim])); } } list_dynamic_dims.push_back(std::move(dynamic_dims)); @@ -191,7 +191,7 @@ class TensorListReserveOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, SetTensorListPushIndex( - new_list, xla::ConstantR0(ctx->builder(), num_elements), + new_list, xla::ConstantR0(ctx->builder(), num_elements), &result)); ctx->SetTensorListOutput(0, result); return; @@ -324,13 +324,13 @@ class TensorListElementShapeOp : public XlaOpKernel { ctx->SetOutput(0, xla::ConstantR1(b, list_shape.dimensions())); break; case DT_INT32: { - std::vector size; + std::vector size; const auto& dimensions = list_shape.dimensions(); size.reserve(dimensions.size()); for (int64_t s : dimensions) { size.push_back(s); } - ctx->SetOutput(0, xla::ConstantR1(b, size)); + ctx->SetOutput(0, xla::ConstantR1(b, size)); break; } default: diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 683dc4737e6dab..0a7297456fce8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -393,7 +393,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, std::vector start_indices( element_part_shape.dimensions().size() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = push_index; xla::XlaOp list_part = xla::GetTupleElement(list, i); @@ -409,7 +409,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, xla::XlaOp update = xla::Reshape(element, element_dims); std::vector start_indices(element_shape.dimensions().size() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = push_index; xla::XlaOp list_part = xla::GetTupleElement(list, 0); @@ -418,7 +418,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, result_parts.push_back(updated_list_part); } - xla::XlaOp updated_push_index = push_index + xla::ConstantR0(b, 1); + xla::XlaOp updated_push_index = push_index + xla::ConstantR0(b, 1); result_parts.push_back(updated_push_index); *result = xla::Tuple(b, result_parts); @@ -441,14 +441,14 @@ absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list)); int list_tuple_size = xla::ShapeUtil::TupleElementCount(list_shape); xla::XlaOp push_index = xla::GetTupleElement(list, list_tuple_size - 1); - push_index = push_index - xla::ConstantR0(b, 1); + push_index = push_index - xla::ConstantR0(b, 1); std::vector list_result_parts, element_result_parts; for (int i = 0; i < list_tuple_size - 1; i++) { const xla::Shape& list_part_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, i); std::vector start_indices(list_part_shape.dimensions().size(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = push_index; std::vector slice_shape = @@ -496,7 +496,7 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, xla::XlaOp update = xla::Reshape(element, element_dims); std::vector start_indices(element_shape.dimensions().size() + 1, - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; xla::XlaOp list_part = xla::GetTupleElement(list, 0); @@ -550,7 +550,7 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, const xla::Shape& buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); std::vector start_indices(buffer_shape.dimensions().size(), - xla::ConstantR0(b, 0)); + xla::ConstantR0(b, 0)); start_indices[0] = index; std::vector slice_shape = @@ -585,7 +585,7 @@ absl::Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, } std::vector result_parts{tensor, - xla::ConstantR0(b, push_index)}; + xla::ConstantR0(b, push_index)}; *result = xla::Tuple(b, result_parts); return absl::OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 039320573f4558..9c4e0b63490205 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -137,7 +137,7 @@ class InvertPermutationOp : public XlaOpKernel { absl::Status status; switch (dtype) { case DT_INT32: - InvertPermutation(ctx); + InvertPermutation(ctx); break; case DT_INT64: InvertPermutation(ctx); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc index dbd6cda9d950d0..1d487f70d09d21 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc @@ -36,7 +36,7 @@ namespace tensorflow { namespace { using XlaUnaryOpGenerator = std::function; -using XlaOpGeneratorMap = absl::flat_hash_map; +using XlaOpGeneratorMap = absl::flat_hash_map; void PopulateXlaOpGeneratorMap(XlaOpGeneratorMap* op_generator_map) { auto add_xla_op_generator = [&](std::string name, @@ -120,7 +120,7 @@ class UnaryOpsCompositionOp : public XlaOpKernel { } private: - std::vector op_names_; + std::vector op_names_; }; REGISTER_XLA_OP(Name("_UnaryOpsComposition"), UnaryOpsCompositionOp); diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index a7a1a438f95b9e..c9ddab9efb6e22 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -165,7 +165,7 @@ class ResourceGatherOp : public XlaOpKernel { } private: - int32 batch_dims_; + int32_t batch_dims_; }; REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 415f465f0b5088..57821f74e97024 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -449,7 +449,8 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Add any TensorArray gradients touched by the body to the enclosing // graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { + for (const std::string& grad_source : + update.tensor_array_gradients_accessed) { VLOG(4) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; @@ -553,7 +554,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Set token input for this "while" op. std::vector token_inputs; token_inputs.reserve(token_input_nodes_.size()); - for (const string& node_name : token_input_nodes_) { + for (const std::string& node_name : token_input_nodes_) { auto token_or = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token_or.status()); token_inputs.push_back(token_or.value()); @@ -590,7 +591,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } else { int32_t dim_size = shape.dimensions(0); dynamic_dims.push_back( - xla::ConstantR0(ctx->builder(), dim_size)); + xla::ConstantR0(ctx->builder(), dim_size)); } // Set dynamic dimension size to 0 for element value. Inside the while diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 8e9f317ac4f3fe..b1937c14f0bebc 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -61,8 +61,8 @@ class XlaWhileOp : public XlaOpKernel { NameAttrList cond_name_attr_; NameAttrList body_name_attr_; bool has_token_input_output_; - std::vector token_input_nodes_; - string original_node_name_; + std::vector token_input_nodes_; + std::string original_node_name_; // Whether to propagate compile time consts into the loop body. // This is not supported by default now since it may cause HBM memory // overheads. diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 9a2a00c58732f3..e06c0b09ba9938 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -166,13 +166,13 @@ class XlaCallModuleOp : public XlaOpKernel { explicit XlaCallModuleOp(OpKernelConstruction *ctx) : XlaOpKernel(ctx) { int version; OP_REQUIRES_OK(ctx, ctx->GetAttr("version", &version)); - string module_str; + std::string module_str; OP_REQUIRES_OK(ctx, ctx->GetAttr("module", &module_str)); std::vector expected_output_shapes; OP_REQUIRES_OK(ctx, ctx->GetAttr("Sout", &expected_output_shapes)); std::vector expected_output_dtypes; OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &expected_output_dtypes)); - std::vector dim_args_spec; + std::vector dim_args_spec; OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec)); OP_REQUIRES(ctx, dim_args_spec.empty(), absl::UnimplementedError( @@ -183,9 +183,9 @@ class XlaCallModuleOp : public XlaOpKernel { "The size of Sout (", expected_output_shapes.size(), ") must match the size of Tout (", expected_output_dtypes.size(), ")"))); - std::vector disabled_checks; + std::vector disabled_checks; OP_REQUIRES_OK(ctx, ctx->GetAttr("disabled_checks", &disabled_checks)); - std::vector platforms; + std::vector platforms; OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms)); // TODO(necula): change this to OP_REQUIRES_OK when 6 months have passed // since we added the function_list and has_token_input_output @@ -222,7 +222,7 @@ class XlaCallModuleOp : public XlaOpKernel { }) << "])"; } - string compilation_device_type = ctx->device_type().type_string(); + std::string compilation_device_type = ctx->device_type().type_string(); compilation_platform_ = ""; if (compilation_device_type == DEVICE_CPU_XLA_JIT) { compilation_platform_ = "CPU"; @@ -293,7 +293,7 @@ class XlaCallModuleOp : public XlaOpKernel { xla::XlaOp token_input; if (!op_token_input_nodes_.empty()) { std::vector token_inputs; - for (const string &node_name : op_token_input_nodes_) { + for (const std::string& node_name : op_token_input_nodes_) { auto token = compiler->GetNodeToken(node_name); OP_REQUIRES_OK(ctx, token.status()); token_inputs.push_back(token.value()); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc index 139ac17b35c637..99a0ec6d9e38dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc @@ -55,8 +55,8 @@ class XlaCustomCallOp : public XlaOpKernel { } private: - string target_name_; - string backend_config_; + std::string target_name_; + std::string backend_config_; DataType output_type_; TensorShape output_shape_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc index 7b0ea597c63488..6889c093a11201 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -42,7 +42,7 @@ class XlaDequantizeOp : public XlaOpKernel { xla::QuantizedRange range(min_range_, max_range_); xla::XlaOp output = - xla::Dequantize(input, range, mode_, transpose_output_); + xla::Dequantize(input, range, mode_, transpose_output_); context->SetOutput(0, output); } @@ -50,7 +50,7 @@ class XlaDequantizeOp : public XlaOpKernel { float min_range_; float max_range_; bool transpose_output_; - string mode_; + std::string mode_; XlaDequantizeOp(const XlaDequantizeOp&) = delete; void operator=(const XlaDequantizeOp&) = delete; }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 8236e67eeded01..f77cb46c44de8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -34,12 +34,12 @@ namespace { class XlaDotOp : public XlaOpKernel { public: explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { - string dnums_attr; + std::string dnums_attr; OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); OP_REQUIRES( context, dnums_.ParsePartialFromString(dnums_attr), errors::InvalidArgument("Error parsing convolution dimension numbers")); - string precision_config_attr; + std::string precision_config_attr; OP_REQUIRES_OK( context, context->GetAttr("precision_config", &precision_config_attr)); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc index 0cfd247bdd1de6..7765de131e865c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -41,7 +41,7 @@ class XlaSelfAdjointEigOp : public XlaOpKernel { private: bool lower_; - int32 max_iter_; + int32_t max_iter_; float epsilon_; }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc index f3bd088ced826a..6639c8003e1a15 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -37,7 +37,7 @@ class XlaSvdOp : public XlaOpKernel { explicit XlaSvdOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("max_iter", &max_iter_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - string precision_config_attr; + std::string precision_config_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("precision_config", &precision_config_attr)); OP_REQUIRES(ctx, @@ -57,7 +57,7 @@ class XlaSvdOp : public XlaOpKernel { } private: - int32 max_iter_; + int32_t max_iter_; float epsilon_; xla::PrecisionConfig precision_config_; }; diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 6a67cfa237af70..0028f8e61cbd11 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -222,7 +222,7 @@ static absl::Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { return shape_inference::UnknownShape(c); } - string dimension_numbers_string; + std::string dimension_numbers_string; TF_RETURN_IF_ERROR( c->GetAttr("dimension_numbers", &dimension_numbers_string)); @@ -1027,7 +1027,7 @@ REGISTER_OP("XlaEinsum") .Attr("equation: string") .Attr("T: {complex64, bfloat16, float}") .SetShapeFn([](shape_inference::InferenceContext* context) { - string equation; + std::string equation; TF_RETURN_IF_ERROR(context->GetAttr("equation", &equation)); // XlaEinsum supports only two-input einsum equations. if (!absl::StrContains(equation, ",")) { @@ -1057,9 +1057,9 @@ REGISTER_OP("XlaSpmdFullToShardShape") if (!c->RankKnown(input_handle)) { return shape_inference::UnknownShape(c); } - string sharding_attr; + std::string sharding_attr; TF_RETURN_IF_ERROR(c->GetAttr("manual_sharding", &sharding_attr)); - int32 single_dim; + int32_t single_dim; TF_RETURN_IF_ERROR(c->GetAttr("dim", &single_dim)); xla::OpSharding sharding; sharding.ParseFromString(sharding_attr); diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index 84ed56a468df8e..47e76f81a0328c 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -304,7 +304,7 @@ absl::Status MaybeRewriteWhileNode( resource_input_count, index_mapping)); // Modify cond and body functions. - for (auto const& attr_name : std::vector{"cond", "body"}) { + for (auto const& attr_name : std::vector{"cond", "body"}) { NameAttrList attr_value; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &attr_value)); const FunctionBody* fbody; @@ -363,7 +363,7 @@ absl::Status MaybeRewriteWhileNode( // Save the new FunctionDef. FunctionDef new_fdef; - string new_name = + std::string new_name = fld->UniqueFunctionName(absl::StrCat(attr_value.name(), "_rearrange_")); TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); @@ -435,7 +435,7 @@ absl::Status MaybeRewriteIfNode( std::map resource_retval_to_arg, retval_index_mapping; for (auto const& attr_name : - std::vector{"then_branch", "else_branch"}) { + std::vector{"then_branch", "else_branch"}) { NameAttrList f; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &f)); const FunctionBody* fbody; @@ -459,7 +459,7 @@ absl::Status MaybeRewriteIfNode( // Save the new FunctionDef. FunctionDef new_fdef; - string new_name = + std::string new_name = fld->UniqueFunctionName(absl::StrCat(f.name(), "_rearrange_")); TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, new_name, &new_fdef)); diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index 956f597301d28d..39efe2d682eb12 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -34,15 +34,16 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { } TEST(ResourceOperationTableTest, HaveAllResourceOps) { - absl::flat_hash_map known_resource_ops; + absl::flat_hash_map known_resource_ops; for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( - known_resource_ops.insert({string(known_resource_op), false}).second); + known_resource_ops.insert({std::string(known_resource_op), false}) + .second); } - std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); - for (const string& xla_op_name : xla_op_names) { + std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const std::string& xla_op_name : xla_op_names) { const OpDef* op_def; TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); if (HasResourceInputOrOutput(*op_def)) { @@ -52,7 +53,7 @@ TEST(ResourceOperationTableTest, HaveAllResourceOps) { } } - std::vector unnecessary_resource_ops; + std::vector unnecessary_resource_ops; for (const auto& pair : known_resource_ops) { if (!pair.second) { unnecessary_resource_ops.push_back(pair.first); diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 7e0b70e4df270a..4b285078f94d21 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -50,7 +50,8 @@ xla::OpMetadata CreateOpMetadata(const std::string& op_type, } void AssignOpMetadataToSharding(xla::OpSharding& sharding, - const string& op_type, const string& op_name) { + const std::string& op_type, + const std::string& op_name) { auto metadata = CreateOpMetadata(op_type, op_name); if (sharding.type() == xla::OpSharding::TUPLE) { for (auto& sharding_element : *sharding.mutable_tuple_shardings()) { @@ -69,7 +70,7 @@ absl::Status CoreOutOfRangeError(int core, int num_cores_per_replica) { } // namespace absl::StatusOr> ParseShardingFromDevice( - const string& device_name, int num_cores_per_replica, + const std::string& device_name, int num_cores_per_replica, std::optional explicit_sharding, std::optional metadata) { if (device_name.empty()) { @@ -102,7 +103,7 @@ absl::StatusOr> ParseShardingFromDevice( absl::StatusOr> ParseShardingFromDevice( const NodeDef& node_def, int num_cores_per_replica, bool add_metadata) { - const string& device_name = node_def.device(); + const std::string& device_name = node_def.device(); TF_ASSIGN_OR_RETURN(std::optional sharding, GetShardingFromNodeDef(node_def, add_metadata)); return ParseShardingFromDevice( @@ -114,7 +115,7 @@ absl::StatusOr> ParseShardingFromDevice( absl::StatusOr> ParseShardingFromDevice( const Node& node, int num_cores_per_replica, bool add_metadata) { - string device_name = node.assigned_device_name(); + std::string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } @@ -152,7 +153,7 @@ absl::StatusOr> ParseShardingFromEdgeSource( } void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { - string device_name = src.assigned_device_name(); + std::string device_name = src.assigned_device_name(); if (device_name.empty()) { device_name = src.requested_device(); } @@ -169,7 +170,7 @@ absl::StatusOr> GetShardingFromNodeDefInternal( if (!HasNodeAttr(node_def, attribute)) { return std::optional(); } - string value; + std::string value; xla::OpSharding sharding; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attribute, &value)); if (tensorflow::DecodeShardingAttribute(value, sharding).failed()) { diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index e579f3ee0ff397..85259e0c729883 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -36,7 +36,7 @@ namespace tensorflow { // - a non-value if there is no assigned core or // - a sharding set as per xla::sharding_builder::AssignDevice. absl::StatusOr> ParseShardingFromDevice( - const string& device_name, int num_cores_per_replica, + const std::string& device_name, int num_cores_per_replica, std::optional explicit_sharding = std::nullopt, std::optional metadata = std::nullopt); diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc index 585e3887fe686c..c987e8f167422f 100644 --- a/tensorflow/compiler/tf2xla/sharding_util_test.cc +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -33,7 +33,7 @@ TEST(CoreUtilTest, ParseShardingFromDevice) { Graph graph(OpRegistry::Global()); auto core_from_sharding = - [](std::optional sharding) -> int64 { + [](std::optional sharding) -> int64_t { if (sharding.has_value() && sharding.value().type() == xla::OpSharding::MAXIMAL) { return sharding.value().tile_assignment_devices(0); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index afe82e0de40f62..e8b2a56cdf64d2 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -48,8 +48,8 @@ absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { } else if (node->IsIfNode()) { AttrValue device_ordinal_value; device_ordinal_value.set_i(device_ordinal); - for (const string& attr_name : - std::vector{"then_branch", "else_branch"}) { + for (const std::string& attr_name : + std::vector{"then_branch", "else_branch"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -59,7 +59,8 @@ absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { } else if (node->IsWhileNode()) { AttrValue device_ordinal_value; device_ordinal_value.set_i(device_ordinal); - for (const string& attr_name : std::vector{"cond", "body"}) { + for (const std::string& attr_name : + std::vector{"cond", "body"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), attr_name, &branch_func)); (*branch_func.mutable_attr())["_device_ordinal"] = device_ordinal_value; @@ -80,39 +81,40 @@ absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { std::set CalculateTokenInputsForOutputToken(const Graph& g) { std::set results; Node* first_side_effecting_node_on_path = nullptr; - ReverseDFS(g, - [&](Node* n) { - std::vector token_input_nodes; - if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, - &token_input_nodes) - .ok() || - token_input_nodes.empty()) { - return; - } - - if (first_side_effecting_node_on_path != nullptr) { - return; - } - - first_side_effecting_node_on_path = n; - string original_node_name; - TF_CHECK_OK(GetNodeAttr(n->def(), - kXlaOriginalOutsideCompilationNodeName, - &original_node_name)); - results.insert(original_node_name); - }, - [&](Node* n) { - if (first_side_effecting_node_on_path == n) { - first_side_effecting_node_on_path = nullptr; - } - }, - NodeComparatorName()); + ReverseDFS( + g, + [&](Node* n) { + std::vector token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + std::string original_node_name; + TF_CHECK_OK(GetNodeAttr(n->def(), + kXlaOriginalOutsideCompilationNodeName, + &original_node_name)); + results.insert(original_node_name); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); return results; } bool HasSideEffectingNodes(const Graph& g) { for (Node* n : g.nodes()) { - std::vector token_input_nodes; + std::vector token_input_nodes; if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) .ok() && !token_input_nodes.empty()) { @@ -123,10 +125,10 @@ bool HasSideEffectingNodes(const Graph& g) { } absl::Status ParseHostComputeCoreList( - absl::Span list_from_attr, - std::map* host_compute_core) { + absl::Span list_from_attr, + std::map* host_compute_core) { for (const auto& hc_core : list_from_attr) { - std::vector parts = str_util::Split(hc_core, ":"); + std::vector parts = str_util::Split(hc_core, ":"); if (parts.size() != 2) { return errors::InvalidArgument( "Malformed host_compute_core entry ", hc_core, diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index 34f30eb7661bc1..9ba994a16a3c8e 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -61,8 +61,9 @@ bool HasSideEffectingNodes(const Graph& g); // Parse the mapping from outside_compilation_subgraph name to core number, // which is specified in an attr as a list of strings // :. -absl::Status ParseHostComputeCoreList(absl::Span list_from_attr, - std::map* host_compute_core); +absl::Status ParseHostComputeCoreList( + absl::Span list_from_attr, + std::map* host_compute_core); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 43623a8db8014f..193eb7c08bc08a 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -21,12 +21,12 @@ limitations under the License. namespace tensorflow { absl::Status InstantiateFunctionForTest( - const string& name, const FunctionLibraryDefinition& library, + const std::string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result) { const FunctionDef* fdef = library.Find(name); TF_RET_CHECK(fdef != nullptr); - auto get_func_sig = [&library](const string& op, const OpDef** sig) { + auto get_func_sig = [&library](const std::string& op, const OpDef** sig) { return library.LookUpOpDef(op, sig); }; InstantiationResult inst; diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 2b2eb4f582af3e..2c9cdc1c352238 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -41,7 +41,7 @@ struct InstantiationResultForTest { // Instantiates a function, producing a GraphDef to compare against the // expected graph. absl::Status InstantiateFunctionForTest( - const string& name, const FunctionLibraryDefinition& library, + const std::string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 504e9d0246322e..eccc2dfaf8d4a4 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -32,7 +32,8 @@ namespace tensorflow { namespace tf2xla { namespace { -void PrintSupportedOps(const string& device, const string& regen_run) { +void PrintSupportedOps(const std::string& device, + const std::string& regen_run) { XlaOpRegistry::RegisterCompilationKernels(); std::vector kdefs = @@ -46,10 +47,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { << "Operator | Type Constraint\n" << "-------- | ---------------" << std::endl; for (const KernelDef* kdef : kdefs) { - std::vector constraints; + std::vector constraints; constraints.reserve(kdef->constraint().size()); for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { - std::vector types; + std::vector types; const auto& allowed_values = constraint.allowed_values().list().type(); types.reserve(allowed_values.size()); for (int type : allowed_values) { @@ -70,18 +71,18 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } // namespace void SupportedOpsMain(int argc, char** argv, const char* regen_run) { - std::vector device_names = XlaOpRegistry::BackendNames(); + std::vector device_names = XlaOpRegistry::BackendNames(); std::sort(device_names.begin(), device_names.end()); // Set up and parse flags. - string device; + std::string device; std::vector flag_list = { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + absl::StrJoin(device_names, ",")}, }; - string usage = Flags::Usage(argv[0], flag_list); + std::string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); QCHECK(parsed_flags_ok) << "\n" << usage; QCHECK(XlaOpRegistry::IsBackendRegistered(device)) diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index d61d66bfe53b72..72bd28f2b47a8c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -118,8 +118,8 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::LiteralUtil::CreateR0(10); - auto y_literal = xla::LiteralUtil::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(x_literal); auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); @@ -140,23 +140,23 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } -GraphDef EinsumGraph() { +GraphDef EinsumGraph(DataType dtype = DT_FLOAT) { GraphDef graph_def; NodeDef* x = graph_def.add_node(); x->set_name("x"); x->set_op("Placeholder"); - (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + (*x->mutable_attr())["dtype"] = TypeAttrValue(dtype); NodeDef* y = graph_def.add_node(); y->set_name("y"); y->set_op("Placeholder"); - (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_FLOAT); + (*y->mutable_attr())["dtype"] = TypeAttrValue(dtype); NodeDef* einsum = graph_def.add_node(); einsum->set_name("einsum"); einsum->set_op("Einsum"); einsum->add_input("x"); einsum->add_input("y"); (*einsum->mutable_attr())["equation"] = StringAttrValue("ij,jk->ik"); - (*einsum->mutable_attr())["T"] = TypeAttrValue(DT_FLOAT); + (*einsum->mutable_attr())["T"] = TypeAttrValue(dtype); (*einsum->mutable_attr())["N"] = IntAttrValue(2); return graph_def; } @@ -233,6 +233,35 @@ TEST_F(ConvertGraphDefToXlaWithTF32Disabled, EXPECT_EQ(num_dots, 1); } +TEST_F(ConvertGraphDefToXlaWithTF32Disabled, + EinsumIsConvertedToDotWithDefaultPrecisionIfNotF32) { + GraphDef graph_def = EinsumGraph(DT_BFLOAT16); + tf2xla::Config config = EinsumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + int num_dots = 0; + const xla::HloModuleProto& module_proto = computation.proto(); + for (const xla::HloComputationProto& computation_proto : + module_proto.computations()) { + for (const xla::HloInstructionProto& instruction_proto : + computation_proto.instructions()) { + if (instruction_proto.opcode() == "dot") { + num_dots++; + ASSERT_EQ(instruction_proto.precision_config().operand_precision_size(), + 2); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(0), + xla::PrecisionConfig::DEFAULT); + EXPECT_EQ(instruction_proto.precision_config().operand_precision(1), + xla::PrecisionConfig::DEFAULT); + } + } + } + EXPECT_EQ(num_dots, 1); +} + GraphDef Conv2DGraph() { GraphDef graph_def; NodeDef* x = graph_def.add_node(); @@ -338,8 +367,8 @@ TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. - auto x_literal = xla::LiteralUtil::CreateR0(10); - auto y_literal = xla::LiteralUtil::CreateR0(32); + auto x_literal = xla::LiteralUtil::CreateR0(10); + auto y_literal = xla::LiteralUtil::CreateR0(32); auto x_global_or = client->TransferToServer(x_literal); auto y_global_or = client->TransferToServer(y_literal); auto unused_global_or = client->TransferToServer(y_literal); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 9f21af2741dcde..042b572c234355 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -58,8 +58,9 @@ absl::Status ValidateTensorId(const tf2xla::TensorId& id) { return absl::OkStatus(); } -absl::Status CheckNameDuplicates(const string& kind, const string& name, - std::set* names) { +absl::Status CheckNameDuplicates(const std::string& kind, + const std::string& name, + std::set* names) { if (!name.empty()) { if (!names->insert(name).second) { return errors::InvalidArgument("duplicate ", kind, " name: ", name); @@ -68,12 +69,12 @@ absl::Status CheckNameDuplicates(const string& kind, const string& name, return absl::OkStatus(); } -absl::Status CheckFeedFetchNameConflicts(const string& kind, - const std::set& names) { +absl::Status CheckFeedFetchNameConflicts(const std::string& kind, + const std::set& names) { // We don't allow the feeds or fetches to contain both "foo" and "foo_data", // since that will cause a collision in codegen symbols. - for (const string& name : names) { - const string name_data(name + "_data"); + for (const std::string& name : names) { + const std::string name_data(name + "_data"); if (names.find(name_data) != names.end()) { return errors::InvalidArgument("conflicting ", kind, " name: ", name, " and ", name_data); @@ -227,7 +228,7 @@ absl::Status ReplaceRetvalInputWithArg( // the function to replace _Arg nodes in `const_input_index_to_node` with Const // inputs. absl::Status PropagateConstIntoFuncAttr( - Node* n, const string& attr_name, + Node* n, const std::string& attr_name, const absl::flat_hash_map& const_input_index_to_node, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld, bool passthrough_arg_to_retval = false) { @@ -255,7 +256,7 @@ absl::Status PropagateConstIntoFuncAttr( // Save rewritten function. FunctionDef replace_fdef; - string new_func_name = + std::string new_func_name = fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_")); const StackTracesMap* stack_traces = lookup_fld->GetStackTraces(func_attr.name()); @@ -301,7 +302,7 @@ absl::Status PropagateConstIntoIfNode( // Rewrite "then_branch" and "else_branch" function, replace usage of those // _Arg nodes with corresponding const node. for (const auto& attr_name : - std::vector{"then_branch", "else_branch"}) { + std::vector{"then_branch", "else_branch"}) { TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( if_node, attr_name, const_input_index_to_node, lookup_fld, fld)); } @@ -309,13 +310,14 @@ absl::Status PropagateConstIntoIfNode( return absl::OkStatus(); } -using GraphCache = absl::flat_hash_map>; +using GraphCache = + absl::flat_hash_map>; absl::StatusOr FindOrInsert( GraphCache* cache, const NameAttrList& body_attr, const FunctionLibraryDefinition* lookup_fld, const FunctionLibraryDefinition* fallback_fld) { - const string name = body_attr.name(); + const std::string name = body_attr.name(); std::unique_ptr& value = (*cache)[name]; if (!value) { const FunctionDef* body_func = lookup_fld->Find(name); @@ -413,7 +415,7 @@ absl::Status PropagateConstIntoAndAroundWhileNode( absl::flat_hash_map const_input_index_to_mutable_node; NameAttrList body_attr; TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr)); - const string fn_name = body_attr.name(); + const std::string fn_name = body_attr.name(); const FunctionDef* body_func = lookup_fld->Find(fn_name); if (!body_func) { return errors::Internal("Propagate: Cannot find body function ", fn_name, @@ -461,7 +463,7 @@ absl::Status PropagateConstIntoAndAroundWhileNode( // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with // corresponding const node. - for (const auto& attr_name : std::vector{"cond", "body"}) { + for (const auto& attr_name : std::vector{"cond", "body"}) { TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr( while_node, attr_name, const_input_index_to_node, lookup_fld, fld, /*passthrough_arg_to_retval=*/attr_name == "body")); @@ -487,7 +489,7 @@ absl::StatusOr IsLoopInvariant( } absl::Status ValidateConfig(const tf2xla::Config& config) { - std::set names; + std::set names; for (const tf2xla::Feed& feed : config.feed()) { TF_RETURN_IF_ERROR(ValidateTensorId(feed.id())); TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape())); @@ -508,19 +510,20 @@ absl::Status ValidateConfig(const tf2xla::Config& config) { absl::Status AddPlaceholdersForFeeds( const tf2xla::Config& config, const OpRegistryInterface* op_registry, - std::unordered_map* feed_remapping, GraphDef* graph_def) { + std::unordered_map* feed_remapping, + GraphDef* graph_def) { struct PlaceholderInfo { const tf2xla::Feed* feed = nullptr; // point to Feed in . - string placeholder_name; + std::string placeholder_name; DataType data_type = DT_INVALID; }; // Put each fed tensor into a map by name:port. A map is used for determinism // when creating placeholders (genrules want deterministic output). - std::map placeholder_info; + std::map placeholder_info; for (int i = 0; i < config.feed_size(); ++i) { const tf2xla::Feed* feed = &config.feed(i); - const string name_port = TensorIdToString(feed->id()); + const std::string name_port = TensorIdToString(feed->id()); PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(), @@ -529,7 +532,7 @@ absl::Status AddPlaceholdersForFeeds( } // Verify node exists and determine data type. - std::unordered_map name_to_node; + std::unordered_map name_to_node; for (int i = 0; i < graph_def->node_size(); ++i) { name_to_node[graph_def->node(i).name()] = &graph_def->node(i); } @@ -609,25 +612,25 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, out->clear_node(); // Tensors needed for feeding. - std::set> feed_tensors; + std::set> feed_tensors; for (const tf2xla::Feed& feed : config.feed()) { feed_tensors.insert( std::make_pair(feed.id().node_name(), feed.id().output_index())); } // Maps node name to reachability. - std::unordered_map> node_by_name; + std::unordered_map> node_by_name; for (const NodeDef& node : in.node()) { node_by_name[node.name()] = std::pair(false, &node); } // Traverse. - std::queue name_queue; + std::queue name_queue; for (int i = 0; i < config.fetch_size(); ++i) { name_queue.push(config.fetch(i).id().node_name()); } while (!name_queue.empty()) { - const string name = name_queue.front(); + const std::string name = name_queue.front(); name_queue.pop(); auto find_it = node_by_name.find(name); @@ -642,9 +645,9 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, map_entry.first = true; // Push input nodes of the currently visited node to name_queue. - for (const string& in_edge : map_entry.second->input()) { + for (const std::string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = string(id.first); + const std::string node_name = std::string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); @@ -668,7 +671,7 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, return absl::OkStatus(); } -string TensorIdToString(const tf2xla::TensorId& id) { +std::string TensorIdToString(const tf2xla::TensorId& id) { return absl::StrCat(id.node_name(), ":", id.output_index()); } @@ -682,7 +685,7 @@ absl::Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { std::optional sharding, ParseShardingFromDevice( *possible_match, - /*num_cores_per_replica=*/std::numeric_limits::max(), + /*num_cores_per_replica=*/std::numeric_limits::max(), /*add_metadata=*/false)); if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) { const int core_annotation = sharding.value().tile_assignment_devices(0); @@ -709,7 +712,7 @@ void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, } namespace { -uint32 InitialRandomSeed() { +uint32_t InitialRandomSeed() { // Support plumbing the TF seed through to XLA is being worked on. // If a user wants deterministic behavior, their best option // is to start with a known checkpoint. This also handles issues when @@ -724,13 +727,13 @@ uint32 InitialRandomSeed() { } } // namespace -uint32 GetXLARandomSeed() { +uint32_t GetXLARandomSeed() { // We initialize counter with an odd number and increment it by two // everytime. This ensures that it will never be zero, even // after an overflow. When seeded with zero, some XLA backends // can return all zeros instead of random numbers. - static std::atomic counter(InitialRandomSeed()); - uint32 seed = counter.fetch_add(2); + static std::atomic counter(InitialRandomSeed()); + uint32_t seed = counter.fetch_add(2); std::srand(seed); return std::rand() | 1; } @@ -766,7 +769,7 @@ bool HasAssociatedFunction(const NodeDef& node_def, std::vector GetAssociatedFunctions( const Node& node, const FunctionLibraryDefinition* fld) { std::vector results; - const string& op = node.type_string(); + const std::string& op = node.type_string(); if (fld->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); @@ -795,7 +798,7 @@ std::vector GetAssociatedFunctions( absl::Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name) { + const std::string& rewritten_function_name) { switch (associated_function.type()) { case AssociatedFunctionInfo::kFunctionCallNode: { // Change this node to call the new function. @@ -834,7 +837,7 @@ absl::Status RewriteAssociatedFunction( GradientDef gradient_def; gradient_def.set_function_name(func.name()); gradient_def.set_gradient_func(rewritten_function_name); - string original_grad_func = fld->FindGradient(func.name()); + std::string original_grad_func = fld->FindGradient(func.name()); if (original_grad_func.empty()) { TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def)); } else if (original_grad_func != rewritten_function_name) { @@ -863,9 +866,9 @@ absl::Status RewriteAssociatedFunction( } absl::Status CachedFunctionHandles::GetOrInstantiate( - const string& func_name, AttrSlice attrs, + const std::string& func_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { - string canonicalized_name = Canonicalize(func_name, attrs); + std::string canonicalized_name = Canonicalize(func_name, attrs); auto iter = handles_.find(canonicalized_name); if (iter != handles_.end()) { *handle = iter->second; @@ -919,8 +922,8 @@ absl::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) { } absl::StatusOr BuildIdentityNode( - Graph* graph, const string& node_name, DataType dtype, const Node* input, - std::optional requested_device) { + Graph* graph, const std::string& node_name, DataType dtype, + const Node* input, std::optional requested_device) { // Create identity node. NodeDef ndef; ndef.set_name(node_name); @@ -975,7 +978,7 @@ absl::Status PruneUnreachableFunctionsFromGraph( g.ToGraphDef(&graph_def); FunctionLibraryDefinition reachable_functions = fld->ReachableDefinitions(graph_def); - for (const string& func_name : fld->ListFunctionNames()) { + for (const std::string& func_name : fld->ListFunctionNames()) { if (!reachable_functions.Find(func_name)) { TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name)); } @@ -1106,7 +1109,7 @@ absl::Status RewriteTensorListWithConstElement(Graph* g, // Add rewritten backward While body function. FunctionDef new_fdef; - string new_name = fld->UniqueFunctionName( + std::string new_name = fld->UniqueFunctionName( absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_")); TF_RETURN_IF_ERROR( GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef)); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index f2ce3944ac158c..4da5a474d964dc 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -41,7 +41,8 @@ absl::Status ValidateConfig(const tf2xla::Config& config); // feeds). absl::Status AddPlaceholdersForFeeds( const tf2xla::Config& config, const OpRegistryInterface* op_registry, - std::unordered_map* feed_remapping, GraphDef* graph_def); + std::unordered_map* feed_remapping, + GraphDef* graph_def); // Returns in a copy of , pruned to only include fetches from // . @@ -49,7 +50,7 @@ absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, GraphDef* out); // Returns node:port for the given . -string TensorIdToString(const tf2xla::TensorId& id); +std::string TensorIdToString(const tf2xla::TensorId& id); // Updates the sharding of based on the sharding of its neighbors. // If is true, outgoing edges from are considered; else incoming @@ -61,7 +62,7 @@ void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); // Returns the next random seed to use for seeding xla rng. -uint32 GetXLARandomSeed(); +uint32_t GetXLARandomSeed(); // Indicates how a FunctionDef is associated with a graph node (e.g. the node is // a function call, or the node has function attrs). @@ -74,14 +75,14 @@ class AssociatedFunctionInfo { }; // The function is an attr of the node. - static AssociatedFunctionInfo FunctionAttr(const string& func_name, + static AssociatedFunctionInfo FunctionAttr(const std::string& func_name, const AttrValueMap& attrs, - const string& attr_name) { + const std::string& attr_name) { return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); } // The node is a function call. - static AssociatedFunctionInfo FunctionCall(const string& func_name, + static AssociatedFunctionInfo FunctionCall(const std::string& func_name, const AttrValueMap& attrs) { // attr_name will not be used in this case. return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, @@ -89,7 +90,7 @@ class AssociatedFunctionInfo { } // The node is a SymbolicGradient op. - static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + static AssociatedFunctionInfo SymbolicGradient(const std::string& func_name, const AttrValueMap& attrs) { // attr_name will not be used in this case. return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, @@ -98,15 +99,17 @@ class AssociatedFunctionInfo { AssociatedFunctionType type() const { return type_; } - const string& func_name() const { return func_name_; } + const std::string& func_name() const { return func_name_; } - const string& attr_name() const { return attr_name_; } + const std::string& attr_name() const { return attr_name_; } const AttrValueMap& attrs() const { return attrs_; } private: - AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, - const AttrValueMap& attrs, const string& attr_name) + AssociatedFunctionInfo(AssociatedFunctionType type, + const std::string& func_name, + const AttrValueMap& attrs, + const std::string& attr_name) : type_(type), func_name_(func_name), attrs_(attrs), @@ -114,11 +117,11 @@ class AssociatedFunctionInfo { // Available for all instances. AssociatedFunctionType type_; - string func_name_; + std::string func_name_; AttrValueMap attrs_; // Only available if the function is defined in an attr. - string attr_name_; + std::string attr_name_; }; // Returns if the NodeDef has associated function. @@ -142,7 +145,7 @@ std::vector GetAssociatedFunctions( absl::Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, - const string& rewritten_function_name); + const std::string& rewritten_function_name); // Class to act as cache for FunctionLibraryRuntime::Handle objects. class CachedFunctionHandles { @@ -152,7 +155,7 @@ class CachedFunctionHandles { // Populates `handle` for requested function and attributes. If we have // instantiated the function with the same attributes before, `handle` will be // cached handle; otherwise instantiate the function and populate `handle`. - absl::Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + absl::Status GetOrInstantiate(const std::string& func_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle); // Releases all handles in the cache. Returns first non-OK status if any; @@ -163,7 +166,7 @@ class CachedFunctionHandles { private: FunctionLibraryRuntime* flr_; - std::map handles_; + std::map handles_; CachedFunctionHandles(const CachedFunctionHandles&) = delete; void operator=(const CachedFunctionHandles&) = delete; @@ -179,9 +182,9 @@ struct OutEdgeInfo { absl::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def); // Helper function that builds an Identity node. -absl::StatusOr BuildIdentityNode(Graph* graph, const string& node_name, - DataType dtype, const Node* input, - std::optional requested_device); +absl::StatusOr BuildIdentityNode( + Graph* graph, const std::string& node_name, DataType dtype, + const Node* input, std::optional requested_device); // For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite // body functions to use the Const nodes instead of original _Arg nodes. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index e66a8a38813474..ef64b82f50e5be 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -157,7 +157,7 @@ TEST(ValidateConfig, ConflictingFetchName) { ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); } -static tf2xla::Config FetchesConfig(std::vector fetches) { +static tf2xla::Config FetchesConfig(std::vector fetches) { tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); @@ -409,7 +409,7 @@ TEST(PropagateConstIntoFunctionalNodes, CopiedConstNodeHasUniqueName) { TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn)); const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name()); ASSERT_NE(rewritten_body_fn, nullptr); - std::unordered_map nodes; + std::unordered_map nodes; for (const NodeDef& node_def : rewritten_body_fn->node_def()) { nodes[node_def.name()] = node_def; } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index ec456344bcfced..007ecef7492600 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -87,6 +87,9 @@ absl::Status DataTypeToPrimitiveType(DataType data_type, case tensorflow::DT_FLOAT8_E5M2FNUZ: *type = xla::F8E5M2FNUZ; return absl::OkStatus(); + case tensorflow::DT_FLOAT4_E2M1FN: + *type = xla::F4E2M1FN; + return absl::OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; return absl::OkStatus(); @@ -122,6 +125,7 @@ absl::StatusOr EncodePrimitiveTypeAsDataType( {xla::F8E4M3FNUZ, DT_FLOAT8_E4M3FNUZ}, {xla::F8E4M3B11FNUZ, DT_FLOAT8_E4M3B11FNUZ}, {xla::F8E5M2FNUZ, DT_FLOAT8_E5M2FNUZ}, + {xla::F4E2M1FN, DT_FLOAT4_E2M1FN}, {xla::BF16, DT_BFLOAT16}, {xla::F16, DT_HALF}, {xla::F32, DT_FLOAT}, diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 215decdb4d8843..add79c369b69ef 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -39,7 +39,7 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationAllocator() {} ~XlaCompilationAllocator() override {} - string Name() override { return "xla_compilation"; } + std::string Name() override { return "xla_compilation"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { // Regardless of the size requested, always allocates an XlaExpression. diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 7ca32b83f158af..5ee45e499cb49e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -34,7 +34,7 @@ namespace tensorflow { namespace { -int32 GetResultIndex(const int32* result_index_table, int32 num_results) { +int32_t GetResultIndex(const int32_t* result_index_table, int32_t num_results) { auto it = std::min_element(result_index_table, result_index_table + num_results); @@ -150,7 +150,7 @@ int LookupNameIndex(absl::string_view name, const char** names) { } // namespace -int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { +int XlaCompiledCpuFunction::LookupArgIndex(const std::string& name) const { return LookupNameIndex(name, arg_names_); } @@ -162,7 +162,7 @@ int XlaCompiledCpuFunction::LookupVariableIndex(absl::string_view name) const { return num_args_ - num_variables_ + index; } -int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { +int XlaCompiledCpuFunction::LookupResultIndex(const std::string& name) const { return LookupNameIndex(name, result_names_); } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 3d5bff87b3570f..061982db6fd08f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include "xla/backends/cpu/alignment.h" #include "xla/backends/cpu/buffer_allocation_info.h" #include "xla/backends/cpu/runtime/rng_state_lib.h" @@ -128,14 +129,14 @@ class XlaCompiledCpuFunction { // Result parameter i is described by // buffer_infos[result_index_table[i]]. - const int32* result_index_table_ = nullptr; + const int32_t* result_index_table_ = nullptr; // There are num_results result parameters. int64_t num_results_ = 0; // Entry parameter i is described by // buffer_infos[arg_index_table[i]]. - const int32* arg_index_table_ = nullptr; + const int32_t* arg_index_table_ = nullptr; // There are num_args entry parameters. int64_t num_args_ = 0; @@ -209,7 +210,7 @@ class XlaCompiledCpuFunction { // TODO(fschneider): For now this always returns an empty string because there // is no support for error reporting in XLA. Remove this once all callers are // updated. - string error_msg() const { return error_msg_; } + std::string error_msg() const { return error_msg_; } void set_error_msg(absl::string_view error_msg) { error_msg_ = error_msg; } @@ -302,7 +303,7 @@ class XlaCompiledCpuFunction { // The index remains constant for every instance of XlaCompiledCpuFunction // generated from the same static data, and might not be cheap to determine. // Recommended usage is to capture this in a variable for re-use. - int LookupArgIndex(const string& name) const; + int LookupArgIndex(const std::string& name) const; // Returns the 0-based index for the variable with the given `name`. // Returns -1 if the name wasn't found, or data isn't available. @@ -318,7 +319,7 @@ class XlaCompiledCpuFunction { // The index remains constant for every instance of XlaCompiledCpuFunction // generated from the same static data, and might not be cheap to determine. // Recommended usage is to capture this in a variable for re-use. - int LookupResultIndex(const string& name) const; + int LookupResultIndex(const std::string& name) const; // Returns the name of the argument at `index`. // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. @@ -364,7 +365,7 @@ class XlaCompiledCpuFunction { return buffer_infos_; } - int32 num_buffers() const { return num_buffers_; } + int32_t num_buffers() const { return num_buffers_; } void** buffer_table() const { return buffer_table_; } @@ -423,7 +424,7 @@ class XlaCompiledCpuFunction { } static void set_static_data_result_index_table( - StaticData* static_data, const int32* result_index_table) { + StaticData* static_data, const int32_t* result_index_table) { static_data->result_index_table_ = result_index_table; } @@ -433,7 +434,7 @@ class XlaCompiledCpuFunction { } static void set_static_data_arg_index_table(StaticData* static_data, - const int32* arg_index_table) { + const int32_t* arg_index_table) { static_data->arg_index_table_ = arg_index_table; } @@ -530,21 +531,21 @@ class XlaCompiledCpuFunction { // Describes the buffers used by the XLA computation. const xla::cpu::BufferAllocationInfo* const buffer_infos_; - const int32 num_buffers_; + const int32_t num_buffers_; // Indices of expanded result tuple. - const int32 num_results_; - const int32* const result_index_table_; + const int32_t num_results_; + const int32_t* const result_index_table_; // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] // for XLA generated code to be able to find it. - const int32* const arg_index_table_; + const int32_t* const arg_index_table_; // The number of incoming arguments. - const int32 num_args_; + const int32_t num_args_; // The number of incoming variables. - const int32 num_variables_; + const int32_t num_variables_; // Shapes of the input arguments. const ShapeInfo* const arg_shape_infos_; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9e761dc6003d80..5088badf28e9cb 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -130,7 +131,7 @@ ComputeArgAndRetvalShardings(const Graph& graph) { [](const Node* n) -> absl::StatusOr> { TF_ASSIGN_OR_RETURN( auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max(), + ParseShardingFromDevice(*n, std::numeric_limits::max(), /*add_metadata=*/false)); return sharding; }; @@ -173,7 +174,7 @@ absl::Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, xla_context->Ref(); absl::Status status; auto step_container = std::make_unique( - step_id, [&status, device](const string& name) { + step_id, [&status, device](const std::string& name) { status = device->resource_manager()->Cleanup(name); }); TF_RETURN_IF_ERROR(step_container->Create(device->resource_manager(), @@ -484,8 +485,8 @@ absl::Status BuildComputation( } // namespace -string XlaCompiler::Argument::HumanString() const { - string common; +std::string XlaCompiler::Argument::HumanString() const { + std::string common; if (!name.empty()) { common = absl::StrCat(" name=", name); } @@ -503,7 +504,7 @@ string XlaCompiler::Argument::HumanString() const { return absl::StrCat("kind=constant-resource", common, " value=", constant_value.DebugString()); case kResource: { - string output = absl::StrCat( + std::string output = absl::StrCat( "kind=resource", common, " resource_kind=", XlaResource::KindToString(resource_kind), " initialized=", initialized, " is_fast_mem=", fast_mem); @@ -543,7 +544,7 @@ XlaCompiler::Argument::DimensionSizesAsInlinedVector() const { } } -string XlaCompiler::Argument::ShapeHumanString() const { +std::string XlaCompiler::Argument::ShapeHumanString() const { if (absl::holds_alternative(shape)) { return std::get(shape).DebugString(); } else { @@ -592,9 +593,9 @@ XlaCompiler::~XlaCompiler() = default; int64_t XlaCompiler::NextStepId() { return next_step_id_++; } -uint64 XlaCompiler::SignatureHash::operator()( - const std::pair>& signature) const { - return std::hash()(signature.first); +uint64_t XlaCompiler::SignatureHash::operator()( + const std::pair>& signature) const { + return std::hash()(signature.first); } static absl::Status GetFunctionBody(const NameAttrList& function, @@ -703,9 +704,9 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) .IgnoreError(); auto node_name_index = graph->BuildNodeNameIndex(); - std::unordered_map> shape_map; + std::unordered_map> shape_map; for (const auto& node_shape_info : shape_info) { - const string& node_name = node_shape_info.first; + const std::string& node_name = node_shape_info.first; const std::vector& output_shapes = node_shape_info.second; const auto& node_iter = node_name_index.find(node_name); if (node_iter != node_name_index.end()) { @@ -726,9 +727,9 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) .IgnoreError(); auto node_name_index = graph->BuildNodeNameIndex(); - std::unordered_map> shape_map; + std::unordered_map> shape_map; for (const auto& node_shape_info : shape_info) { - const string& node_name = node_shape_info.first; + const std::string& node_name = node_shape_info.first; const std::vector& output_shapes = node_shape_info.second; const auto& node_iter = node_name_index.find(node_name); if (node_iter != node_name_index.end()) { @@ -754,7 +755,7 @@ std::vector GetValidControlRets( // the map with nodes in FunctionDef control_ret_nodes and later query it // using the nodes in `graph`. The Node pointers would be different but the // Node name is expected to remain the same between the two. - absl::flat_hash_map control_ret_nodes_map; + absl::flat_hash_map control_ret_nodes_map; for (int i = 0; i < orig_control_ret_nodes.size(); ++i) { const Node* n = orig_control_ret_nodes[i]; control_ret_nodes_map[n->name()] = i; @@ -814,7 +815,7 @@ absl::Status XlaCompiler::CompileFunction( const NameAttrList& fn_name_attrs, absl::Span args, XlaCompiler::CompilationResult* result) { - string function_id = + std::string function_id = Canonicalize(fn_name_attrs.name(), AttrSlice(&fn_name_attrs.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; @@ -1325,7 +1326,7 @@ namespace { absl::Status ValidateFunctionDef(const FunctionDef* fdef, const FunctionLibraryDefinition& flib_def) { for (const NodeDef& node : fdef->node_def()) { - const string& op = node.op(); + const std::string& op = node.op(); if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { continue; } @@ -1340,7 +1341,8 @@ absl::Status ValidateFunctionDef(const FunctionDef* fdef, // Returned pointer points to the internal string either in node's attributes // or in its NodeDef. This pointer is valid as long as the node has not been // modified. -absl::Status GetPotentialFunctionName(const Node& node, const string** name) { +absl::Status GetPotentialFunctionName(const Node& node, + const std::string** name) { if (node.IsPartitionedCall()) { const AttrValue* attr_value; TF_RETURN_IF_ERROR( @@ -1361,7 +1363,8 @@ absl::Status GetPotentialFunctionName(const Node& node, const string** name) { // given device_type, invalid data type, missing attributes...) absl::Status ValidateGraph(const Graph* graph, const FunctionLibraryDefinition& flib_def, - const DeviceType& device_type, const string& name) { + const DeviceType& device_type, + const std::string& name) { // Make sure the XLA compilation kernels are registered. This operation is // idempotent so it is fine if someone called it already. XlaOpRegistry::RegisterCompilationKernels(); @@ -1398,7 +1401,7 @@ absl::Status ValidateGraph(const Graph* graph, if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { continue; } - const string* function_name; + const std::string* function_name; TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name)); const FunctionDef* fdef = flib_def.Find(*function_name); absl::Status s; @@ -1455,6 +1458,36 @@ class DummyStackTrace : public AbstractStackTrace { }; namespace { +const xla::HloInstructionProto* FindInstructionById( + const xla::HloComputationProto& computation, int64_t id) { + auto iter = + absl::c_find_if(computation.instructions(), + [id](const xla::HloInstructionProto& instruction) { + return instruction.id() == id; + }); + if (iter == computation.instructions().end()) { + return nullptr; + } + return &(*iter); +} + +bool ShouldAddPrecisionToInstruction( + const xla::HloInstructionProto& instruction, + const xla::HloComputationProto& computation) { + static constexpr std::array kOpsPossiblyUsingTF32 = { + "dot", "convolution"}; + if (!absl::c_linear_search(kOpsPossiblyUsingTF32, instruction.opcode())) { + return false; + } + if (instruction.shape().element_type() == xla::F32) { + return true; + } + return absl::c_any_of(instruction.operand_ids(), [&](int64_t operand_id) { + const xla::HloInstructionProto* operand = + FindInstructionById(computation, operand_id); + return operand && operand->shape().element_type() == xla::F32; + }); +} // Add precisions configs to the HLO module to avoid TensorFloat32 computations // in XLA. @@ -1462,13 +1495,7 @@ namespace { // Some operations, such as Einsum are converted through MlirXlaOpKernel, which // doesn't set the precisions, so we set them all here. // -// TODO(tdanyluk): We may want to restrict this logic to only set the operand -// precision for F32 operands. (Historically, it was set without regard to -// operand type in other parts of TF2XLA.) void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { - static constexpr std::array kOpsPossiblyUsingTF32 = { - "dot", "convolution"}; - xla::PrecisionConfig precision_config; precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); precision_config.add_operand_precision(xla::PrecisionConfig::HIGHEST); @@ -1476,8 +1503,7 @@ void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { for (xla::HloComputationProto& computation : *module.mutable_computations()) { for (xla::HloInstructionProto& instruction : *computation.mutable_instructions()) { - if (absl::c_find(kOpsPossiblyUsingTF32, instruction.opcode()) != - kOpsPossiblyUsingTF32.end()) { + if (ShouldAddPrecisionToInstruction(instruction, computation)) { *instruction.mutable_precision_config() = precision_config; } } @@ -1487,7 +1513,7 @@ void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { } // namespace absl::Status XlaCompiler::CompileGraph( - const XlaCompiler::CompileOptions& options, string const& name, + const XlaCompiler::CompileOptions& options, const std::string& name, std::unique_ptr graph, absl::Span args, CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; @@ -1689,7 +1715,7 @@ xla::ChannelHandle XlaCompiler::NewChannel( return new_handle; } -absl::Status XlaCompiler::GetChannelHandle(const string& key, +absl::Status XlaCompiler::GetChannelHandle(const std::string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { @@ -1701,7 +1727,7 @@ absl::Status XlaCompiler::GetChannelHandle(const string& key, } absl::Status XlaCompiler::GetHostToDeviceChannelHandle( - const string& key, xla::ChannelHandle* channel) { + const std::string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { result.first->second = NewChannel(xla::ChannelHandle::HOST_TO_DEVICE); @@ -1712,7 +1738,7 @@ absl::Status XlaCompiler::GetHostToDeviceChannelHandle( } absl::Status XlaCompiler::GetDeviceToHostChannelHandle( - const string& key, xla::ChannelHandle* channel) { + const std::string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { result.first->second = NewChannel(xla::ChannelHandle::DEVICE_TO_HOST); @@ -1724,7 +1750,7 @@ absl::Status XlaCompiler::GetDeviceToHostChannelHandle( namespace { -void SetTransfer(const string& key, absl::Span types, +void SetTransfer(const std::string& key, absl::Span types, absl::Span shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); @@ -1739,7 +1765,7 @@ void SetTransfer(const string& key, absl::Span types, } // namespace absl::Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, absl::Span types, + const std::string& key, absl::Span types, absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key]; @@ -1759,7 +1785,7 @@ absl::Status XlaCompiler::SetDeviceToHostMetadata( } absl::Status XlaCompiler::GetDeviceToHostShapes( - const string& key, std::vector* shapes) const { + const std::string& key, std::vector* shapes) const { const auto iter = host_compute_sends_.find(key); if (iter == host_compute_sends_.end()) { return errors::InvalidArgument( @@ -1774,7 +1800,7 @@ absl::Status XlaCompiler::GetDeviceToHostShapes( } absl::Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, absl::Span types, + const std::string& key, absl::Span types, absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) { tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key]; @@ -1794,7 +1820,7 @@ absl::Status XlaCompiler::SetHostToDeviceMetadata( } absl::Status XlaCompiler::GetHostComputeControlDependency( - const string& host_compute_name, xla::XlaOp* handle) { + const std::string& host_compute_name, xla::XlaOp* handle) { const auto iter = host_compute_control_output_.find(host_compute_name); if (iter == host_compute_control_output_.end()) { return errors::InvalidArgument( @@ -1807,7 +1833,7 @@ absl::Status XlaCompiler::GetHostComputeControlDependency( } absl::Status XlaCompiler::SetHostComputeControlDependency( - const string& host_compute_name, const xla::XlaOp handle) { + const std::string& host_compute_name, const xla::XlaOp handle) { if (host_compute_control_output_.find(host_compute_name) != host_compute_control_output_.end()) { return errors::InvalidArgument( @@ -1819,7 +1845,7 @@ absl::Status XlaCompiler::SetHostComputeControlDependency( } void XlaCompiler::PushNodeTokenMapping() { - node_token_mapping_stack_.emplace(std::map{}); + node_token_mapping_stack_.emplace(std::map{}); } absl::Status XlaCompiler::PopNodeTokenMapping() { @@ -1832,7 +1858,7 @@ absl::Status XlaCompiler::PopNodeTokenMapping() { return absl::OkStatus(); } -absl::Status XlaCompiler::SetNodeToken(const string& node_name, +absl::Status XlaCompiler::SetNodeToken(const std::string& node_name, const xla::XlaOp op) { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( @@ -1847,7 +1873,8 @@ absl::Status XlaCompiler::SetNodeToken(const string& node_name, return absl::OkStatus(); } -absl::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { +absl::StatusOr XlaCompiler::GetNodeToken( + const std::string& node_name) { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( "Calling GetNodeToken() when node_token_mapping_stack_ is " diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 2beb730eb06fa3..216125f9cb153e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -277,7 +277,8 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - absl::Status CompileGraph(const CompileOptions& options, string const& name, + absl::Status CompileGraph(const CompileOptions& options, + const std::string& name, std::unique_ptr graph, absl::Span args, CompilationResult* result); @@ -295,31 +296,32 @@ class XlaCompiler { // Channel handles can be used to communicate between different // computations. Computations that communicate should be compiled with the // same XlaCompiler. - absl::Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + absl::Status GetChannelHandle(const std::string& key, + xla::ChannelHandle* channel); // Retrieves the host-to-device channel handle associated with `key`. // Allocates a new channel handle if none exists. - absl::Status GetHostToDeviceChannelHandle(const string& key, + absl::Status GetHostToDeviceChannelHandle(const std::string& key, xla::ChannelHandle* channel); // Retrieves the device-to-host channel handle associated with `key`. // Allocates a new channel handle if none exists. - absl::Status GetDeviceToHostChannelHandle(const string& key, + absl::Status GetDeviceToHostChannelHandle(const std::string& key, xla::ChannelHandle* channel); // Sets the shapes and types for the device to host transfer associated with // 'key'. - absl::Status SetDeviceToHostMetadata(const string& key, + absl::Status SetDeviceToHostMetadata(const std::string& key, absl::Span types, absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. - absl::Status GetDeviceToHostShapes(const string& key, + absl::Status GetDeviceToHostShapes(const std::string& key, std::vector* shapes) const; // Sets the shapes and types for the host to device transfer associated with // 'key'. - absl::Status SetHostToDeviceMetadata(const string& key, + absl::Status SetHostToDeviceMetadata(const std::string& key, absl::Span types, absl::Span shapes); @@ -334,10 +336,10 @@ class XlaCompiler { // 'host_compute_name' can be any string the client wishes to use to identify // a given HostCompute Op as long as the names are unique within the // compilation. - absl::Status GetHostComputeControlDependency(const string& host_compute_name, - xla::XlaOp* handle); - absl::Status SetHostComputeControlDependency(const string& host_compute_name, - xla::XlaOp handle); + absl::Status GetHostComputeControlDependency( + const std::string& host_compute_name, xla::XlaOp* handle); + absl::Status SetHostComputeControlDependency( + const std::string& host_compute_name, xla::XlaOp handle); const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } @@ -345,8 +347,8 @@ class XlaCompiler { void PushNodeTokenMapping(); absl::Status PopNodeTokenMapping(); - absl::Status SetNodeToken(const string& node_name, xla::XlaOp op); - absl::StatusOr GetNodeToken(const string& node_name); + absl::Status SetNodeToken(const std::string& node_name, xla::XlaOp op); + absl::StatusOr GetNodeToken(const std::string& node_name); // Sets the function body `fbody` to the one registered as `function`. absl::Status FindFunctionBody(const NameAttrList& function, @@ -405,20 +407,22 @@ class XlaCompiler { FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. struct SignatureHash { - uint64 operator()( - const std::pair>& signature) const; + uint64_t operator()( + const std::pair>& signature) const; }; - std::unordered_map>, + std::unordered_map>, CompilationResult, SignatureHash> cache_; - std::unordered_map channels_; + std::unordered_map channels_; - std::unordered_map host_compute_sends_; - std::unordered_map host_compute_recvs_; + std::unordered_map + host_compute_sends_; + std::unordered_map + host_compute_recvs_; - std::unordered_map host_compute_control_output_; + std::unordered_map host_compute_control_output_; // This is used to store mapping. Side-effecting // ops call SetNodeToken() to record its token output, so later side-effecting @@ -427,7 +431,7 @@ class XlaCompiler { // It's a stack because we need a mapping like this for each level of nested // CompileGraph() call. In CompileGraph(), we will push a new mapping to the // stack, and pop the mapping before returning. - std::stack> node_token_mapping_stack_; + std::stack> node_token_mapping_stack_; XlaCompiler(const XlaCompiler&) = delete; void operator=(const XlaCompiler&) = delete; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7d8a4f2c431e80..a29094470b911f 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -140,7 +140,7 @@ namespace { // compiled kernels. class DummyResourceForTest : public ResourceBase { public: - string DebugString() const override { return "dummy"; } + std::string DebugString() const override { return "dummy"; } void Increment() { ++value_; } int Get() { return value_; } @@ -268,8 +268,8 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, Simple) { .value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } @@ -366,8 +366,8 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -484,7 +484,7 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -602,7 +602,7 @@ TEST_F(XlaCompilerTest, MixedOrderArguments) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -680,7 +680,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // func(a) { b=7; c=-a; return b, c; } Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); - auto b = ops::Const(scope.WithOpName("B"), 7); + auto b = ops::Const(scope.WithOpName("B"), 7); auto c = ops::Neg(scope.WithOpName("C"), a); auto d = ops::_Retval(scope.WithOpName("D"), b, 0); auto e = ops::_Retval(scope.WithOpName("E"), c, 1); @@ -710,7 +710,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); @@ -718,8 +718,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Execute(*result.computation, {param0_data.get()}).value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); xla::Literal expected = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); @@ -885,7 +885,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { // The names of instructions were uniquified by the XlaBuilder and the // unique ids may be different, the rest of the fields should be // identical. - string str1, str2; + std::string str1, str2; LOG(INFO) << "instr1 = " << instr1.DebugString(); LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); @@ -904,7 +904,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { auto flow = ops::Const(scope, {}); auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2"); - auto index = ops::Const(scope, 1); + auto index = ops::Const(scope, 1); auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index, grad2.flow_out); auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32); @@ -933,12 +933,12 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; EXPECT_EQ(0, update.input_index); EXPECT_EQ(DT_INT32, update.type); - EXPECT_EQ((std::set{"grad1", "grad2"}), + EXPECT_EQ((std::set{"grad1", "grad2"}), update.tensor_array_gradients_accessed); // Tests that the generated computation works. - xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = client_->TransferToServer(input).value(); @@ -947,10 +947,10 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { client_->Execute(*result.computation, {param0_data.get()}).value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal output_read = xla::LiteralUtil::CreateR0(42); - xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); - xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); xla::Literal output_resource = xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); xla::Literal expected_literal = @@ -964,7 +964,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); auto flow = ops::Const(scope, {}); auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); - auto index = ops::Const(scope, 1); + auto index = ops::Const(scope, 1); auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); @@ -996,7 +996,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); auto flow = ops::Const(scope, {}); auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2"); - auto index = ops::Const(scope, 1); + auto index = ops::Const(scope, 1); auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); @@ -1067,8 +1067,8 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("value"), 1, {}); - auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); NodeDef def; @@ -1151,9 +1151,9 @@ TEST_F(XlaCompilerTest, SliceWithDynamicBegins) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("shape"), {5}, {1}); + auto value = ops::Const(scope.WithOpName("shape"), {5}, {1}); auto begin = ops::_Arg(scope.WithOpName("arg"), DT_INT32, 0); - auto size = ops::Const(scope.WithOpName("value"), {1}, {1}); + auto size = ops::Const(scope.WithOpName("value"), {1}, {1}); TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); @@ -1188,8 +1188,8 @@ TEST_F(XlaCompilerTest, SliceWithDynamicBegins) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = client->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -1201,8 +1201,8 @@ void RunAndCheckVariablesComputation( .value(); xla::Literal actual_literal = client->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1220,7 +1220,7 @@ TEST_F(XlaCompilerTest, Variables) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -1356,7 +1356,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = client_->TransferToServer(param1_literal).value(); @@ -1379,7 +1379,7 @@ TEST_F(XlaCompilerTest, ReturnResourceHandle) { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto r = ops::_Retval(scope.WithOpName("R"), var, 0); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1); @@ -1414,7 +1414,7 @@ absl::StatusOr> BuildTestGraph() { auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); - auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); @@ -1475,9 +1475,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Tests that the generated computation works. xla::Literal param0_literal = - xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); + xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); xla::Literal param1_literal = - xla::LiteralUtil::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -1490,8 +1490,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal actual_literal = client_->Transfer(*actual).value(); xla::Literal expected0 = - xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); + xla::Literal expected1 = + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1547,9 +1548,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Tests that the generated computation works. xla::Literal param0_literal = - xla::LiteralUtil::CreateR1({4, 55, 1, -3}); + xla::LiteralUtil::CreateR1({4, 55, 1, -3}); xla::Literal param1_literal = - xla::LiteralUtil::CreateR1({22, 11, 33, 404}); + xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).value(); std::unique_ptr param1_data = @@ -1561,8 +1562,10 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { .value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected0 = + xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = + xla::LiteralUtil::CreateR1({26, 66, 34, 401}); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1587,8 +1590,8 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope scope = Scope::NewRootScope().ExitOnError(); - auto value = ops::Const(scope.WithOpName("value"), 1, {}); - auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); NodeDef def; @@ -1684,7 +1687,8 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { side_effecting_op.set_name("DummySideEffectingOp"); side_effecting_op.set_op("DummySideEffectingOp"); AddNodeAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}, &side_effecting_op); + std::vector{kXlaTokenArgNodeName}, + &side_effecting_op); AddNodeAttr(kXlaOriginalOutsideCompilationNodeName, side_effecting_op.name(), &side_effecting_op); absl::Status status; @@ -1768,8 +1772,8 @@ TEST_F(XlaCompilerTest, OpsWithTensorListInput) { } Scope scope = Scope::NewRootScope().ExitOnError(); - auto element_shape = ops::Const(scope, {1}, {1}); - auto max_elements = ops::Const(scope, {10}, {}); + auto element_shape = ops::Const(scope, {1}, {1}); + auto max_elements = ops::Const(scope, {10}, {}); auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0); std::initializer_list out = {arg, arg}; auto add_n = ops::AddN(scope, out); @@ -1822,7 +1826,7 @@ TEST_F(XlaCompilerTest, WhileWithResources) { auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0); auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2); - auto less = ops::Less(scope, arg0, ops::Const(scope, 10)); + auto less = ops::Less(scope, arg0, ops::Const(scope, 10)); (void)ops::_Retval(scope.WithOpName("ret"), less, 0); TF_ASSERT_OK(scope.ToGraph(graph.get())); FunctionDef fdef; @@ -1899,9 +1903,9 @@ TEST_F(XlaCompilerTest, WhileWithResources) { ASSERT_EQ(output2.input_index, 2); // Tests that the generated computation works. - xla::Literal literal0 = xla::LiteralUtil::CreateR0(0); - xla::Literal literal1 = xla::LiteralUtil::CreateR0(2); - xla::Literal literal2 = xla::LiteralUtil::CreateR0(1); + xla::Literal literal0 = xla::LiteralUtil::CreateR0(0); + xla::Literal literal1 = xla::LiteralUtil::CreateR0(2); + xla::Literal literal2 = xla::LiteralUtil::CreateR0(1); std::unique_ptr data0 = client_->TransferToServer(literal0).value(); std::unique_ptr data1 = @@ -1916,9 +1920,9 @@ TEST_F(XlaCompilerTest, WhileWithResources) { .value(); xla::Literal actual_literal = client_->Transfer(*actual).value(); - xla::Literal expected0 = xla::LiteralUtil::CreateR0(10); - xla::Literal expected1 = xla::LiteralUtil::CreateR0(2); - xla::Literal expected2 = xla::LiteralUtil::CreateR0(1); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(10); + xla::Literal expected1 = xla::LiteralUtil::CreateR0(2); + xla::Literal expected2 = xla::LiteralUtil::CreateR0(1); xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); @@ -1978,7 +1982,7 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) { TEST_F(XlaCompilerTest, AliasResourceUpdates) { Scope scope = Scope::NewRootScope().ExitOnError(); - auto a = ops::Const(scope.WithOpName("A"), {1, 2}); + auto a = ops::Const(scope.WithOpName("A"), {1, 2}); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); auto write = ops::AssignAddVariableOp(scope, var, a); auto read = ops::ReadVariableOp( @@ -2022,7 +2026,7 @@ TEST_F(XlaCompilerTest, AliasResourceUpdates) { TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; @@ -2035,7 +2039,7 @@ TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; std::vector types2{DT_FLOAT}; @@ -2051,7 +2055,7 @@ TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; @@ -2064,7 +2068,7 @@ TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) { XlaCompiler compiler(DefaultOptions()); - const string& key = "comm_key"; + const std::string& key = "comm_key"; std::vector types{DT_INT32}; std::vector shapes{TensorShape({2})}; std::vector types2{DT_FLOAT}; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 92ddf0125aded1..fad607b1ae1333 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -67,7 +67,7 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, } } -string XlaContext::DebugString() const { return "XLA JIT context"; } +std::string XlaContext::DebugString() const { return "XLA JIT context"; } void XlaContext::SetRetval(int index, const XlaExpression& expression) { const int64_t retvals_size = retvals_.size(); @@ -84,7 +84,7 @@ XlaResource* XlaContext::AddResource(std::unique_ptr resource) { const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -100,7 +100,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { return LookupOrCreate(type, &min_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -116,7 +116,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { return LookupOrCreate(type, &add_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -133,7 +133,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateLogAddExp( const DataType type) { return LookupOrCreate(type, &log_add_exp_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building LogAddExp() for " << type_string; xla::XlaBuilder b("log_add_exp<" + type_string + ">"); xla::PrimitiveType xla_type; @@ -154,7 +154,7 @@ const xla::XlaComputation* XlaContext::GetOrCreateLogAddExp( const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { return LookupOrCreate(type, &mul_func_, [type] { - const string type_string = DataTypeString(type); + const std::string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 9184fb4300633c..1d72f0c756f364 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -50,7 +50,7 @@ class XlaContext : public ResourceBase { const Graph* graph); // Virtual method defined by ResourceBase. - string DebugString() const override; + std::string DebugString() const override; XlaCompiler* compiler() const { return compiler_; } diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 61bd10e413ccf3..e867dd14209ab8 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -73,7 +73,7 @@ XlaExpression XlaExpression::Resource(XlaResource* resource) { return e; } -string XlaExpression::HumanString() const { +std::string XlaExpression::HumanString() const { switch (kind_) { case Kind::kInvalid: return "invalid"; diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index d410b79a3da137..ed0041fc9942a0 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -115,7 +115,7 @@ class XlaExpression { XlaResource* resource() const { return resource_; } // Returns a human-readable summary of the expression. - string HumanString() const; + std::string HumanString() const; // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns // an erroneous XlaOp if the expression is not a constant or an expression. diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc index 7a0cc34de9af2e..797002476aeb1c 100644 --- a/tensorflow/compiler/tf2xla/xla_expression_test.cc +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -38,14 +38,15 @@ class XlaExpressionTest : public ::testing::Test { void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); builder_ = std::make_unique("acomputation"); - constant_ = test::AsScalar(42); - op_ = xla::ConstantR0(builder_.get(), 7); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); non_constant_op_ = xla::Parameter( builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); resource_ = std::make_unique( - XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), - DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, - /*tensor_array_gradients=*/std::set(), + XlaResource::kVariable, /*arg_num=*/0, + /*name=*/std::string("avariable"), DT_INT32, TensorShape({17, 3}), op_, + /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), /*tensor_array_multiple_writes_aggregate=*/false); } @@ -87,8 +88,8 @@ TEST_F(XlaExpressionTest, AsXlaOp) { builder_->BuildConstantSubGraph(const_as_op)); TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, client_->ComputeConstant(computation)); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), - value)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal( + xla::LiteralUtil::CreateR0(42), value)); } TEST_F(XlaExpressionTest, GetShape) { @@ -120,7 +121,7 @@ TEST_F(XlaExpressionTest, ResolveConstant) { std::optional op_constant, XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); ASSERT_TRUE(op_constant.has_value()); - test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); TF_ASSERT_OK_AND_ASSIGN(std::optional op_nonconstant, XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) @@ -131,7 +132,7 @@ TEST_F(XlaExpressionTest, ResolveConstant) { std::optional constant_constant, XlaExpression::Constant(constant_).ResolveConstant(client_)); ASSERT_TRUE(constant_constant.has_value()); - test::ExpectTensorEqual(constant_, *constant_constant); + test::ExpectTensorEqual(constant_, *constant_constant); } TEST_F(XlaExpressionTest, ResolveConstantOnResource) { diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 38f01c83db8251..0b3425e5b8524a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -136,7 +136,7 @@ struct XlaResourceUpdate { bool modified; // If the resource is a TensorArray, the set of gradients read or written. - std::set tensor_array_gradients_accessed; + std::set tensor_array_gradients_accessed; }; struct XlaCompilationResult { diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 48e562ce5c7810..b374e8c8e81dd6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -76,12 +76,12 @@ int CountResults( // tf2xla::{Feed,Fetch,Variable}. We hold the actual strings in nonempty_names, // and hold arrays of pointers in name_ptrs, terminated by a nullptr entry. template -void CollectNames(const T& entries, std::vector* nonempty_names, +void CollectNames(const T& entries, std::vector* nonempty_names, std::vector* name_ptrs) { // First collect `nonempty_names`, to ensure the underlying strings won't // change out from under us. for (const auto& entry : entries) { - const string& name = entry.name(); + const std::string& name = entry.name(); if (!name.empty()) { nonempty_names->push_back(name); } @@ -90,7 +90,7 @@ void CollectNames(const T& entries, std::vector* nonempty_names, name_ptrs->reserve(entries.size() + 1); // +1 for nullptr array terminator size_t nonempty_index = 0; for (const auto& entry : entries) { - const string& name = entry.name(); + const std::string& name = entry.name(); if (!name.empty()) { name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str()); ++nonempty_index; @@ -158,9 +158,9 @@ XlaJitCompiledCpuFunction::Compile( xla::cpu::CreateBufferAllocationInfos(cpu_executable->module(), buffer_assignment); - std::vector arg_index_table = + std::vector arg_index_table = xla::cpu::CreateArgIndexTable(buffer_infos); - std::vector result_index_table = + std::vector result_index_table = xla::cpu::CreateResultIndexTable(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index 0678c3be6c67f6..6f61f472a2fd5a 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h" #include "xla/backends/cpu/buffer_allocation_info.h" @@ -85,17 +86,17 @@ class XlaJitCompiledCpuFunction { std::vector buffer_infos_; // The backing array for the arg index table. - std::vector arg_index_table_; + std::vector arg_index_table_; // The backing array for the result index table. - std::vector result_index_table_; + std::vector result_index_table_; // The backing arrays of arg and result names. We hold the actual strings in // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static // data to refer to. - std::vector nonempty_arg_names_; - std::vector nonempty_variable_names_; - std::vector nonempty_result_names_; + std::vector nonempty_arg_names_; + std::vector nonempty_variable_names_; + std::vector nonempty_result_names_; std::vector arg_names_; std::vector variable_names_; std::vector result_names_; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index acac1efd73881f..b49e699d6e267f 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -182,18 +182,18 @@ TEST(XlaJitCompiledCpuFunction, Sum) { ASSERT_EQ(function.num_results(), 1); // Run the function and check results. - *static_cast(function.arg_data(0)) = 10; - *static_cast(function.arg_data(1)) = 32; + *static_cast(function.arg_data(0)) = 10; + *static_cast(function.arg_data(1)) = 32; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 42); + EXPECT_EQ(*static_cast(function.result_data(0)), 42); // Run the function again. - *static_cast(function.arg_data(0)) = 100; - *static_cast(function.arg_data(1)) = 320; + *static_cast(function.arg_data(0)) = 100; + *static_cast(function.arg_data(1)) = 320; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 420); + EXPECT_EQ(*static_cast(function.result_data(0)), 420); // Check name to index lookups. EXPECT_TRUE(function.HasNameIndices()); @@ -268,20 +268,20 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { ASSERT_EQ(function.num_results(), 2); // Run the function and check results. - *static_cast(function.arg_data(0)) = 10; - *static_cast(function.arg_data(1)) = 32; + *static_cast(function.arg_data(0)) = 10; + *static_cast(function.arg_data(1)) = 32; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 10); - EXPECT_EQ(*static_cast(function.result_data(1)), 42); + EXPECT_EQ(*static_cast(function.result_data(0)), 10); + EXPECT_EQ(*static_cast(function.result_data(1)), 42); // Run the function again. - *static_cast(function.arg_data(0)) = 100; - *static_cast(function.arg_data(1)) = 320; + *static_cast(function.arg_data(0)) = 100; + *static_cast(function.arg_data(1)) = 320; EXPECT_TRUE(function.Run()); EXPECT_EQ(function.error_msg(), ""); - EXPECT_EQ(*static_cast(function.result_data(0)), 100); - EXPECT_EQ(*static_cast(function.result_data(1)), 420); + EXPECT_EQ(*static_cast(function.result_data(0)), 100); + EXPECT_EQ(*static_cast(function.result_data(1)), 420); // Check name to index lookups. EXPECT_TRUE(function.HasNameIndices()); @@ -325,7 +325,7 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { int VisibleDeviceCount() const override { return 0; } - const string& Name() const override { return name_; } + const std::string& Name() const override { return name_; } absl::StatusOr> DescriptionForDevice( int ordinal) const override { @@ -338,7 +338,7 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { } private: - string name_; + std::string name_; }; TF_EXPECT_OK( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 0456328617e8a8..baefe0138d43dd 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -207,9 +207,9 @@ static absl::Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S16) { - *out = literal.Get({}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S32) { - *out = literal.Get({}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S64) { *out = literal.Get({}); } else { @@ -370,7 +370,7 @@ static absl::Status LiteralToInt64Vector(const xla::LiteralSlice& literal, int64_t size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { for (int64_t i = 0; i < size; ++i) { - out->push_back(literal.Get({i})); + out->push_back(literal.Get({i})); } } else if (literal.shape().element_type() == xla::S64) { for (int64_t i = 0; i < size; ++i) { @@ -422,7 +422,7 @@ absl::Status XlaOpKernelContext::ConstantInputAsInt64Literal( case xla::S32: { *out = xla::Literal( xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64)); - auto src_data = literal.data(); + auto src_data = literal.data(); for (int64_t i = 0; i < src_data.size(); ++i) { out->data()[i] = src_data[i]; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 445065971f2a6a..c74db865769229 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -61,7 +61,7 @@ static absl::Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { NodeDef node_def; node_def.set_name("_XlaLaunch-op"); node_def.set_op("XlaLaunch"); - string kernel_class_name; + std::string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); VLOG(1) << "LaunchOpHasKernelForDevice" @@ -128,7 +128,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; } /* static */ void XlaOpRegistry::RegisterCompilationDevice( - const string& device_name, const DeviceRegistration& registration) { + const std::string& device_name, const DeviceRegistration& registration) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto result = @@ -138,7 +138,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; } /* static */ void XlaOpRegistry::RegisterBackend( - const string& compilation_device_name, + const std::string& compilation_device_name, absl::Span supported_types, BackendOpFilter op_filter) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); @@ -151,14 +151,14 @@ XlaOpRegistry::~XlaOpRegistry() = default; } /* static */ bool XlaOpRegistry::IsCompilationDevice( - const string& device_name) { + const std::string& device_name) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); return registry.backends_.find(device_name) != registry.backends_.end(); } /* static */ bool XlaOpRegistry::GetCompilationDevice( - const string& device_name, const DeviceRegistration** registration) { + const std::string& device_name, const DeviceRegistration** registration) { XlaOpRegistry& registry = Instance(); // Lazily register the CPU and GPU JIT devices the first time @@ -235,7 +235,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // 2. Process op registration without device allowlists: // this pass registers the kernels for all the other supported backends. for (auto& ops : registry.ops_) { - const string& op_name = ops.first; + const std::string& op_name = ops.first; std::vector>& op_registrations = ops.second; // Partition the op registration so that the ones with device allowlists // precede the one without device allowlist. @@ -247,7 +247,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Collect a set of backend registered by ops with device allowlists. // The op registration without allowlists will register a generic kernel // for all other backends not in this set. - std::unordered_set allowlisted_backend; + std::unordered_set allowlisted_backend; for (auto& op_registration : op_registrations) { if (op_registration->has_device_allowlist) { allowlisted_backend.insert(op_registration->device_allowlist.begin(), @@ -267,7 +267,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { } TF_CHECK_OK(lookup_status); - std::unordered_set type_attrs; + std::unordered_set type_attrs; for (const OpDef::AttrDef& attr_def : op_def->attr()) { if (attr_def.type() == "type" || attr_def.type() == "list(type)") { type_attrs.insert(attr_def.name()); @@ -309,7 +309,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { // b) the types allowed by the OpDef, and // c) the type constraints. bool unsatisfiable_type_constraint = false; - for (const string& type_attr : type_attrs) { + for (const std::string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = @@ -375,7 +375,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { } std::vector XlaOpRegistry::DeviceKernels( - const string& compilation_device_name, + const std::string& compilation_device_name, bool include_compilation_only_kernels) { // Ensure compilation kernels registered. RegisterCompilationKernels(); @@ -403,8 +403,8 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } -/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { - std::vector ops; +/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { + std::vector ops; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); ops.reserve(registry.ops_.size()); @@ -416,7 +416,7 @@ std::vector XlaOpRegistry::DeviceKernels( } /*static*/ const std::unordered_set* -XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { +XlaOpRegistry::CompileTimeConstantInputArgNames(const std::string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); @@ -435,10 +435,10 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { DCHECK(op_def != nullptr || op_kernel != nullptr); - std::unordered_set compile_time_constant_inputs_from_attr; - std::vector compile_time_constant_inputs_vect_from_attr; + std::unordered_set compile_time_constant_inputs_from_attr; + std::vector compile_time_constant_inputs_vect_from_attr; - const std::unordered_set* compile_time_constant_inputs; + const std::unordered_set* compile_time_constant_inputs; if (TryGetNodeAttr(node_def, kXlaCompileTimeConstantInputsAttr, &compile_time_constant_inputs_vect_from_attr)) { @@ -459,7 +459,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { << " required constants are: " << absl::StrJoin(*compile_time_constant_inputs, ", "); - for (const string& input : *compile_time_constant_inputs) { + for (const std::string& input : *compile_time_constant_inputs) { if (op_def) { NameRangeMap input_name_ranges; TF_RETURN_IF_ERROR( @@ -486,7 +486,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { return absl::OkStatus(); } -/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { +/*static*/ bool XlaOpRegistry::IsMetadataOp(const std::string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); @@ -500,8 +500,8 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { return it->second.front()->is_metadata_op; } -std::vector XlaOpRegistry::BackendNames() { - std::vector names; +std::vector XlaOpRegistry::BackendNames() { + std::vector names; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); names.reserve(registry.backends_.size()); @@ -511,7 +511,7 @@ std::vector XlaOpRegistry::BackendNames() { return names; } -bool XlaOpRegistry::IsBackendRegistered(const string& name) { +bool XlaOpRegistry::IsBackendRegistered(const std::string& name) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); return registry.backends_.find(name) != registry.backends_.end(); @@ -524,7 +524,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = string(name); + registration_->name = std::string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( @@ -572,7 +572,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowStringType() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[string(attr_name)]; + registration_->type_constraints[std::string(attr_name)]; types.insert(allowed); return *this; } @@ -580,7 +580,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( absl::string_view attr_name, absl::Span allowed) { std::set& types = - registration_->type_constraints[string(attr_name)]; + registration_->type_constraints[std::string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -628,7 +628,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(string(name), types, op_filter); + registry.RegisterBackend(std::string(name), types, op_filter); AddSymbolicExecutionDevice(name); } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 5eaf0fb2d42bfa..9ce6e263f8feb4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -139,7 +139,7 @@ class XlaOpRegistry { // Describes how to compile operators assigned to a device. struct DeviceRegistration { // The name of the an XLA compilation device to use to compile code. - string compilation_device_name; + std::string compilation_device_name; // When should we autocluster operators assigned to this device? AutoclusteringPolicy autoclustering_policy; @@ -190,25 +190,25 @@ class XlaOpRegistry { // `backend_op_filter` should return true if the op should be registered on // the device; it may optionally modify the KernelDef. typedef bool (*BackendOpFilter)(KernelDef* kdef); - static void RegisterBackend(const string& compilation_device_name, + static void RegisterBackend(const std::string& compilation_device_name, absl::Span supported_types, BackendOpFilter op_filter); // Returns the names of the registered backends. - static std::vector BackendNames(); + static std::vector BackendNames(); // Returns true iff a backend with the given name is registered. - static bool IsBackendRegistered(const string& name); + static bool IsBackendRegistered(const std::string& name); // Registers `device_name` for XLA compilation, using information from // `registration`. // Does nothing if a registration for `device_name` already exists. - static void RegisterCompilationDevice(const string& device_name, + static void RegisterCompilationDevice(const std::string& device_name, const DeviceRegistration& registration); // Returns whether the device name is for the JIT device used exclusively for // TF2XLA conversion. - static bool IsCompilationDevice(const string& device_name); + static bool IsCompilationDevice(const std::string& device_name); // Returns the JIT device name associated with 'device_name', setting // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they @@ -216,7 +216,7 @@ class XlaOpRegistry { // JIT device is registered. // '*enable_jit_by_default' is set to true if we should try to JIT using this // device when the JIT is enabled via the Session OptimizerOptions. - static bool GetCompilationDevice(const string& device_name, + static bool GetCompilationDevice(const std::string& device_name, const DeviceRegistration** registration); // Registers all JIT kernels on JIT devices, if not already registered. @@ -227,11 +227,11 @@ class XlaOpRegistry { // 'compilation_device_name'. Does not include kernels registered as // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( - const string& compilation_device_name, + const std::string& compilation_device_name, bool include_compilation_only_kernels); // Returns all operations for which there are XLA kernels on any device. - static std::vector GetAllRegisteredOps(); + static std::vector GetAllRegisteredOps(); // Returns (via `result`) the indices of inputs to `node_def` that must be // compile-time constants. Returns an empty vector if the op is not @@ -265,11 +265,11 @@ class XlaOpRegistry { // Return names of arguments for a given op which are supposed to be // constants. static const std::unordered_set* - CompileTimeConstantInputArgNames(const string& op); + CompileTimeConstantInputArgNames(const std::string& op); // Returns true if `op` is a "metadata" op, one that only looks at the shapes // of its operands and not their values. - static bool IsMetadataOp(const string& op); + static bool IsMetadataOp(const std::string& op); private: friend class XlaBackendRegistrar; @@ -298,15 +298,15 @@ class XlaOpRegistry { }; // Map from compilation device names to a description of the backend. - std::unordered_map backends_ TF_GUARDED_BY(mutex_); + std::unordered_map backends_ TF_GUARDED_BY(mutex_); // Map from Tensorflow device names to the corresponding JIT device metadata. - std::unordered_map compilation_devices_ + std::unordered_map compilation_devices_ TF_GUARDED_BY(mutex_); // A description of a Tensorflow operator that can be compiled to XLA. struct OpRegistration { - string name; + std::string name; // Should this operator be registered only on compilation devices, without a // dummy kernel registered on the corresponding XLA device? @@ -325,15 +325,15 @@ class XlaOpRegistry { bool allow_string_type = false; // Mapping from attribute name to a list of supported types. - std::unordered_map> type_constraints; + std::unordered_map> type_constraints; // An optional allowlist of devices. If there is no allowlist, all devices // are permitted. bool has_device_allowlist = false; - std::unordered_set device_allowlist; + std::unordered_set device_allowlist; // Names of arguments that must be compile-time constants. - std::unordered_set compile_time_constant_inputs; + std::unordered_set compile_time_constant_inputs; // True if this is a "metadata" op, one that only looks at the shapes of its // operands and not their values. @@ -360,8 +360,8 @@ class XlaOpRegistry { // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. - std::unordered_map>> ops_ - TF_GUARDED_BY(mutex_); + std::unordered_map>> + ops_ TF_GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? bool jit_kernels_registered_ = false; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 5b894d07e121ba..962b0e473a826c 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -51,29 +51,29 @@ namespace tensorflow { } /*static*/ std::unique_ptr XlaResource::CreateStack( - string name, DataType type, int64_t max_size) { + std::string name, DataType type, int64_t max_size) { return std::make_unique( XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(), /*initial_value=*/xla::XlaOp(), /*max_array_size=*/max_size, - /*tensor_array_gradients=*/std::set{}, + /*tensor_array_gradients=*/std::set{}, /*tensor_array_multiple_writes_aggregate=*/false); } /*static*/ std::unique_ptr XlaResource::CreateTensorArray( - string name, DataType type, TensorShape shape, xla::XlaOp initial_value, - int64_t max_array_size) { + std::string name, DataType type, TensorShape shape, + xla::XlaOp initial_value, int64_t max_array_size) { return std::make_unique( XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape, initial_value, max_array_size, - /*tensor_array_gradients=*/std::set{}, + /*tensor_array_gradients=*/std::set{}, /*tensor_array_multiple_writes_aggregate=*/false); } XlaResource::XlaResource( - Kind kind, int arg_num, string name, DataType type, TensorShape shape, + Kind kind, int arg_num, std::string name, DataType type, TensorShape shape, xla::XlaOp initial_value, int64_t max_array_size, - const std::set& tensor_array_gradients, + const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate, const std::optional& definition_stack_trace) : kind_(kind), @@ -89,7 +89,7 @@ XlaResource::XlaResource( definition_stack_trace_(definition_stack_trace) { CHECK(kind_ != kInvalid); - for (const string& gradient : tensor_array_gradients) { + for (const std::string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, @@ -163,7 +163,7 @@ absl::Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { value_ = xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()), - xla::ConstantR0(builder, 0)}); + xla::ConstantR0(builder, 0)}); break; } @@ -175,7 +175,7 @@ absl::Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } absl::Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::XlaBuilder* builder, + const std::string& source, xla::XlaBuilder* builder, XlaResource** gradient_out) { VLOG(2) << "Gradient lookup for resource: " << name_ << " gradient: " << source; @@ -214,9 +214,9 @@ absl::Status XlaResource::Pack(xla::XlaOp* pack, return absl::OkStatus(); } -absl::Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::XlaOp pack, - xla::XlaBuilder* builder) { +absl::Status XlaResource::SetFromPack( + const std::set& gradient_sources, const xla::XlaOp pack, + xla::XlaBuilder* builder) { if (gradient_sources.empty()) { if (!initialized()) { initial_value_ = pack; diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index d4c8f7c1c9347f..07c826d21e8b3d 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -43,18 +43,19 @@ class XlaResource { static absl::string_view KindToString(Kind kind); // Creates a new Stack resource. - static std::unique_ptr CreateStack(string name, DataType type, + static std::unique_ptr CreateStack(std::string name, + DataType type, int64_t max_size); // Creates a new TensorArray resource. static std::unique_ptr CreateTensorArray( - string name, DataType type, TensorShape shape, xla::XlaOp initial_value, - int64_t max_array_size); + std::string name, DataType type, TensorShape shape, + xla::XlaOp initial_value, int64_t max_array_size); - XlaResource(Kind kind, int arg_num, string name, DataType type, + XlaResource(Kind kind, int arg_num, std::string name, DataType type, TensorShape shape, xla::XlaOp initial_value, int64_t max_array_size, - const std::set& tensor_array_gradients, + const std::set& tensor_array_gradients, bool tensor_array_multiple_writes_aggregate, const std::optional& definition_stack_trace = std::nullopt); @@ -72,7 +73,7 @@ class XlaResource { int arg_num() const { return arg_num_; } // A descriptive name for the resource, used in error messages. - const string& name() const { return name_; } + const std::string& name() const { return name_; } // Current type and value of the resource. Uninitialized resources are // represented by a default (zero) handle and type DT_INVALID. @@ -121,7 +122,7 @@ class XlaResource { // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator // documentation for TensorArrayGradV3 for details. - absl::Status GetOrCreateTensorArrayGradient(const string& source, + absl::Status GetOrCreateTensorArrayGradient(const std::string& source, xla::XlaBuilder* builder, XlaResource** gradient_out); @@ -138,7 +139,7 @@ class XlaResource { // If `reset_initial_values` is true, sets the initial_values as well as the // values. // Opposite of Pack(). - absl::Status SetFromPack(const std::set& gradient_sources, + absl::Status SetFromPack(const std::set& gradient_sources, xla::XlaOp pack, xla::XlaBuilder* builder); bool IsOverwritten() { return is_overwritten_; } @@ -164,15 +165,15 @@ class XlaResource { // string, irrespective of the number of calls to TensorArrayGrad. The map // is ordered since values are packed into tuples by Pack() sorted by name // order. - const std::map>& tensor_array_gradients() - const { + const std::map>& + tensor_array_gradients() const { return tensor_array_gradients_; } private: const Kind kind_; const int arg_num_; - const string name_; + const std::string name_; DataType type_; TensorShape shape_; @@ -186,7 +187,7 @@ class XlaResource { int64_t max_array_size_ = -1; bool tensor_array_multiple_writes_aggregate_ = false; - std::map> tensor_array_gradients_; + std::map> tensor_array_gradients_; bool is_overwritten_ = false; std::optional definition_stack_trace_; diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 561a4ca410c5e0..b76a4ffd8955b9 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1944,21 +1944,15 @@ tf_cc_tests( ) tf_cc_tests( - name = "cell_reader_test", + name = "test_utils_test", size = "small", srcs = [ - "//tensorflow/core/lib/monitoring:cell_reader_test.cc", "//tensorflow/core/lib/monitoring:test_utils_test.cc", ], deps = [ ":protos_all_cc", ":test", ":test_main", - "//tensorflow/core/lib/monitoring:cell_reader", - "//tensorflow/core/lib/monitoring:counter", - "//tensorflow/core/lib/monitoring:gauge", - "//tensorflow/core/lib/monitoring:percentile_sampler", - "//tensorflow/core/lib/monitoring:sampler", "//tensorflow/core/lib/monitoring:test_utils", "//tensorflow/core/lib/monitoring:types", "//tensorflow/core/platform:errors", diff --git a/tensorflow/core/activity_watcher/activity.h b/tensorflow/core/activity_watcher/activity.h index eecd207a33fe27..fba51b43f8a3ce 100644 --- a/tensorflow/core/activity_watcher/activity.h +++ b/tensorflow/core/activity_watcher/activity.h @@ -32,7 +32,7 @@ namespace tensorflow { namespace activity_watcher { -using ActivityId = tsl::uint64; +using ActivityId = uint64_t; constexpr ActivityId kActivityNotRecorded = 0; constexpr int kWatcherDisabled = 0; @@ -45,7 +45,7 @@ enum ActivityCategory { kRendezvous = 5, }; -static tsl::string ToString(ActivityCategory category) { +static std::string ToString(ActivityCategory category) { switch (category) { case ActivityCategory::kCollective: return "Collective"; @@ -64,17 +64,17 @@ static tsl::string ToString(ActivityCategory category) { // An activity to be recorded. struct Activity { - using Attributes = absl::flat_hash_map; + using Attributes = absl::flat_hash_map; // A human readable title of the activity. - tsl::string title; + std::string title; // The category of the activity. ActivityCategory category = ActivityCategory::kMisc; // Key/value pairs that are attached to the activity. Attributes attributes; Activity() = default; - Activity(tsl::string title, ActivityCategory category) + Activity(std::string title, ActivityCategory category) : title(std::move(title)), category(category) {} - Activity(tsl::string title, ActivityCategory category, Attributes attributes) + Activity(std::string title, ActivityCategory category, Attributes attributes) : title(std::move(title)), category(category), attributes(std::move(attributes)) {} diff --git a/tensorflow/core/activity_watcher/activity_utils.cc b/tensorflow/core/activity_watcher/activity_utils.cc index b3631076c5c2d9..58b3909a25789c 100644 --- a/tensorflow/core/activity_watcher/activity_utils.cc +++ b/tensorflow/core/activity_watcher/activity_utils.cc @@ -28,7 +28,7 @@ namespace tensorflow { namespace activity_watcher { std::unique_ptr ActivityFromContext( - OpKernelContext* context, tsl::string name, ActivityCategory category, + OpKernelContext* context, std::string name, ActivityCategory category, Activity::Attributes additional_attributes) { Activity::Attributes attributes(std::move(additional_attributes)); if (context) { diff --git a/tensorflow/core/activity_watcher/activity_utils.h b/tensorflow/core/activity_watcher/activity_utils.h index 64958cd5e09744..749ef1326ae565 100644 --- a/tensorflow/core/activity_watcher/activity_utils.h +++ b/tensorflow/core/activity_watcher/activity_utils.h @@ -29,7 +29,7 @@ namespace activity_watcher { // A convenient way to create an activity. Writes OpKernelContext information // and given attributes to a new activity and returns. std::unique_ptr ActivityFromContext( - OpKernelContext* context, tsl::string name, ActivityCategory category, + OpKernelContext* context, std::string name, ActivityCategory category, Activity::Attributes additional_attributes = Activity::Attributes()); } // namespace activity_watcher diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 76b8cc01324619..caf20c11b93566 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -65,6 +65,7 @@ cc_library( "//tensorflow/core:op_gen_lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 7f844e88ba90c6..3c954cf076ddc8 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -43,26 +43,27 @@ namespace { constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -string DefaultApiDefDir() { +std::string DefaultApiDefDir() { return GetDataDependencyFilepath( io::JoinPath("tensorflow", "core", "api_def", "base_api")); } -string PythonApiDefDir() { +std::string PythonApiDefDir() { return GetDataDependencyFilepath( io::JoinPath("tensorflow", "core", "api_def", "python_api")); } // Reads golden ApiDef files and returns a map from file name to ApiDef file // contents. -void GetGoldenApiDefs(Env* env, const string& api_files_dir, - std::unordered_map* name_to_api_def) { - std::vector matching_paths; +void GetGoldenApiDefs( + Env* env, const std::string& api_files_dir, + std::unordered_map* name_to_api_def) { + std::vector matching_paths; TF_CHECK_OK(env->GetMatchingPaths( io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); for (auto& file_path : matching_paths) { - string file_contents; + std::string file_contents; TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); file_contents = PBTxtFromMultiline(file_contents); @@ -76,8 +77,9 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, } void TestAllApiDefsHaveCorrespondingOp( - const OpList& ops, const std::unordered_map& api_defs_map) { - std::unordered_set op_names; + const OpList& ops, + const std::unordered_map& api_defs_map) { + std::unordered_set op_names; for (const auto& op : ops.op()) { op_names.insert(op.name()); } @@ -89,7 +91,8 @@ void TestAllApiDefsHaveCorrespondingOp( } void TestAllApiDefInputArgsAreValid( - const OpList& ops, const std::unordered_map& api_defs_map) { + const OpList& ops, + const std::unordered_map& api_defs_map) { for (const auto& op : ops.op()) { const auto api_def_iter = api_defs_map.find(op.name()); if (api_def_iter == api_defs_map.end()) { @@ -113,7 +116,8 @@ void TestAllApiDefInputArgsAreValid( } void TestAllApiDefOutputArgsAreValid( - const OpList& ops, const std::unordered_map& api_defs_map) { + const OpList& ops, + const std::unordered_map& api_defs_map) { for (const auto& op : ops.op()) { const auto api_def_iter = api_defs_map.find(op.name()); if (api_def_iter == api_defs_map.end()) { @@ -137,7 +141,8 @@ void TestAllApiDefOutputArgsAreValid( } void TestAllApiDefAttributeNamesAreValid( - const OpList& ops, const std::unordered_map& api_defs_map) { + const OpList& ops, + const std::unordered_map& api_defs_map) { for (const auto& op : ops.op()) { const auto api_def_iter = api_defs_map.find(op.name()); if (api_def_iter == api_defs_map.end()) { @@ -159,7 +164,7 @@ void TestAllApiDefAttributeNamesAreValid( } void TestDeprecatedAttributesSetCorrectly( - const std::unordered_map& api_defs_map) { + const std::unordered_map& api_defs_map) { for (const auto& name_and_api_def : api_defs_map) { int num_deprecated_endpoints = 0; const auto& api_def = name_and_api_def.second; @@ -186,7 +191,7 @@ void TestDeprecatedAttributesSetCorrectly( } void TestDeprecationVersionSetCorrectly( - const std::unordered_map& api_defs_map) { + const std::unordered_map& api_defs_map) { for (const auto& name_and_api_def : api_defs_map) { const auto& name = name_and_api_def.first; const auto& api_def = name_and_api_def.second; @@ -205,13 +210,13 @@ class BaseApiTest : public ::testing::Test { protected: BaseApiTest() { OpRegistry::Global()->Export(false, &ops_); - const std::vector multi_line_fields = {"description"}; + const std::vector multi_line_fields = {"description"}; Env* env = Env::Default(); GetGoldenApiDefs(env, DefaultApiDefDir(), &api_defs_map_); } OpList ops_; - std::unordered_map api_defs_map_; + std::unordered_map api_defs_map_; }; // Check that all ops have an ApiDef. @@ -233,7 +238,7 @@ TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) { TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } -string GetOpDefHasDocStringError(const string& op_name) { +std::string GetOpDefHasDocStringError(const std::string& op_name) { return strings::Printf( "OpDef for %s has a doc string. " "Doc strings must be defined in ApiDef instead of OpDef. " @@ -301,13 +306,13 @@ class PythonApiTest : public ::testing::Test { protected: PythonApiTest() { OpRegistry::Global()->Export(false, &ops_); - const std::vector multi_line_fields = {"description"}; + const std::vector multi_line_fields = {"description"}; Env* env = Env::Default(); GetGoldenApiDefs(env, PythonApiDefDir(), &api_defs_map_); } OpList ops_; - std::unordered_map api_defs_map_; + std::unordered_map api_defs_map_; }; // Check that ApiDefs have a corresponding op. diff --git a/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt b/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt index 7c4db1f721a032..41868ddc6c649f 100644 --- a/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ComplexAbs.pbtxt @@ -1,5 +1,12 @@ op { graph_op_name: "ComplexAbs" + attr { + name: "Tout" + description: < #include +#include "absl/strings/str_format.h" #include "tensorflow/core/api_def/excluded_ops.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -124,7 +125,7 @@ bool CheckDocsMatch(const OpDef& op1, const OpDef& op2) { // Returns true if descriptions and summaries in op match a // given single doc-string. -bool ValidateOpDocs(const OpDef& op, const string& doc) { +bool ValidateOpDocs(const OpDef& op, const std::string& doc) { OpDefBuilder b(op.name()); // We don't really care about type we use for arguments and // attributes. We just want to make sure attribute and argument names @@ -146,28 +147,28 @@ bool ValidateOpDocs(const OpDef& op, const string& doc) { } } // namespace -string RemoveDoc(const OpDef& op, const string& file_contents, - size_t start_location) { +std::string RemoveDoc(const OpDef& op, const std::string& file_contents, + size_t start_location) { // Look for a line starting with .Doc( after the REGISTER_OP. const auto doc_start_location = file_contents.find(kDocStart, start_location); - const string format_error = strings::Printf( + const std::string format_error = strings::Printf( "Could not find %s doc for removal. Make sure the doc is defined with " "'%s' prefix and '%s' suffix or remove the doc manually.", op.name().c_str(), kDocStart, kDocEnd); - if (doc_start_location == string::npos) { + if (doc_start_location == std::string::npos) { std::cerr << format_error << std::endl; LOG(ERROR) << "Didn't find doc start"; return file_contents; } const auto doc_end_location = file_contents.find(kDocEnd, doc_start_location); - if (doc_end_location == string::npos) { + if (doc_end_location == std::string::npos) { LOG(ERROR) << "Didn't find doc start"; std::cerr << format_error << std::endl; return file_contents; } const auto doc_start_size = sizeof(kDocStart) - 1; - string doc_text = file_contents.substr( + std::string doc_text = file_contents.substr( doc_start_location + doc_start_size, doc_end_location - doc_start_location - doc_start_size); @@ -189,12 +190,12 @@ namespace { // Remove .Doc calls that follow REGISTER_OP calls for the given ops. // We search for REGISTER_OP calls in the given op_files list. void RemoveDocs(const std::vector& ops, - const std::vector& op_files) { + const std::vector& op_files) { // Set of ops that we already found REGISTER_OP calls for. - std::set processed_ops; + std::set processed_ops; for (const auto& file : op_files) { - string file_contents; + std::string file_contents; bool file_contents_updated = false; TF_CHECK_OK(ReadFileToString(Env::Default(), file, &file_contents)); @@ -203,11 +204,11 @@ void RemoveDocs(const std::vector& ops, // We already found REGISTER_OP call for this op in another file. continue; } - string register_call = + std::string register_call = strings::Printf("REGISTER_OP(\"%s\")", op->name().c_str()); const auto register_call_location = file_contents.find(register_call); // Find REGISTER_OP(OpName) call. - if (register_call_location == string::npos) { + if (register_call_location == std::string::npos) { continue; } std::cout << "Removing .Doc call for " << op->name() << " from " << file @@ -228,11 +229,11 @@ void RemoveDocs(const std::vector& ops, // Returns ApiDefs text representation in multi-line format // constructed based on the given op. -string CreateApiDef(const OpDef& op) { +std::string CreateApiDef(const OpDef& op) { ApiDefs api_defs; FillBaseApiDef(api_defs.add_op(), op); - const std::vector multi_line_fields = {"description"}; + const std::vector multi_line_fields = {"description"}; std::string new_api_defs_str; ::tensorflow::protobuf::TextFormat::PrintToString(api_defs, &new_api_defs_str); @@ -242,8 +243,8 @@ string CreateApiDef(const OpDef& op) { // Creates ApiDef files for any new ops. // If op_file_pattern is not empty, then also removes .Doc calls from // new op registrations in these files. -void CreateApiDefs(const OpList& ops, const string& api_def_dir, - const string& op_file_pattern) { +void CreateApiDefs(const OpList& ops, const std::string& api_def_dir, + const std::string& op_file_pattern) { auto* excluded_ops = GetExcludedOps(); std::vector new_ops_with_docs; @@ -252,9 +253,8 @@ void CreateApiDefs(const OpList& ops, const string& api_def_dir, continue; } // Form the expected ApiDef path. - string file_path = - io::JoinPath(tensorflow::string(api_def_dir), kApiDefFileFormat); - file_path = strings::Printf(file_path.c_str(), op.name().c_str()); + std::string file_name = absl::StrFormat(kApiDefFileFormat, op.name()); + std::string file_path = io::JoinPath(api_def_dir, file_name); // Create ApiDef if it doesn't exist. if (!Env::Default()->FileExists(file_path).ok()) { @@ -268,7 +268,7 @@ void CreateApiDefs(const OpList& ops, const string& api_def_dir, } } if (!op_file_pattern.empty()) { - std::vector op_files; + std::vector op_files; TF_CHECK_OK(Env::Default()->GetMatchingPaths(op_file_pattern, &op_files)); RemoveDocs(new_ops_with_docs, op_files); } diff --git a/tensorflow/core/api_def/update_api_def.h b/tensorflow/core/api_def/update_api_def.h index 1e285c06883efa..1ac71689bba2d0 100644 --- a/tensorflow/core/api_def/update_api_def.h +++ b/tensorflow/core/api_def/update_api_def.h @@ -23,14 +23,14 @@ namespace tensorflow { // Returns ApiDefs text representation in multi-line format // constructed based on the given op. -string CreateApiDef(const OpDef& op); +std::string CreateApiDef(const OpDef& op); // Removes .Doc call for the given op. // If unsuccessful, returns original file_contents and prints an error. // start_location - We search for .Doc call starting at this location // in file_contents. -string RemoveDoc(const OpDef& op, const string& file_contents, - size_t start_location); +std::string RemoveDoc(const OpDef& op, const std::string& file_contents, + size_t start_location); // Creates api_def_*.pbtxt files for any new ops (i.e. ops that don't have an // api_def_*.pbtxt file yet). @@ -38,8 +38,8 @@ string RemoveDoc(const OpDef& op, const string& file_contents, // look for a REGISTER_OP call for the new ops and removes corresponding // .Doc() calls since the newly generated api_def_*.pbtxt files will // store the doc strings. -void CreateApiDefs(const OpList& ops, const string& api_def_dir, - const string& op_file_pattern); +void CreateApiDefs(const OpList& ops, const std::string& api_def_dir, + const std::string& op_file_pattern); } // namespace tensorflow #endif // TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ diff --git a/tensorflow/core/api_def/update_api_def_main.cc b/tensorflow/core/api_def/update_api_def_main.cc index 3fd975ce178b5f..4cf74abf82cb6f 100644 --- a/tensorflow/core/api_def/update_api_def_main.cc +++ b/tensorflow/core/api_def/update_api_def_main.cc @@ -33,8 +33,8 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" int main(int argc, char** argv) { - tensorflow::string api_files_dir; - tensorflow::string op_file_pattern; + std::string api_files_dir; + std::string op_file_pattern; std::vector flag_list = { tensorflow::Flag("api_def_dir", &api_files_dir, "Base directory of api_def*.pbtxt files."), diff --git a/tensorflow/core/api_def/update_api_def_test.cc b/tensorflow/core/api_def/update_api_def_test.cc index 4200c9da23c093..23751ffa3ecd25 100644 --- a/tensorflow/core/api_def/update_api_def_test.cc +++ b/tensorflow/core/api_def/update_api_def_test.cc @@ -24,7 +24,7 @@ namespace tensorflow { namespace { TEST(UpdateApiDefTest, TestRemoveDocSingleOp) { - const string op_def_text = R"opdef( + const std::string op_def_text = R"opdef( REGISTER_OP("Op1") .Input("a: T") .Output("output: T") @@ -32,7 +32,7 @@ REGISTER_OP("Op1") .SetShapeFn(shape_inference::UnchangedShape); )opdef"; - const string op_def_text_with_doc = R"opdef( + const std::string op_def_text_with_doc = R"opdef( REGISTER_OP("Op1") .Input("a: T") .Output("output: T") @@ -50,7 +50,7 @@ output: Description for output. )doc"); )opdef"; - const string op_text = R"( + const std::string op_text = R"( name: "Op1" input_arg { name: "a" @@ -75,7 +75,7 @@ description: "Description\nfor Op1." } TEST(UpdateApiDefTest, TestRemoveDocMultipleOps) { - const string op_def_text = R"opdef( + const std::string op_def_text = R"opdef( REGISTER_OP("Op1") .Input("a: T") .SetShapeFn(shape_inference::UnchangedShape); @@ -89,7 +89,7 @@ REGISTER_OP("Op3") .SetShapeFn(shape_inference::UnchangedShape); )opdef"; - const string op_def_text_with_doc = R"opdef( + const std::string op_def_text_with_doc = R"opdef( REGISTER_OP("Op1") .Input("a: T") .Doc(R"doc( @@ -112,21 +112,21 @@ Summary for Op3. )doc"); )opdef"; - const string op1_text = R"( + const std::string op1_text = R"( name: "Op1" input_arg { name: "a" } summary: "Summary for Op1." )"; - const string op2_text = R"( + const std::string op2_text = R"( name: "Op2" input_arg { name: "a" } summary: "Summary for Op2." )"; - const string op3_text = R"( + const std::string op3_text = R"( name: "Op3" input_arg { name: "c" @@ -138,12 +138,12 @@ summary: "Summary for Op3." protobuf::TextFormat::ParseFromString(op2_text, &op2); // NOLINT protobuf::TextFormat::ParseFromString(op3_text, &op3); // NOLINT - string updated_text = + std::string updated_text = RemoveDoc(op2, op_def_text_with_doc, op_def_text_with_doc.find("Op2") /* start_location */); - EXPECT_EQ(string::npos, updated_text.find("Summary for Op2")); - EXPECT_NE(string::npos, updated_text.find("Summary for Op1")); - EXPECT_NE(string::npos, updated_text.find("Summary for Op3")); + EXPECT_EQ(std::string::npos, updated_text.find("Summary for Op2")); + EXPECT_NE(std::string::npos, updated_text.find("Summary for Op1")); + EXPECT_NE(std::string::npos, updated_text.find("Summary for Op3")); updated_text = RemoveDoc(op3, updated_text, updated_text.find("Op3") /* start_location */); @@ -153,7 +153,7 @@ summary: "Summary for Op3." } TEST(UpdateApiDefTest, TestCreateApiDef) { - const string op_text = R"( + const std::string op_text = R"( name: "Op1" input_arg { name: "a" @@ -173,7 +173,7 @@ description: "Description\nfor Op1." OpDef op; protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT - const string expected_api_def = R"(op { + const std::string expected_api_def = R"(op { graph_op_name: "Op1" in_arg { name: "a" diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 5ada8b377b880a..86c504a3fda8e5 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -375,6 +375,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/config:flag_defs", "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation", "//tensorflow/core/profiler/lib:traceme", @@ -1110,6 +1111,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", + "@com_google_absl//absl/log:check", ], ) diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index 77231dee43e240..5150869da370f7 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/op_kernel.h" @@ -229,6 +230,9 @@ CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, BaseCollectiveExecutor::~BaseCollectiveExecutor() {} void BaseCollectiveExecutor::StartAbort(const absl::Status& s) { + if (flags::Global().enable_fatal_error_on_collective_abort.value()) { + LOG(FATAL) << "BaseCollectiveExecutor::StartAbort: " << s; + } absl::Status status; { mutex_lock l(status_mu_); diff --git a/tensorflow/core/common_runtime/device/device_event_mgr.h b/tensorflow/core/common_runtime/device/device_event_mgr.h index 7fb0dbc822d676..75847bf66a6e2c 100644 --- a/tensorflow/core/common_runtime/device/device_event_mgr.h +++ b/tensorflow/core/common_runtime/device/device_event_mgr.h @@ -83,7 +83,7 @@ class EventMgr { friend class EventMgrFactory; se::StreamExecutor* const exec_; - const int32 polling_active_delay_usecs_; + const int32_t polling_active_delay_usecs_; mutex mu_; condition_variable events_pending_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 65bd6a8dae1e5d..583fce11a0ef28 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -134,9 +134,9 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { // Time the execution of kernels (in CPU cycles). Used to dynamically identify // inexpensive kernels which can be dispatched inline. struct KernelTimer { - uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); + uint64_t start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); - uint64 ElapsedCycles() { + uint64_t ElapsedCycles() { return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; } }; @@ -197,14 +197,14 @@ class ExecutorImpl : public Executor { // given node is expensive. The new cost estimate is a weighted average of // the old cost estimate and the latest cost. We only update cost estimates // for kernels for which IsExpensive() return true. - void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) { + void UpdateCostEstimate(const NodeItem& node, uint64_t elapsed_cycles) { // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous // updates may result in one or more updates being ignored. This does not // affect correctness but may slow down the update frequency. std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id]; auto prev_estimate = cost_estimate.load(std::memory_order_relaxed); - uint64 new_estimate = + uint64_t new_estimate = ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay; cost_estimate.store(new_estimate, std::memory_order_relaxed); @@ -214,9 +214,9 @@ class ExecutorImpl : public Executor { // Initial time (in CPU cycles) we expect an operation to take. Used to // determine whether an operation should be place in a threadpool. // Operations start out "expensive". - static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000; - static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000; - static constexpr uint64 kCostDecay = 10; + static constexpr uint64_t kInitialCostEstimateCycles = 100 * 1000 * 1000; + static constexpr uint64_t kOpIsExpensiveThresholdCycles = 8000; + static constexpr uint64_t kCostDecay = 10; std::vector is_expensive_; // std::unique_ptr[]> is_expensive_; @@ -369,14 +369,14 @@ class ExecutorState { // Maximum number of kernels that can be scheduled inline. If lots of kernels // are ready at the same time, scheduling them in one thread can be very slow. // TODO(fishx): Make it configurable if necessary. - static constexpr uint64 kInlineScheduleReadyThreshold = 500; + static constexpr uint64_t kInlineScheduleReadyThreshold = 500; // Not owned. RendezvousInterface* rendezvous_; CollectiveExecutor* collective_executor_ = nullptr; const ConfigProto* const session_config_; SessionState* session_state_; - string session_handle_; + std::string session_handle_; const SessionMetadata* session_metadata_ = nullptr; TensorStore* tensor_store_; // Step-local container. @@ -1099,7 +1099,7 @@ absl::Status ExecutorState::ProcessOutputs( } if (s.code() == error::RESOURCE_EXHAUSTED) { if (stats_collector_) { - string err = + std::string err = stats_collector_->ReportAllocsOnResourceExhausted(s.message()); s = errors::CreateWithUpdatedMessage(s, absl::StrCat(s.message(), err)); } else { diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 850d48694c1982..cbe63568f69de9 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -105,7 +105,7 @@ class Executor { const ConfigProto* session_config = nullptr; SessionState* session_state = nullptr; // Unique session identifier. Can be empty. - string session_handle; + std::string session_handle; TensorStore* tensor_store = nullptr; ScopedStepContainer* step_container = nullptr; CollectiveExecutor* collective_executor = nullptr; diff --git a/tensorflow/core/common_runtime/executor_factory.cc b/tensorflow/core/common_runtime/executor_factory.cc index 8d4c440b46aa4b..8346b47748484d 100644 --- a/tensorflow/core/common_runtime/executor_factory.cc +++ b/tensorflow/core/common_runtime/executor_factory.cc @@ -29,7 +29,7 @@ namespace { static mutex executor_factory_lock(LINKER_INITIALIZED); -typedef std::unordered_map ExecutorFactories; +typedef std::unordered_map ExecutorFactories; ExecutorFactories* executor_factories() { static ExecutorFactories* factories = new ExecutorFactories; return factories; @@ -37,7 +37,7 @@ ExecutorFactories* executor_factories() { } // namespace -void ExecutorFactory::Register(const string& executor_type, +void ExecutorFactory::Register(const std::string& executor_type, ExecutorFactory* factory) { mutex_lock l(executor_factory_lock); if (!executor_factories()->insert({executor_type, factory}).second) { @@ -47,9 +47,9 @@ void ExecutorFactory::Register(const string& executor_type, } namespace { -const string RegisteredFactoriesErrorMessageLocked() +const std::string RegisteredFactoriesErrorMessageLocked() TF_SHARED_LOCKS_REQUIRED(executor_factory_lock) { - std::vector factory_types; + std::vector factory_types; for (const auto& executor_factory : *executor_factories()) { factory_types.push_back(executor_factory.first); } @@ -58,7 +58,7 @@ const string RegisteredFactoriesErrorMessageLocked() } } // namespace -absl::Status ExecutorFactory::GetFactory(const string& executor_type, +absl::Status ExecutorFactory::GetFactory(const std::string& executor_type, ExecutorFactory** out_factory) { tf_shared_lock l(executor_factory_lock); @@ -73,7 +73,7 @@ absl::Status ExecutorFactory::GetFactory(const string& executor_type, return absl::OkStatus(); } -absl::Status NewExecutor(const string& executor_type, +absl::Status NewExecutor(const std::string& executor_type, const LocalExecutorParams& params, const Graph& graph, std::unique_ptr* out_executor) { ExecutorFactory* factory = nullptr; diff --git a/tensorflow/core/common_runtime/executor_factory.h b/tensorflow/core/common_runtime/executor_factory.h index 14a8d2777bcfcb..3459a4a38b06c9 100644 --- a/tensorflow/core/common_runtime/executor_factory.h +++ b/tensorflow/core/common_runtime/executor_factory.h @@ -36,12 +36,13 @@ class ExecutorFactory { std::unique_ptr* out_executor) = 0; virtual ~ExecutorFactory() {} - static void Register(const string& executor_type, ExecutorFactory* factory); - static absl::Status GetFactory(const string& executor_type, + static void Register(const std::string& executor_type, + ExecutorFactory* factory); + static absl::Status GetFactory(const std::string& executor_type, ExecutorFactory** out_factory); }; -absl::Status NewExecutor(const string& executor_type, +absl::Status NewExecutor(const std::string& executor_type, const LocalExecutorParams& params, const Graph& graph, std::unique_ptr* out_executor); diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index 9ca90f01f6c1c2..81719752519e56 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -128,7 +128,7 @@ Tensor V(const float val) { // A int32 val -> Tensor Tensor VI(const int32_t val) { Tensor tensor(DT_INT32, TensorShape({})); - tensor.scalar()() = val; + tensor.scalar()() = val; return tensor; } @@ -153,10 +153,11 @@ float V(const Tensor& tensor) { return tensor.scalar()(); } -static uint64 kIncarnation = 1; // Uses in following tests. +static uint64_t kIncarnation = 1; // Uses in following tests. -Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, - const string& receiver, const string& name) { +Rendezvous::ParsedKey Key(const std::string& sender, const uint64_t incarnation, + const std::string& receiver, + const std::string& name) { Rendezvous::ParsedKey result; CHECK( Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, @@ -508,8 +509,8 @@ static void BM_executor(::testing::benchmark::State& state) { Graph* g = new Graph(OpRegistry::Global()); random::PhiloxRandom philox(1729, 17); random::SimplePhilox rand(&philox); - uint64 cur = 0; - uint32 r = 1 + rand.Rand32() % width; + uint64_t cur = 0; + uint32_t r = 1 + rand.Rand32() % width; std::vector ready_nodes; for (int i = 0; i < r; ++i) { ready_nodes.push_back(test::graph::NoOp(g, {})); @@ -589,9 +590,9 @@ static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) { Node* sum = test::graph::Add(g, x, y); Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE); - string x_key = test::GetRendezvousKey(x); - string y_key = test::GetRendezvousKey(y); - string z_key = test::GetRendezvousKey(z); + std::string x_key = test::GetRendezvousKey(x); + std::string y_key = test::GetRendezvousKey(y); + std::string z_key = test::GetRendezvousKey(z); Tensor val(DT_FLOAT, TensorShape({})); val.scalar()() = 3.14; @@ -603,9 +604,10 @@ static void BM_FeedInputFetchOutput(::testing::benchmark::State& state) { BENCHMARK(BM_FeedInputFetchOutput); absl::Status ReplaceEdgeWithSendRecv(Graph* g, const Edge* edge, - const string& tensor, const string& sender, - const uint64 sender_incarnation, - const string& receiver) { + const std::string& tensor, + const std::string& sender, + const uint64_t sender_incarnation, + const std::string& receiver) { Node* send; NodeDef send_def; TF_CHECK_OK(NodeDefBuilder(g->NewName("n"), "_Send") @@ -662,16 +664,16 @@ static void BM_WhileLoopHelper(::testing::benchmark::State& state, FunctionDefLibrary f_lib_proto; // Define the loop body as a function: `x = x + 1`. - const Tensor one_t = test::AsScalar(1); + const Tensor one_t = test::AsScalar(1); - std::vector args; + std::vector args; args.reserve(loop_vars); args.push_back("x: int32"); for (int i = 1; i < loop_vars; ++i) { args.push_back(absl::StrCat("x", i, ": int32")); } - std::vector body_rets; + std::vector body_rets; body_rets.reserve(loop_vars); body_rets.push_back("y: int32"); for (int i = 1; i < loop_vars; ++i) { @@ -703,7 +705,7 @@ static void BM_WhileLoopHelper(::testing::benchmark::State& state, body_nodes); // Define the loop condition as a function: `x < loop_iters`. - const Tensor loop_iters_t = test::AsScalar(loop_iters); + const Tensor loop_iters_t = test::AsScalar(loop_iters); *f_lib_proto.add_function() = FunctionDefHelper::Define( // Name "LessThanOrEqualToN", @@ -775,7 +777,7 @@ static void BM_WhileLoopHelper(::testing::benchmark::State& state, if (edge->dst()->type_string() != "Switch") { continue; } - string tensor_name = absl::StrCat("c", edge->id()); + std::string tensor_name = absl::StrCat("c", edge->id()); TF_ASSERT_OK(ReplaceEdgeWithSendRecv(graph.get(), edge, tensor_name, BOB, 1, ALICE)); } diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index a57fe0323a3273..90080692323345 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -88,7 +88,7 @@ struct Endpoint { int index; // Returns the string name represents this endpoint. - string name() const { + std::string name() const { if (index == 0) { return node->name(); } else { @@ -100,7 +100,7 @@ struct Endpoint { }; struct EndpointHash { - uint64 operator()(const Endpoint& x) const { + uint64_t operator()(const Endpoint& x) const { return Hash64(reinterpret_cast(&x.node), sizeof(Node*), x.index); } @@ -166,7 +166,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { : base_flr_(base_flr), lib_def_(std::move(lib_def)) {} ~FunctionLibraryRuntimeOverlay() override; - absl::Status Instantiate(const string& function_name, AttrSlice attrs, + absl::Status Instantiate(const std::string& function_name, AttrSlice attrs, const InstantiateOptions& options, Handle* handle) override; @@ -192,7 +192,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { absl::Status CreateKernel(const std::shared_ptr& props, OpKernel** kernel) override; - bool IsStateful(const string& function_name) const override; + bool IsStateful(const std::string& function_name) const override; const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const override; @@ -204,7 +204,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { std::function)>* runner() override; const DeviceMgr* device_mgr() const override; - string DebugString(Handle handle) override; + std::string DebugString(Handle handle) override; int graph_def_version() const override; absl::Status Clone(std::unique_ptr* out_lib_def, @@ -220,7 +220,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { FunctionLibraryRuntimeOverlay::~FunctionLibraryRuntimeOverlay() = default; absl::Status FunctionLibraryRuntimeOverlay::Instantiate( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const InstantiateOptions& options, Handle* handle) { // We automatically set the `lib_def` option for all instantiations, if the // caller doesn't set this option explicitly. @@ -284,7 +284,7 @@ absl::Status FunctionLibraryRuntimeOverlay::CreateKernel( } bool FunctionLibraryRuntimeOverlay::IsStateful( - const string& function_name) const { + const std::string& function_name) const { // Important: we do not forward lookup to the base FLR. const OpDef* op_def; const absl::Status s = lib_def_.LookUpOpDef(function_name, &op_def); @@ -317,7 +317,7 @@ FunctionLibraryRuntimeOverlay::GetFunctionLibraryDefinition() const { return &lib_def_; } -string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) { +std::string FunctionLibraryRuntimeOverlay::DebugString(Handle handle) { return base_flr_->DebugString(handle); } @@ -348,7 +348,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { ~FunctionLibraryRuntimeImpl() override; - absl::Status Instantiate(const string& function_name, AttrSlice attrs, + absl::Status Instantiate(const std::string& function_name, AttrSlice attrs, const InstantiateOptions& options, Handle* handle) override; @@ -375,7 +375,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { absl::Status RunSync(Options opts, Handle handle, CallFrameInterface* call_frame) override; - bool IsStateful(const string& function) const override; + bool IsStateful(const std::string& function) const override; // TODO: b/396484774 - Consider handling the case where the FLR is already // finalized instead of always returning the pointer to the unowned library @@ -397,7 +397,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const ConfigProto* const config_proto() override { return config_; } int graph_def_version() const override { return graph_def_version_; } - string DebugString(Handle h) override; + std::string DebugString(Handle h) override; absl::Status Clone(std::unique_ptr* out_lib_def, std::unique_ptr* out_pflr, @@ -416,9 +416,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { GraphOptimizer optimizer_; const SessionMetadata* const session_metadata_; Executor::Args::Runner default_runner_; - const string device_name_; + const std::string device_name_; - std::function get_func_sig_; + std::function get_func_sig_; std::function&, OpKernel**)> create_kernel_; @@ -432,13 +432,13 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { // The instantiated and transformed function is encoded as a Graph // object, and an executor is created for the graph. struct Item { - uint64 instantiation_counter = 0; + uint64_t instantiation_counter = 0; std::unique_ptr graph = nullptr; const FunctionLibraryDefinition* lib_def = nullptr; // Not owned. FunctionBody* func_graph = nullptr; Executor* exec = nullptr; core::RefCountPtr overlay_flr = nullptr; - string executor_type; + std::string executor_type; bool allow_small_function_optimizations = false; bool allow_control_flow_sync_execution = false; bool function_runs_at_most_once = false; @@ -517,7 +517,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( absl::flat_hash_map>>()), function_handle_cache_(std::make_unique(this)), parent_(parent) { - get_func_sig_ = [this](const string& op, const OpDef** sig) { + get_func_sig_ = [this](const std::string& op, const OpDef** sig) { return base_lib_def_->LookUpOpDef(op, sig); }; create_kernel_ = [this](const std::shared_ptr& props, @@ -714,7 +714,7 @@ absl::Status FunctionLibraryRuntimeImpl::FunctionDefToBody( return FunctionDefToBodyHelper(std::move(record), attrs, lib_def, get_func_sig_, fbody); } else { - auto get_func_sig = [lib_def](const string& op, const OpDef** sig) { + auto get_func_sig = [lib_def](const std::string& op, const OpDef** sig) { return lib_def->LookUpOpDef(op, sig); }; return FunctionDefToBodyHelper(std::move(record), attrs, lib_def, @@ -779,7 +779,7 @@ bool FunctionLibraryRuntimeImpl::IsLocalTarget( } absl::Status FunctionLibraryRuntimeImpl::Instantiate( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const InstantiateOptions& options, Handle* handle) { if (!IsLocalTarget(options)) { return parent_->Instantiate(function_name, attrs, options, handle); @@ -796,7 +796,7 @@ absl::Status FunctionLibraryRuntimeImpl::Instantiate( // in the canonical key. InstantiateOptions options_copy(options); options_copy.target = device_name_; - const string key = Canonicalize(function_name, attrs, options_copy); + const std::string key = Canonicalize(function_name, attrs, options_copy); { mutex_lock l(mu_); @@ -837,7 +837,7 @@ absl::Status FunctionLibraryRuntimeImpl::Instantiate( if (func.name() == kGradientOp) { return errors::InvalidArgument("Can't take gradient of SymbolicGradient"); } - const string grad = lib_def->FindGradient(func.name()); + const std::string grad = lib_def->FindGradient(func.name()); if (!grad.empty()) { return Instantiate(grad, AttrSlice(&func.attr()), options, handle); } @@ -941,7 +941,7 @@ absl::Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { absl::Status FunctionLibraryRuntimeImpl::CreateItem(Item** item) { const FunctionBody* fbody; FunctionLibraryRuntime* flr; - string executor_type; + std::string executor_type; { tf_shared_lock l(mu_); fbody = (*item)->func_graph; @@ -1120,8 +1120,8 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, absl::Span args, std::vector* rets, Item* item, DoneCallback done) { - string target_device = parent_->GetDeviceName(handle); - string source_device = opts.source_device; + std::string target_device = parent_->GetDeviceName(handle); + std::string source_device = opts.source_device; RendezvousInterface* rendezvous = opts.rendezvous; DeviceContext* device_context; absl::Status s = parent_->GetDeviceContext(target_device, &device_context); @@ -1436,13 +1436,13 @@ absl::Status FunctionLibraryRuntimeImpl::RunSync( return absl::OkStatus(); } -bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const { +bool FunctionLibraryRuntimeImpl::IsStateful(const std::string& func) const { const OpDef* op_def; const absl::Status s = base_lib_def_->LookUpOpDef(func, &op_def); return s.ok() && op_def->is_stateful(); } -string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { +std::string FunctionLibraryRuntimeImpl::DebugString(Handle handle) { Item* item = nullptr; LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); absl::Status s = GetOrCreateItem(local_handle, &item); diff --git a/tensorflow/core/common_runtime/function_def_utils.cc b/tensorflow/core/common_runtime/function_def_utils.cc index 17263a4465cf31..570791252dda4e 100644 --- a/tensorflow/core/common_runtime/function_def_utils.cc +++ b/tensorflow/core/common_runtime/function_def_utils.cc @@ -41,7 +41,7 @@ namespace tensorflow { absl::Status FunctionDefToBodyHelper( core::RefCountPtr&& record, const AttrSlice& attrs, const FunctionLibraryDefinition* const lib_def, - const std::function& + const std::function& get_func_sig, std::unique_ptr* fbody) { // Instantiates the function template into a graph def. @@ -96,7 +96,8 @@ absl::Status FunctionDefToBodyHelper(core::RefCountPtr&& record, const AttrSlice& attrs, const FunctionLibraryDefinition* lib_def, std::unique_ptr* fbody) { - const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) { + const auto get_func_sig = [&lib_def](const std::string& op, + const OpDef** sig) { return lib_def->LookUpOpDef(op, sig); }; return FunctionDefToBodyHelper(std::move(record), attrs, lib_def, @@ -109,7 +110,8 @@ absl::Status FunctionDefToBodyHelper(const FunctionDef& fdef, std::unique_ptr* fbody) { core::RefCountPtr record( new FunctionRecord(FunctionDef(fdef), {}, true)); - const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) { + const auto get_func_sig = [&lib_def](const std::string& op, + const OpDef** sig) { return lib_def->LookUpOpDef(op, sig); }; return FunctionDefToBodyHelper(std::move(record), attrs, lib_def, @@ -125,8 +127,8 @@ bool PrunableStatefulNode(const Node* n) { // and can produce different results on each invocation (due to variable // updates) but it does not itself modify the variable. // TODO(b/341721055): Consolidate this set with other side effect modeling. - static const absl::flat_hash_set* prunable_stateful_ops = - new absl::flat_hash_set{ + static const absl::flat_hash_set* prunable_stateful_ops = + new absl::flat_hash_set{ FunctionLibraryDefinition::kArgOp, "ResourceGather", "ResourceGatherNd", diff --git a/tensorflow/core/common_runtime/function_def_utils.h b/tensorflow/core/common_runtime/function_def_utils.h index cd3b021ec2f3c9..589dd9304edea9 100644 --- a/tensorflow/core/common_runtime/function_def_utils.h +++ b/tensorflow/core/common_runtime/function_def_utils.h @@ -55,7 +55,7 @@ absl::Status FunctionDefToBodyHelper(const FunctionDef& fdef, absl::Status FunctionDefToBodyHelper( core::RefCountPtr&& record, const AttrSlice& attrs, const FunctionLibraryDefinition* lib_def, - const std::function& + const std::function& get_func_sig, std::unique_ptr* fbody); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 08898cc8052396..adf7ea36fdd99d 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -74,7 +74,7 @@ using ::tsl::testing::StatusIs; using FDH = ::tensorflow::FunctionDefHelper; using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; -absl::Status GetOpSig(const string& op, const OpDef** sig) { +absl::Status GetOpSig(const std::string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } @@ -220,14 +220,14 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return absl::OkStatus(); } - absl::Status Instantiate(FunctionLibraryRuntime* flr, const string& name, + absl::Status Instantiate(FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, FunctionLibraryRuntime::Handle* handle) { return flr->Instantiate(name, attrs, handle); } absl::Status Instantiate( - FunctionLibraryRuntime* flr, const string& name, + FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { @@ -235,7 +235,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status InstantiateAndRun(FunctionLibraryRuntime* flr, - const string& name, + const std::string& name, test::function::Attrs attrs, const std::vector& args, std::vector rets) { @@ -245,7 +245,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status InstantiateAndRun( - FunctionLibraryRuntime* flr, const string& name, + FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const std::vector& args, std::vector rets) { @@ -295,7 +295,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status InstantiateAndRunViaCallFrameInterface( - FunctionLibraryRuntime* flr, const string& name, + FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, const std::vector& args, std::vector rets) { FunctionLibraryRuntime::Handle handle; @@ -331,7 +331,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } std::unique_ptr GetFuncBody(FunctionLibraryRuntime* flr, - const string& name, + const std::string& name, test::function::Attrs attrs) { FunctionLibraryRuntime::Handle handle; absl::Status status = flr->Instantiate(name, attrs, &handle); @@ -347,7 +347,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } std::unique_ptr GetGradBody(FunctionLibraryRuntime* flr, - const string& func, + const std::string& func, test::function::Attrs attrs) { FunctionLibraryRuntime::Handle handle; absl::Status status = flr->Instantiate(func, attrs, &handle); @@ -646,9 +646,9 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { // Attrs {}, // Nodes - {FDH::Const("shape", absl::Span({1})), - FDH::Const("minval", 0), - FDH::Const("maxval", 10), + {FDH::Const("shape", absl::Span({1})), + FDH::Const("minval", 0), + FDH::Const("maxval", 10), // A stateful node. {{"y"}, "RandomUniformInt", @@ -665,7 +665,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { // Simple case: instantiating with no state_handle. for (int32_t expected : {6, 4}) { TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({expected})); + test::ExpectTensorEqual(y, test::AsTensor({expected})); } } @@ -678,7 +678,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { EXPECT_EQ(handle, handle_non_isolated); for (int32_t expected : {0, 1}) { TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({expected})); + test::ExpectTensorEqual(y, test::AsTensor({expected})); } } @@ -693,7 +693,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { EXPECT_NE(handle, handle_isolated); for (int32_t expected : {6, 4, 0, 1}) { TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({expected})); + test::ExpectTensorEqual(y, test::AsTensor({expected})); } } @@ -708,7 +708,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { EXPECT_NE(handle, handle_isolated); for (int32_t expected : {6, 4, 0, 1}) { TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({expected})); + test::ExpectTensorEqual(y, test::AsTensor({expected})); } } @@ -725,7 +725,7 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) { EXPECT_NE(handle, handle_isolated); for (int32_t expected : {6, 4, 0, 1}) { TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({expected})); + test::ExpectTensorEqual(y, test::AsTensor({expected})); } TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated)); } @@ -1128,9 +1128,9 @@ TEST_F(FunctionLibraryRuntimeTest, std::unique_ptr g; ExpandInlineFunctionsOptions opts; - const string input_node = "Func/b/input/_0"; - const string output_node = "Func/b/output/_1"; - const string output_control_node = "Func/b/output_control_node/_2"; + const std::string input_node = "Func/b/input/_0"; + const std::string output_node = "Func/b/output/_1"; + const std::string output_control_node = "Func/b/output_control_node/_2"; // Use data outputs as output control source. opts.native_options.output_control_src = OutputControlSrc::kDataOutputs; @@ -1203,9 +1203,9 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepCallerNode) { return absl::OkStatus(); }; - const string input_node = "Func/b/input/_0"; - const string output_node = "Func/b/output/_1"; - const string output_control_node = "Func/b/output_control_node/_2"; + const std::string input_node = "Func/b/input/_0"; + const std::string output_node = "Func/b/output/_1"; + const std::string output_control_node = "Func/b/output_control_node/_2"; // Construct expected graph after function inlining. auto expected_graph = [&](const NodeDef& caller) -> GraphDef { @@ -1266,9 +1266,9 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) { using test::function::NDef; using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; - const string arg_device = "/job:arg/replica:0/task:0/device:GPU"; - const string call_device = "/job:call/replica:0/task:1/device:GPU"; - const string body_device = "/job:body/replica:0/task:1/device:CPU"; + const std::string arg_device = "/job:arg/replica:0/task:0/device:GPU"; + const std::string call_device = "/job:call/replica:0/task:1/device:GPU"; + const std::string body_device = "/job:body/replica:0/task:1/device:CPU"; const FunctionDef func = FDH::Create( "AddFunc", {"i: float"}, {"o: float"}, {}, @@ -1291,12 +1291,13 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) { return absl::OkStatus(); }; - const string input_node = "Func/b/input/_0"; - const string output_node = "Func/b/output/_1"; - const string output_control_node = "Func/b/output_control_node/_2"; + const std::string input_node = "Func/b/input/_0"; + const std::string output_node = "Func/b/output/_1"; + const std::string output_control_node = "Func/b/output_control_node/_2"; // Construct expected graph after function inlining. - auto expected_graph = [&](const std::vector& placed) -> GraphDef { + auto expected_graph = + [&](const std::vector& placed) -> GraphDef { return test::function::GDef( { NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}, placed[0]), @@ -1364,7 +1365,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) { auto g = std::make_unique(OpRegistry::Global()); TF_ASSERT_OK(construct_graph(&g)); - const string merged_device = "/job:body/replica:0/task:1/device:CPU:*"; + const std::string merged_device = "/job:body/replica:0/task:1/device:CPU:*"; ExpandInlineFunctions(flr0_, g.get(), opts); GraphDef expected = expected_graph({/*a*/ arg_device, // @@ -1400,7 +1401,7 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { {{"x1"}, "Add", {"o", "o"}, {{"T", T}}}, {{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}}, {{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}}, - FDH::Const("shape", {1, 2}), + FDH::Const("shape", {1, 2}), // A stateful node. {{"keep_me"}, "RandomUniform", @@ -1410,7 +1411,7 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { {{"z"}, "Add", {"a", "o"}, {{"T", T}}}}); Init({stateful_func}); - auto x = test::AsTensor({1, 2, 3, 4}); + auto x = test::AsTensor({1, 2, 3, 4}); auto y = test::AsTensor({1.0, 2.0, 3.0, 4.0}); Tensor z; @@ -1427,15 +1428,15 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {}, {x, y}, {&z})); - test::ExpectTensorEqual(z, test::AsTensor({2, 5, 10, 17})); + test::ExpectTensorEqual(z, test::AsTensor({2, 5, 10, 17})); stats_collector.FinalizeAndSwap(&stats); // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to // execute. - std::set expected_node_names( + std::set expected_node_names( {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"}); - std::set executed_node_names; + std::set executed_node_names; for (const auto& node_stats : stats.dev_stats()[0].node_stats()) { executed_node_names.insert(node_stats.node_name()); } @@ -1475,9 +1476,9 @@ TEST_F(FunctionLibraryRuntimeTest, DoNotPruneControlOutputsFromBody) { stats_collector.FinalizeAndSwap(&stats); - std::set expected_node_names( + std::set expected_node_names( {"_SOURCE", "i", "add", "ret", "o_RetVal"}); - std::set executed_node_names; + std::set executed_node_names; for (const auto& node_stats : stats.dev_stats()[0].node_stats()) { executed_node_names.insert(node_stats.node_name()); } @@ -1645,7 +1646,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) { TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) { Init({test::function::InvalidControlFlow()}); - auto x = test::AsTensor({0}); + auto x = test::AsTensor({0}); DCHECK_EQ(x.dtype(), DT_INT32); Tensor y; HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}), @@ -2117,7 +2118,7 @@ TEST_F(FunctionLibraryRuntimeTest, FullTypeForInt32) { {{"z"}, "Add", {"x", "x"}, {{"T", T}}}}); Init({int32_func}); - auto x = test::AsTensor({1, 2, 3, 4}); + auto x = test::AsTensor({1, 2, 3, 4}); auto y = test::AsTensor({1.0, 2.0, 3.0, 4.0}); Tensor z; diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc index 77ee26e29d0e1f..a37f05da7df38e 100644 --- a/tensorflow/core/common_runtime/function_testlib.cc +++ b/tensorflow/core/common_runtime/function_testlib.cc @@ -126,8 +126,8 @@ FunctionDef BlockingOpFn() { } // TODO(phawkins): replace with C++ API for calling functions, when that exists. -Output Call(Scope* scope, const string& op_name, const string& fn_name, - absl::Span inputs) { +Output Call(Scope* scope, const std::string& op_name, + const std::string& fn_name, absl::Span inputs) { NodeDef def; NodeDefBuilder builder(op_name, fn_name, scope->graph()->op_registry()); for (const Input& input : inputs) { diff --git a/tensorflow/core/common_runtime/function_testlib.h b/tensorflow/core/common_runtime/function_testlib.h index 9618c4083b869e..b71acef0c83408 100644 --- a/tensorflow/core/common_runtime/function_testlib.h +++ b/tensorflow/core/common_runtime/function_testlib.h @@ -44,8 +44,8 @@ FunctionDef BlockingOpFn(); // Adds a function call to the given scope and returns the output for the node. // TODO(phawkins): replace with C++ API for calling functions, when that exists. -Output Call(Scope* scope, const string& op_name, const string& fn_name, - absl::Span inputs); +Output Call(Scope* scope, const std::string& op_name, + const std::string& fn_name, absl::Span inputs); } // namespace function } // namespace test diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc index e28eb03fd8787b..4c6846593885f5 100644 --- a/tensorflow/core/common_runtime/function_threadpool_test.cc +++ b/tensorflow/core/common_runtime/function_threadpool_test.cc @@ -81,7 +81,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { FunctionLibraryRuntime::Options opts, const std::vector& args, std::vector rets, bool add_runner = true) { - std::atomic call_count(0); + std::atomic call_count(0); std::function)> runner = [&call_count](std::function fn) { ++call_count; @@ -115,14 +115,14 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return absl::OkStatus(); } - absl::Status Instantiate(FunctionLibraryRuntime* flr, const string& name, + absl::Status Instantiate(FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, FunctionLibraryRuntime::Handle* handle) { return flr->Instantiate(name, attrs, handle); } absl::Status Instantiate( - FunctionLibraryRuntime* flr, const string& name, + FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { @@ -130,7 +130,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status InstantiateAndRun(FunctionLibraryRuntime* flr, - const string& name, + const std::string& name, test::function::Attrs attrs, const std::vector& args, std::vector rets, @@ -141,7 +141,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status InstantiateAndRun( - FunctionLibraryRuntime* flr, const string& name, + FunctionLibraryRuntime* flr, const std::string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const std::vector& args, std::vector rets, @@ -171,7 +171,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { FunctionLibraryRuntime::Handle handle, FunctionLibraryRuntime::Options opts, CallFrameInterface* frame, bool add_runner = true) { - std::atomic call_count(0); + std::atomic call_count(0); std::function)> runner = [&call_count](std::function fn) { ++call_count; @@ -232,7 +232,7 @@ TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) { TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h)); auto x1 = test::AsTensor({1, 2, 3, 4}); - std::atomic num_done(0); + std::atomic num_done(0); FunctionLibraryRuntime::Options opts; for (int i = 0; i < 4; ++i) { tp1->Schedule([&h, &x1, &opts, &num_done, this]() { diff --git a/tensorflow/core/common_runtime/function_utils.cc b/tensorflow/core/common_runtime/function_utils.cc index 5c743abd0e81df..736dcc4db4811b 100644 --- a/tensorflow/core/common_runtime/function_utils.cc +++ b/tensorflow/core/common_runtime/function_utils.cc @@ -36,7 +36,7 @@ struct Endpoint { int index; // Returns the string name represents this endpoint. - string name() const { + std::string name() const { if (index == 0) { return node->name(); } else { @@ -285,7 +285,7 @@ bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, return node.IsFunctionCall(); } -string NewName(const Node* n, bool pretty) { +std::string NewName(const Node* n, bool pretty) { if (pretty) { return absl::StrCat(n->type_string(), n->id()); } else { @@ -347,7 +347,7 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { ndef->add_input("unknown"); continue; } - const string srcname = NewName(e->src(), pretty); + const std::string srcname = NewName(e->src(), pretty); if (!e->src()->IsOp()) { } else if (e->IsControlEdge()) { ndef->add_input(absl::StrCat("^", srcname)); @@ -360,7 +360,7 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { }); } -string DebugString(const Graph* g) { +std::string DebugString(const Graph* g) { GraphDef gdef; ToGraphDef(g, &gdef); return DebugString(gdef); diff --git a/tensorflow/core/common_runtime/function_utils.h b/tensorflow/core/common_runtime/function_utils.h index cfbfe86936421b..97cd4cc63e8ea4 100644 --- a/tensorflow/core/common_runtime/function_utils.h +++ b/tensorflow/core/common_runtime/function_utils.h @@ -34,7 +34,7 @@ class OpDef; // Debugging facility. Returns a debug string for a graph // representing an instantiated function. -string DebugString(const Graph* g); +std::string DebugString(const Graph* g); // Dump the contents of the "graph" to log files if the logging level is // sufficiently high. diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index 049155675fcdb6..4d192d8af9fab4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -34,8 +34,8 @@ limitations under the License. #ifdef TF_GPU_USE_PJRT #include "tensorflow/compiler/jit/pjrt_tensor_buffer.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "xla/future.h" #include "xla/literal.h" -#include "xla/pjrt/pjrt_future.h" #endif // TF_GPU_USE_PJRT #include "tensorflow/core/common_runtime/copy_tensor.h" diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 072c3353c5b8d9..f84dbfac0d3f6d 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -69,7 +70,7 @@ namespace { typedef std::tuple OutputAndControlEdges; OutputAndControlEdges CountOutputEdges(const Node* n) { - DCHECK_LE(n->out_edges().size(), kint32max); + DCHECK_LE(n->out_edges().size(), std::numeric_limits::max()); int32_t num_output_edges = 0; int32_t num_output_control_edges = 0; for (auto e : n->out_edges()) { @@ -125,7 +126,8 @@ size_t GraphView::NodeItemBytes(const Node* n) { char* GraphView::InitializeNode(char* ptr, const Node* n) { const int id = n->id(); - CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor + CHECK(node_offsets_[id] == + std::numeric_limits::max()); // Initial value in constructor const size_t bytes = NodeItemBytes(n); constexpr size_t kItemAlignment = sizeof(NodeItem*); @@ -137,7 +139,8 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { // (versus 64 bits on most machines if we just stored an array of NodeItem* // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing // values as "int" vs "size_t" in CHECK_LE. - CHECK_LE(static_cast(ptr - space_), kuint32max); + CHECK_LE(static_cast(ptr - space_), + std::numeric_limits::max()); const uint32 offset = static_cast(ptr - space_); node_offsets_[id] = offset; ptr += bytes; @@ -252,7 +255,7 @@ absl::Status GraphView::Initialize(const Graph* g) { num_nodes_ = num_nodes; size_t total_bytes = 0; for (const Node* n : g->nodes()) { - if (n->out_edges().size() > kint32max) { + if (n->out_edges().size() > std::numeric_limits::max()) { return errors::InvalidArgument( "The executor cannot handle nodes with more than ", std::numeric_limits::max(), " output edges. Node ", @@ -263,7 +266,7 @@ absl::Status GraphView::Initialize(const Graph* g) { node_offsets_ = new uint32[num_nodes]; for (int i = 0; i < num_nodes; i++) { - node_offsets_[i] = kuint32max; + node_offsets_[i] = std::numeric_limits::max(); } space_ = new char[total_bytes]; // NodeItem objects are allocated here diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index 83d15e71282024..3864df8a6ce165 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -221,7 +221,7 @@ class GraphView { DCHECK_GE(id, 0); DCHECK_LT(id, num_nodes_); uint32 offset = node_offsets_[id]; - return ((offset == kuint32max) + return ((offset == std::numeric_limits::max()) ? nullptr : reinterpret_cast(space_ + node_offsets_[id])); } @@ -233,7 +233,7 @@ class GraphView { DCHECK_GE(id, 0); DCHECK_LT(id, num_nodes_); uint32 offset = node_offsets_[id]; - DCHECK_NE(offset, kuint32max); + DCHECK_NE(offset, std::numeric_limits::max()); return *reinterpret_cast(space_ + node_offsets_[id]); } diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc index 998c81efc85d97..a9ebb6f4c3559d 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc @@ -242,7 +242,7 @@ absl::Status PluggableDevice::Init(const SessionOptions& options) { // callback instead of GPU environment variables: TF_GPU_THREAD_MODE, // TF_GPU_THREAD_COUNT, TF_FORCE_GPU_ALLOC_GROWTH, // TF_ENABLE_GPU_GARBAGE_COLLECTION, and TF_GPU_HOST_MEM_LIMIT_IN_MB. - string device_thread_mode; + std::string device_thread_mode; TF_RETURN_IF_ERROR(ReadStringFromEnvVar("TF_GPU_THREAD_MODE", "global", &device_thread_mode)); device_thread_mode = absl::AsciiStrToLower(device_thread_mode); @@ -256,19 +256,19 @@ absl::Status PluggableDevice::Init(const SessionOptions& options) { thread_pool_ = std::make_unique( options.env, ThreadOptions(), absl::StrCat("gpu_private_", tf_device_id_.value()), - static_cast(device_thread_count), + static_cast(device_thread_count), !options.config.experimental().disable_thread_spinning(), /*allocator=*/nullptr); set_tensorflow_device_thread_pool(thread_pool_.get()); } else if (device_thread_mode == "gpu_shared") { static thread::ThreadPool* thread_pool = new thread::ThreadPool( options.env, ThreadOptions(), "gpu_shared", - static_cast(device_thread_count), + static_cast(device_thread_count), !options.config.experimental().disable_thread_spinning(), /*allocator=*/nullptr); set_tensorflow_device_thread_pool(thread_pool); } else { - string error_message = + std::string error_message = absl::StrCat("Invalid gpu_thread_mode: ", device_thread_mode); LOG(WARNING) << error_message; return errors::InvalidArgument(error_message); @@ -293,8 +293,8 @@ Allocator* PluggableDevice::GetAllocator(AllocatorAttributes attr) { } } -string PluggableDevice::ComputeOpKernelDebugString(const OpKernel& op_kernel, - const int stream_id) { +std::string PluggableDevice::ComputeOpKernelDebugString( + const OpKernel& op_kernel, const int stream_id) { return strings::StrCat(op_kernel.name(), " op ", op_kernel.type_string(), " on ", platform_name_, tf_device_id_.value(), " stream[", stream_id, "]"); diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h index bfcbc16d0eb2da..9ccdc04192e071 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h @@ -48,9 +48,9 @@ namespace tensorflow { class PluggableDevice : public LocalDevice { public: PluggableDevice(const SessionOptions& options, const std::string& name, - const string& device_type, const string& platform_name, - Bytes memory_limit, const DeviceLocality& locality, - TfDeviceId tf_device_id, + const std::string& device_type, + const std::string& platform_name, Bytes memory_limit, + const DeviceLocality& locality, TfDeviceId tf_device_id, const std::string& physical_device_desc, Allocator* device_allocator, Allocator* cpu_allocator, bool sync_every_op); @@ -99,7 +99,7 @@ class PluggableDevice : public LocalDevice { // TODO(penpornk): Investigate renaming `GpuDeviceInfo` to `DeviceInfo`. DeviceBase::AcceleratorDeviceInfo* pluggable_device_info_ = nullptr; TfDeviceId tf_device_id_; - const string platform_name_; + const std::string platform_name_; const bool sync_every_op_ = false; EventMgr* em_ = nullptr; std::unique_ptr thread_pool_; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc index e4b3ef4c8e7f2b..ac2488d0b57664 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc @@ -94,14 +94,14 @@ bool PluggableDeviceBFCAllocator::GetGarbageCollectionValue() { } PluggableDeviceBFCAllocator::PluggableDeviceBFCAllocator( - tsl::SubAllocator* sub_allocator, size_t total_memory, const string& name, - bool force_memory_growth_requested) + tsl::SubAllocator* sub_allocator, size_t total_memory, + const std::string& name, bool force_memory_growth_requested) : PluggableDeviceBFCAllocator(sub_allocator, total_memory, GPUOptions(), name, force_memory_growth_requested) {} PluggableDeviceBFCAllocator::PluggableDeviceBFCAllocator( tsl::SubAllocator* sub_allocator, size_t total_memory, - const GPUOptions& gpu_options, const string& name, + const GPUOptions& gpu_options, const std::string& name, bool force_memory_growth_requested) : BFCAllocator(absl::WrapUnique(sub_allocator), total_memory, name, [&] { BFCAllocator::Options o; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h index b968b9dbc1c734..9e87b2612343a6 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h @@ -30,11 +30,12 @@ namespace tensorflow { class PluggableDeviceBFCAllocator : public BFCAllocator { public: PluggableDeviceBFCAllocator(tsl::SubAllocator* sub_allocator, - size_t total_memory, const string& name, + size_t total_memory, const std::string& name, bool force_memory_growth_requested); PluggableDeviceBFCAllocator(tsl::SubAllocator* sub_allocator, size_t total_memory, - const GPUOptions& gpu_options, const string& name, + const GPUOptions& gpu_options, + const std::string& name, bool force_memory_growth_requested); ~PluggableDeviceBFCAllocator() override = default; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc index d580a185f6ed56..855e796ee7903d 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc @@ -82,7 +82,7 @@ int64_t MinSystemMemory(int64_t available_memory) { // Get the memory limit for the virtual device being created on PluggableDevice // with 'platform_device_id', when that virtual device is the only // virtual device being created on that PluggableDevice. -absl::Status SingleVirtualDeviceMemoryLimit(const string& platform_name, +absl::Status SingleVirtualDeviceMemoryLimit(const std::string& platform_name, const GPUOptions& device_options, PlatformDeviceId platform_device_id, int64_t* memory_limit) { @@ -119,18 +119,18 @@ absl::Status SingleVirtualDeviceMemoryLimit(const string& platform_name, } } // namespace -PluggableDeviceFactory::PluggableDeviceFactory(const string& device_type, - const string& platform_name) +PluggableDeviceFactory::PluggableDeviceFactory(const std::string& device_type, + const std::string& platform_name) : device_type_(device_type), platform_name_(platform_name) {} absl::Status PluggableDeviceFactory::ListPhysicalDevices( - std::vector* devices) { + std::vector* devices) { TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); int device_count = platform->VisibleDeviceCount(); for (int i = 0; i < device_count; ++i) { - const string device_name = + const std::string device_name = absl::StrCat("/physical_device:", device_type_, ":", i); devices->push_back(device_name); } @@ -139,7 +139,7 @@ absl::Status PluggableDeviceFactory::ListPhysicalDevices( } absl::Status PluggableDeviceFactory::GetDeviceDetails( - int device_index, std::unordered_map* details) { + int device_index, std::unordered_map* details) { TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); if (platform == nullptr) { @@ -163,7 +163,7 @@ absl::Status PluggableDeviceFactory::GetDeviceDetails( } absl::Status PluggableDeviceFactory::CreateDevices( - const SessionOptions& options, const string& name_prefix, + const SessionOptions& options, const std::string& name_prefix, std::vector>* devices) { TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); @@ -214,20 +214,20 @@ absl::Status PluggableDeviceFactory::CreateDevices( return absl::OkStatus(); } -static string GetShortDeviceDescription(PlatformDeviceId platform_device_id, - const se::DeviceDescription& desc) { +static std::string GetShortDeviceDescription( + PlatformDeviceId platform_device_id, const se::DeviceDescription& desc) { return strings::StrCat("device: ", platform_device_id.value(), ", name: ", desc.name(), ", pci bus id: ", desc.pci_bus_id()); } absl::Status PluggableDeviceFactory::CreatePluggableDevice( - const SessionOptions& options, const string& name_prefix, + const SessionOptions& options, const std::string& name_prefix, TfDeviceId tf_device_id, int64_t memory_limit, const DeviceLocality& dev_locality, std::vector>* devices) { DCHECK_GE(tf_device_id.value(), 0); - const string device_name = strings::StrCat( + const std::string device_name = strings::StrCat( name_prefix, "/device:", device_type_, ":", tf_device_id.value()); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h index 3f6ab10f9951fc..92a145080a0ba4 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h @@ -34,14 +34,15 @@ limitations under the License. namespace tensorflow { class PluggableDeviceFactory : public DeviceFactory { public: - PluggableDeviceFactory(const string& device_type, - const string& platform_name); - absl::Status ListPhysicalDevices(std::vector* devices) override; + PluggableDeviceFactory(const std::string& device_type, + const std::string& platform_name); + absl::Status ListPhysicalDevices(std::vector* devices) override; absl::Status CreateDevices( const SessionOptions& options, const std::string& name_prefix, std::vector>* devices) override; absl::Status GetDeviceDetails( - int device_index, std::unordered_map* details) override; + int device_index, + std::unordered_map* details) override; private: // Populates *device_localities with the DeviceLocality descriptor for @@ -57,8 +58,8 @@ class PluggableDeviceFactory : public DeviceFactory { const DeviceLocality& dev_locality, std::vector>* devices); - const string device_type_; - const string platform_name_; + const std::string device_type_; + const std::string platform_name_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc index 52c09016bddcd1..696248aba12122 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc @@ -25,11 +25,11 @@ limitations under the License. namespace tensorflow { absl::Status ValidatePluggableDeviceMachineManager( - const string& platform_name) { + const std::string& platform_name) { return se::PlatformManager::PlatformWithName(platform_name).status(); } -se::Platform* PluggableDeviceMachineManager(const string& platform_name) { +se::Platform* PluggableDeviceMachineManager(const std::string& platform_name) { auto result = se::PlatformManager::PlatformWithName(platform_name); if (!result.ok()) { LOG(FATAL) << "Could not find platform with name " // Crash OK diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h index b77917d14701c5..6d385ac31c435d 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h @@ -30,7 +30,8 @@ namespace tensorflow { // Initializes the PluggableDevice platform and returns OK if the // PluggableDevice platform could be initialized. -absl::Status ValidatePluggableDeviceMachineManager(const string& platform_name); +absl::Status ValidatePluggableDeviceMachineManager( + const std::string& platform_name); // Returns the PluggableDevice machine manager singleton, creating it and // initializing the PluggableDevices on the machine if needed the first time it @@ -38,7 +39,7 @@ absl::Status ValidatePluggableDeviceMachineManager(const string& platform_name); // environment in the process (e.g., ValidatePluggableDeviceMachineManager() // returns OK). stream_executor::Platform* PluggableDeviceMachineManager( - const string& platform_name); + const std::string& platform_name); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc index 5e41c8db0c39b6..d348c678a15ea3 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc @@ -49,7 +49,7 @@ static absl::Status InitDeviceModule(stream_executor::SEInitPluginFn init_fn) { return absl::OkStatus(); } - string device_type, platform_name; + std::string device_type, platform_name; TF_RETURN_IF_ERROR(stream_executor::InitStreamExecutorPlugin( init_fn, &device_type, &platform_name)); diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc index 581f6b6c5c306f..01f6aa0e97bb00 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc @@ -57,9 +57,9 @@ limitations under the License. namespace tensorflow { /*static*/ PluggableDeviceProcessState* PluggableDeviceProcessState::singleton( - const string& device_type, const string& platform_name) { + const std::string& device_type, const std::string& platform_name) { using ProcessStateMap = - std::unordered_map; + std::unordered_map; static ProcessStateMap* process_state_map = new ProcessStateMap; auto iter = process_state_map->find(platform_name); if (iter != process_state_map->end()) { @@ -71,7 +71,7 @@ namespace tensorflow { } PluggableDeviceProcessState::PluggableDeviceProcessState( - const string& device_type, const string& platform_name) + const std::string& device_type, const std::string& platform_name) : pluggable_device_enabled_(false), device_type_(device_type), platform_name_(platform_name) { @@ -93,7 +93,7 @@ int PluggableDeviceProcessState::BusIdForPluggableDevice( Allocator* PluggableDeviceProcessState::GetPluggableDeviceAllocator( const GPUOptions& options, TfDeviceId tf_device_id, size_t total_bytes) { DCHECK(process_state_); - const string& allocator_type = options.allocator_type(); + const std::string& allocator_type = options.allocator_type(); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); mutex_lock lock(mu_); tsl::CheckValidTfDeviceId(DeviceType(device_type_), diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h index 6e6b45fe887dca..6afb0daa77a2da 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h @@ -43,8 +43,8 @@ class PluggableDeviceProcessState { public: // Singleton that manages each platform's per-process state. e.g. allocation // of shared resource. - static PluggableDeviceProcessState* singleton(const string& device_type, - const string& platform_name); + static PluggableDeviceProcessState* singleton( + const std::string& device_type, const std::string& platform_name); // Query whether any PluggableDevice has been created so far. // Disable thread safety analysis since a race is benign here. @@ -89,8 +89,8 @@ class PluggableDeviceProcessState { protected: // PluggableDeviceProcessState is a singleton that should not normally be // deleted except at process shutdown. - PluggableDeviceProcessState(const string& device_type, - const string& platform_name); + PluggableDeviceProcessState(const std::string& device_type, + const std::string& platform_name); virtual ~PluggableDeviceProcessState() = default; ProcessState::MDMap* mem_desc_map() { @@ -101,8 +101,8 @@ class PluggableDeviceProcessState { static PluggableDeviceProcessState* instance_; ProcessState* process_state_; // Not owned. bool pluggable_device_enabled_; - const string device_type_; - const string platform_name_; + const std::string device_type_; + const std::string platform_name_; mutex mu_; struct AllocatorParts { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h index 27304954c25b0c..b7e9424982b22a 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h @@ -35,7 +35,7 @@ class PluggableDeviceSimpleAllocator : public Allocator { void DeallocateRaw(void* ptr) override; bool TracksAllocationSizes() const override { return false; } - string Name() override { return "Simple allocator"; } + std::string Name() override { return "Simple allocator"; } std::optional GetStats() override; AllocatorMemoryType GetMemoryType() const override { diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc index 29ae03e0d1f996..e74d99fd2af2ad 100644 --- a/tensorflow/core/common_runtime/pool_allocator.cc +++ b/tensorflow/core/common_runtime/pool_allocator.cc @@ -37,7 +37,7 @@ namespace tensorflow { PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize, SubAllocator* allocator, - RoundUpInterface* size_rounder, string name) + RoundUpInterface* size_rounder, std::string name) : name_(std::move(name)), has_size_limit_(pool_size_limit > 0), auto_resize_(auto_resize), diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h index 6ce3b7886cfa6b..69c1e7a75b88d9 100644 --- a/tensorflow/core/common_runtime/pool_allocator.h +++ b/tensorflow/core/common_runtime/pool_allocator.h @@ -55,10 +55,10 @@ class PoolAllocator : public Allocator { // malloc/free operations. This object takes ownership of allocator. PoolAllocator(size_t pool_size_limit, bool auto_resize, SubAllocator* allocator, RoundUpInterface* size_rounder, - string name); + std::string name); ~PoolAllocator() override; - string Name() override { return name_; } + std::string Name() override { return name_; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; @@ -121,7 +121,7 @@ class PoolAllocator : public Allocator { // Delete the least recently used record. void EvictOne() TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - const string name_; + const std::string name_; const bool has_size_limit_; const bool auto_resize_; size_t pool_size_limit_; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 98af5aedeedee1..c26495dfa83117 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -92,7 +92,7 @@ int64_t GetParallelSubgraphThreshold() { const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null"; void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit( - DistributedFunctionLibraryRuntime* parent, const string& function_name, + DistributedFunctionLibraryRuntime* parent, const std::string& function_name, const FunctionLibraryDefinition& lib_def, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::DoneCallback done) { @@ -149,16 +149,17 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( /* static */ absl::Status ProcessFunctionLibraryRuntime::SendTensors( - const string& source_device, const string& target_device, - const string& key_prefix, int64_t src_incarnation, + const std::string& source_device, const std::string& target_device, + const std::string& key_prefix, int64_t src_incarnation, absl::Span tensors_to_send, DeviceContext* device_context, const std::vector& alloc_attrs, RendezvousInterface* rendezvous) { - std::vector keys; + std::vector keys; for (int i = 0; i < tensors_to_send.size(); ++i) { - string name = strings::StrCat(key_prefix, i); - string key = Rendezvous::CreateKey(source_device, src_incarnation, - target_device, name, FrameAndIter(0, 0)); + std::string name = absl::StrCat(key_prefix, i); + std::string key = + Rendezvous::CreateKey(source_device, src_incarnation, target_device, + name, FrameAndIter(0, 0)); keys.push_back(key); } TF_RETURN_IF_ERROR(SendTensorsToRendezvous( @@ -168,17 +169,18 @@ absl::Status ProcessFunctionLibraryRuntime::SendTensors( /* static */ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( - const string& source_device, const string& target_device, - const string& key_prefix, int64_t src_incarnation, int64_t num_tensors, + const std::string& source_device, const std::string& target_device, + const std::string& key_prefix, int64_t src_incarnation, int64_t num_tensors, DeviceContext* device_context, const std::vector& alloc_attrs, RendezvousInterface* rendezvous, std::vector* received_tensors, StatusCallback done) { - std::vector keys; + std::vector keys; for (int64_t i = 0; i < num_tensors; ++i) { - string name = strings::StrCat(key_prefix, i); - string key = Rendezvous::CreateKey(source_device, src_incarnation, - target_device, name, FrameAndIter(0, 0)); + std::string name = absl::StrCat(key_prefix, i); + std::string key = + Rendezvous::CreateKey(source_device, src_incarnation, target_device, + name, FrameAndIter(0, 0)); keys.push_back(key); } RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys, @@ -207,7 +209,7 @@ absl::Status ProcessFunctionLibraryRuntime::GetRetTypes( } absl::Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( - const string& device_name, int64_t* incarnation) const { + const std::string& device_name, int64_t* incarnation) const { FunctionLibraryRuntime* flr = GetFLR(device_name); if (flr == nullptr) { return errors::InvalidArgument("Device name: ", device_name, " not found."); @@ -217,14 +219,14 @@ absl::Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( } absl::Status ProcessFunctionLibraryRuntime::GetDeviceContext( - const string& device_name, DeviceContext** device_context) const { + const std::string& device_name, DeviceContext** device_context) const { *device_context = nullptr; FunctionLibraryRuntime* flr = GetFLR(device_name); if (flr == nullptr) { return errors::InvalidArgument("Device name: ", device_name, " not found."); } Device* device = flr->device(); - string device_type = device->parsed_name().type; + std::string device_type = device->parsed_name().type; if (device_type == "CPU" || device_type == "TPU_SYSTEM") { // "TPU_SYSTEM" indicates that `device` is a CPU. return absl::OkStatus(); @@ -281,7 +283,7 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() { } FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( - const string& device_name) const { + const std::string& device_name) const { Device* device = nullptr; if (device_name != kDefaultFLRDevice) { if (!device_mgr_->LookupDevice(device_name, &device).ok()) { @@ -299,14 +301,14 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( } FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( - const string& function_key, const string& device_name, + const std::string& function_key, const std::string& device_name, FunctionLibraryRuntime::LocalHandle local_handle) { mutex_lock l(mu_); return AddHandleLocked(function_key, device_name, local_handle); } FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked( - const string& function_key, const string& device_name, + const std::string& function_key, const std::string& device_name, FunctionLibraryRuntime::LocalHandle local_handle) { auto h = next_handle_; function_data_[h] = @@ -318,7 +320,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked( FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddMultiDeviceHandle( - std::unique_ptr data, const string& function_key) { + std::unique_ptr data, + const std::string& function_key) { mutex_lock l(mu_); auto h = next_handle_; mdevice_data_[h] = std::move(data); @@ -338,14 +341,14 @@ bool ProcessFunctionLibraryRuntime::HasMultiDeviceHandle( } FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle( - const string& function_key) const { + const std::string& function_key) const { tf_shared_lock l(mu_); return gtl::FindWithDefault(table_, function_key, kInvalidHandle); } FunctionLibraryRuntime::LocalHandle ProcessFunctionLibraryRuntime::GetHandleOnDevice( - const string& device_name, FunctionLibraryRuntime::Handle handle, + const std::string& device_name, FunctionLibraryRuntime::Handle handle, bool include_multi_device) const { tf_shared_lock l(mu_); @@ -357,7 +360,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( if (data.glue_.size() != 1) return kInvalidLocalHandle; const auto& pair = *data.glue_.begin(); - const string& func_device_name = pair.first; + const std::string& func_device_name = pair.first; const ComponentFunctionData& component_data = pair.second; if (func_device_name != device_name) return kInvalidLocalHandle; @@ -377,7 +380,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( return function_data->local_handle(); } -string ProcessFunctionLibraryRuntime::GetDeviceName( +std::string ProcessFunctionLibraryRuntime::GetDeviceName( FunctionLibraryRuntime::Handle handle) const { tf_shared_lock l(mu_); auto iter = function_data_.find(handle); @@ -496,11 +499,11 @@ void ProcessFunctionLibraryRuntime::PublishSubgraphs( } absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { // Check if this function has already been instantiated. - const string& function_key = Canonicalize(function_name, attrs, options); + const std::string& function_key = Canonicalize(function_name, attrs, options); { mutex_lock l(mu_); @@ -517,12 +520,12 @@ absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( if (VLOG_IS_ON(3)) { int index = 0; VLOG(3) << "Requested input devices:"; - for (const string& device : options.input_devices) { + for (const std::string& device : options.input_devices) { VLOG(3) << " [input " << index++ << "] " << device; } index = 0; VLOG(3) << "Requested output devices:"; - for (const string& device : options.output_devices) { + for (const std::string& device : options.output_devices) { VLOG(3) << " [output " << index++ << "] " << device; } } @@ -552,7 +555,7 @@ absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( Device* cpu_device; TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device)); - const uint64 optimization_start_time_usecs = Env::Default()->NowMicros(); + const uint64_t optimization_start_time_usecs = Env::Default()->NowMicros(); // Look up for optimized function graph in library. If found, skip // `OptimizeFunctionGraph` step. std::optional> optimized_graph_proto = @@ -593,8 +596,8 @@ absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( function_name, *optimized_graph_info, options, *dev_set, lib_def_, composite_devices, cpu_device, env_)); - const uint64 optimization_end_time_usecs = Env::Default()->NowMicros(); - const uint64 graph_optimization_duration = + const uint64_t optimization_end_time_usecs = Env::Default()->NowMicros(); + const uint64_t graph_optimization_duration = optimization_end_time_usecs - optimization_start_time_usecs; metrics::UpdateFunctionGraphOptimizationTime(graph_optimization_duration); VLOG(1) << "Finished graph optimizations for MultiDevice function \"" @@ -617,11 +620,11 @@ absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( // We must preserve control returns in each of the function components, // otherwise after function inlining we might prune side-effectful nodes. const auto control_ret = - [&node_name_to_control_ret](const Node* n) -> std::optional { + [&node_name_to_control_ret](const Node* n) -> std::optional { const auto it = node_name_to_control_ret.find(n->name()); return it != node_name_to_control_ret.end() // NOLINTNEXTLINE - ? absl::make_optional(it->second) + ? absl::make_optional(it->second) // NOLINTNEXTLINE : absl::nullopt; }; @@ -659,11 +662,11 @@ absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( auto instantiate_component = [this, dev_set, &data_lib_def, &control_ret, &options, - &data](const string& target, + &data](const std::string& target, std::unique_ptr subgraph, ComponentFunctionData* comp_data, std::function done) { - const string& device_type = + const std::string& device_type = dev_set->FindDeviceByName(target)->device_type(); bool ints_on_device = @@ -854,7 +857,7 @@ absl::Status ProcessFunctionLibraryRuntime::GetOutputDevices( continue; } - const string& target = pair.first; + const std::string& target = pair.first; FunctionLibraryRuntime* target_flr = GetFLR(target); Device* target_device = nullptr; Device* host = nullptr; @@ -863,7 +866,7 @@ absl::Status ProcessFunctionLibraryRuntime::GetOutputDevices( data->has_remote_outputs = true; } target_device = device_set()->FindDeviceByName(target); - string remote_host; + std::string remote_host; TF_RETURN_IF_ERROR( DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host)); host = device_set()->FindDeviceByName(remote_host); @@ -917,14 +920,14 @@ absl::Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice( return absl::OkStatus(); } -std::vector ProcessFunctionLibraryRuntime::GetOrderedSubgraphs( +std::vector ProcessFunctionLibraryRuntime::GetOrderedSubgraphs( const MultiDeviceFunctionData* data) const { - std::vector subgraph_keys; + std::vector subgraph_keys; subgraph_keys.reserve(data->glue_.size()); for (const auto& pair : data->glue_) { subgraph_keys.push_back(pair.first); } - auto send_first_ordering = [&](const string& a, const string& b) { + auto send_first_ordering = [&](const std::string& a, const std::string& b) { auto a_summary = data->glue_.at(a).async_attributes.summary(); auto b_summary = data->glue_.at(b).async_attributes.summary(); if (a_summary == b_summary) { @@ -969,9 +972,9 @@ absl::Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( // // We assume that the partitioning has a valid deadlock-free ordering and the // safety of running synchronously has already been confirmed by this point. - std::vector subgraph_keys = GetOrderedSubgraphs(data); + std::vector subgraph_keys = GetOrderedSubgraphs(data); - for (const string& target : subgraph_keys) { + for (const std::string& target : subgraph_keys) { const ComponentFunctionData& comp_data = data->glue_.at(target); FunctionLibraryRuntime::Handle comp_handle = comp_data.handle; @@ -1003,9 +1006,9 @@ absl::Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( &comp_tensor_rets); if (!run_status.ok()) { VLOG(2) << "Component function execution failed: " << run_status; - const string function_and_msg = strings::StrCat( - errors::FormatFunctionForError(data->function_name_), " ", - run_status.message()); + const std::string function_and_msg = + absl::StrCat(errors::FormatFunctionForError(data->function_name_), + " ", run_status.message()); if (opts.rendezvous != nullptr) opts.rendezvous->StartAbort(run_status); return errors::CreateWithUpdatedMessage(run_status, function_and_msg); } else { @@ -1067,7 +1070,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( FunctionLibraryRuntime::Options opts_copy = opts; for (const auto& pair : data->glue_) { - const string& target = pair.first; + const std::string& target = pair.first; const ComponentFunctionData& comp_data = pair.second; FunctionLibraryRuntime::Handle comp_handle = pair.second.handle; @@ -1094,9 +1097,9 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( VLOG(2) << "Component function execution on target " << target << " from " << data->function_name_ << " with handle " << comp_handle << " failed: " << status; - const string function_and_msg = strings::StrCat( - errors::FormatFunctionForError(data->function_name_), " ", - status.message()); + const std::string function_and_msg = + absl::StrCat(errors::FormatFunctionForError(data->function_name_), + " ", status.message()); refcounted_done->UpdateStatus( errors::CreateWithUpdatedMessage(status, function_and_msg)); // Cancel the execution of other component functions. @@ -1147,7 +1150,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( } absl::Status ProcessFunctionLibraryRuntime::Instantiate( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { if (options.is_multi_device_function) { @@ -1195,7 +1198,7 @@ absl::Status ProcessFunctionLibraryRuntime::IsCrossProcess( } void ProcessFunctionLibraryRuntime::InstantiateRemote( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle, FunctionLibraryRuntime::DoneCallback done) { @@ -1207,7 +1210,7 @@ void ProcessFunctionLibraryRuntime::InstantiateRemote( } auto target = options.target; VLOG(1) << "ProcessFLR Instantiate: " << function_name << " on: " << target; - string function_key = Canonicalize(function_name, attrs, options); + std::string function_key = Canonicalize(function_name, attrs, options); FunctionData* f; { mutex_lock l(mu_); @@ -1257,7 +1260,7 @@ absl::Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( // Release all component function handles. absl::Status overall_status; for (const auto& it : mdata->glue_) { - const string& device = it.first; + const std::string& device = it.first; FunctionLibraryRuntime::Handle flr_handle = it.second.handle; FunctionLibraryRuntime* flr = GetFLR(device); if (flr == nullptr) { @@ -1291,7 +1294,7 @@ absl::Status ProcessFunctionLibraryRuntime::ReleaseHandle( } FunctionLibraryRuntime* flr = nullptr; - string target_device; + std::string target_device; { mutex_lock l(mu_); @@ -1455,7 +1458,7 @@ void ProcessFunctionLibraryRuntime::RunInternal( std::vector>* cleanup_items, FunctionLibraryRuntime::DoneCallback done) const { FunctionLibraryRuntime* flr = nullptr; - string target_device; + std::string target_device; FunctionLibraryRuntime::LocalHandle local_handle; { tf_shared_lock l(mu_); @@ -1480,7 +1483,7 @@ void ProcessFunctionLibraryRuntime::RunInternal( flr = GetFLR(target_device); if (flr != nullptr) { auto rendezvous = opts.rendezvous; - string source_device = opts.source_device; + std::string source_device = opts.source_device; DeviceContext* device_context; absl::Status s = GetDeviceContext(source_device, &device_context); if (!s.ok()) { diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 0305bde12e6cba..d37f341ae83531 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -94,8 +94,8 @@ class ProcessFunctionLibraryRuntime { // `tensors_to_send` and indicates how the input tensors are allocated. Method // takes references on each of the `tensors_to_send`. Method doesn't block. static absl::Status SendTensors( - const string& source_device, const string& target_device, - const string& key_prefix, int64_t src_incarnation, + const std::string& source_device, const std::string& target_device, + const std::string& key_prefix, int64_t src_incarnation, absl::Span tensors_to_send, DeviceContext* device_context, const std::vector& alloc_attrs, RendezvousInterface* rendezvous); @@ -107,23 +107,23 @@ class ProcessFunctionLibraryRuntime { // tensors and should either be empty or `num_tensors` in size. Method doesn't // block and calls `done` when `num_tensors` are fetched. static void ReceiveTensorsAsync( - const string& source_device, const string& target_device, - const string& key_prefix, int64_t src_incarnation, int64_t num_tensors, - DeviceContext* device_context, + const std::string& source_device, const std::string& target_device, + const std::string& key_prefix, int64_t src_incarnation, + int64_t num_tensors, DeviceContext* device_context, const std::vector& alloc_attrs, RendezvousInterface* rendezvous, std::vector* received_tensors, StatusCallback done); static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. - FunctionLibraryRuntime* GetFLR(const string& device_name) const; + FunctionLibraryRuntime* GetFLR(const std::string& device_name) const; // Returns the return types for the function identified by handle `h`. absl::Status GetRetTypes(FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types); // Returns the device incarnation for the given device_name. - absl::Status GetDeviceIncarnation(const string& device_name, + absl::Status GetDeviceIncarnation(const std::string& device_name, int64_t* incarnation) const; // For a given canonicalized key signature of the function instantiated @@ -131,11 +131,12 @@ class ProcessFunctionLibraryRuntime { // that value. Uses core/common_runtime/framework/function.h::Canonicalize // to canonicalize the function signature. FunctionLibraryRuntime::Handle AddHandle( - const string& function_key, const string& device_name, + const std::string& function_key, const std::string& device_name, FunctionLibraryRuntime::LocalHandle local_handle); // Returns a handle if found for the given key, else returns kInvalidHandle. - FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; + FunctionLibraryRuntime::Handle GetHandle( + const std::string& function_key) const; // For the given handle instantiated on device `device_name` returns the local // index of instantiation of that function. If the function was not @@ -146,7 +147,7 @@ class ProcessFunctionLibraryRuntime { // with a single component that is placed on `device_name`, then this method // will return the local handle for that component. FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( - const string& device_name, FunctionLibraryRuntime::Handle handle, + const std::string& device_name, FunctionLibraryRuntime::Handle handle, bool include_multi_device = false) const; // Fills `output_devices` with the devices on which the results will @@ -161,7 +162,7 @@ class ProcessFunctionLibraryRuntime { // Allows for function_name to be instantiated on different devices // as specified in attrs. absl::Status Instantiate( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle); @@ -273,7 +274,7 @@ class ProcessFunctionLibraryRuntime { // The handle for the instantiated component function. FunctionLibraryRuntime::Handle handle; // The name for the component function. - string name; + std::string name; // arg_indices.size() is the number of arguments to the component function. // The i-th argument of the component function comes from the // `arg_indices[i]`-th argument of the multi-device function. @@ -297,8 +298,8 @@ class ProcessFunctionLibraryRuntime { // The fields are filled in during instantiation. Once the object is // added to mdevice_data_, all fields are constant. struct MultiDeviceFunctionData { - MultiDeviceFunctionData(const string& function_name, - const string& function_key, int num_outputs, + MultiDeviceFunctionData(const std::string& function_name, + const std::string& function_key, int num_outputs, DataTypeVector ret_types) : function_name_(function_name), function_key_(function_key), @@ -308,9 +309,9 @@ class ProcessFunctionLibraryRuntime { is_cross_process_(false), has_remote_outputs(false) {} - const string function_name_; - const string function_key_; - uint64 instantiation_counter_; + const std::string function_name_; + const std::string function_key_; + uint64_t instantiation_counter_; // Stored here to resize the output tensor vector when function is run. const int num_outputs_; DataTypeVector ret_types_; @@ -325,12 +326,12 @@ class ProcessFunctionLibraryRuntime { // Maps the device name to the information about the component function // be run on this device. - std::unordered_map glue_; + std::unordered_map glue_; }; struct CleanUpItem { - string device; - uint64 step_id; + std::string device; + uint64_t step_id; FunctionLibraryRuntime::Handle local_handle; }; @@ -343,18 +344,18 @@ class ProcessFunctionLibraryRuntime { private: FunctionLibraryRuntime::Handle AddHandleLocked( - const string& function_key, const string& device_name, + const std::string& function_key, const std::string& device_name, FunctionLibraryRuntime::LocalHandle local_handle) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // For a given device_name, returns a DeviceContext for copying // tensors to/from the device. - absl::Status GetDeviceContext(const string& device_name, + absl::Status GetDeviceContext(const std::string& device_name, DeviceContext** device_context) const; // Looks up the information for the given `handle` and returns the name // of the device where the function is registered. - string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; + std::string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; // Removes handle from the state owned by this object. absl::Status RemoveHandle(FunctionLibraryRuntime::Handle handle); @@ -380,19 +381,19 @@ class ProcessFunctionLibraryRuntime { absl::Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); absl::Status InstantiateMultiDevice( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle); void InstantiateRemote( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle, FunctionLibraryRuntime::DoneCallback done); FunctionLibraryRuntime::Handle AddMultiDeviceHandle( const std::unique_ptr data, - const string& function_key); + const std::string& function_key); bool HasMultiDeviceHandle(FunctionLibraryRuntime::Handle handle) const; @@ -426,7 +427,7 @@ class ProcessFunctionLibraryRuntime { InternalArgs* comp_args); #endif // IS_MOBILE_PLATFORM - std::vector GetOrderedSubgraphs( + std::vector GetOrderedSubgraphs( const MultiDeviceFunctionData* data) const; absl::Status PrepareRunMultiDevice( @@ -458,15 +459,15 @@ class ProcessFunctionLibraryRuntime { // (to be executed on `target_device`) function. class FunctionData { public: - FunctionData(const string& target_device, + FunctionData(const std::string& target_device, FunctionLibraryRuntime::LocalHandle local_handle, - const string& function_key) + const std::string& function_key) : target_device_(target_device), local_handle_(local_handle), function_key_(function_key) {} - const string& target_device() { return target_device_; } - const string& function_key() { return function_key_; } + const std::string& target_device() { return target_device_; } + const std::string& function_key() { return function_key_; } FunctionLibraryRuntime::LocalHandle local_handle() { mutex_lock l(mu_); @@ -476,7 +477,8 @@ class ProcessFunctionLibraryRuntime { // Initializes the FunctionData object by potentially making an Initialize // call to the DistributedFunctionLibraryRuntime. void DistributedInit( - DistributedFunctionLibraryRuntime* parent, const string& function_name, + DistributedFunctionLibraryRuntime* parent, + const std::string& function_name, const FunctionLibraryDefinition& lib_def, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::DoneCallback done); @@ -489,9 +491,9 @@ class ProcessFunctionLibraryRuntime { private: mutex mu_; - const string target_device_; + const std::string target_device_; FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_); - const string function_key_; + const std::string function_key_; bool is_cross_process_ TF_GUARDED_BY(mu_) = false; bool init_started_ TF_GUARDED_BY(mu_) = false; absl::Status init_result_ TF_GUARDED_BY(mu_); @@ -516,7 +518,7 @@ class ProcessFunctionLibraryRuntime { std::vector composite_devices_ TF_GUARDED_BY(mu_); // Holds all the function instantiations. Maps function_keys to handles. - std::unordered_map table_ + std::unordered_map table_ TF_GUARDED_BY(mu_); // Function data for instantiated remote functions. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index df2f3db3f68ca7..5458203f8c592c 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -59,7 +59,7 @@ class TestClusterFLR : public DistributedFunctionLibraryRuntime { public: explicit TestClusterFLR(DeviceMgr* device_mgr) : device_mgr_(device_mgr) {} - void Instantiate(const string& function_name, + void Instantiate(const std::string& function_name, const FunctionLibraryDefinition& lib_def, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::LocalHandle* handle, @@ -82,7 +82,7 @@ class TestClusterFLR : public DistributedFunctionLibraryRuntime { absl::Span args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) override {} - void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, + void CleanUp(uint64_t step_id, FunctionLibraryRuntime::LocalHandle handle, FunctionLibraryRuntime::DoneCallback done) override {} DeviceMgr* remote_device_mgr() const override { return device_mgr_; } @@ -169,7 +169,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status Instantiate( - const string& name, test::function::Attrs attrs, + const std::string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, FunctionLibraryRuntime::Handle* handle) { return proc_flr_->Instantiate(name, attrs, instantiate_opts, handle); @@ -214,7 +214,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { template absl::Status RunWithRuntime( - const string& name, FunctionLibraryRuntime::Options opts, + const std::string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const T& args, std::vector rets, @@ -270,7 +270,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status Run( - const string& name, FunctionLibraryRuntime::Options opts, + const std::string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const std::vector& args, std::vector rets, @@ -280,7 +280,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } absl::Status RunWithPackedArgs( - const string& name, FunctionLibraryRuntime::Options opts, + const std::string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const FunctionArgsInterface& args, std::vector rets, @@ -503,7 +503,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { TEST_F(ProcessFunctionLibraryRuntimeTest, SameDeviceXTimesFourInt32MultiDevice) { Init({test::function::XTimesTwoInt32(), test::function::XTimesFourInt32()}); - auto x = test::AsTensor({1, 2, 3, 4}); + auto x = test::AsTensor({1, 2, 3, 4}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; opts.remote_execution = true; @@ -515,13 +515,13 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Tensor y; TF_CHECK_OK(Run("XTimesFourInt32", opts, {{"T", DT_INT32}}, instantiate_opts, {x}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); + test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimesMultiDevice) { Init({test::function::XTimesTwoInt32(), test::function::XTimesFourInt32()}); - auto x = test::AsTensor({1, 2, 3, 4}); + auto x = test::AsTensor({1, 2, 3, 4}); FunctionLibraryRuntime::Options opts; opts.source_device = "/job:a/replica:0/task:0/cpu:0"; opts.remote_execution = true; @@ -533,10 +533,10 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Tensor y; TF_CHECK_OK(Run("XTimesTwoInt32", opts, {{"T", DT_INT32}}, instantiate_opts, {x}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); + test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); TF_CHECK_OK(Run("XTimesFourInt32", opts, {{"T", DT_INT32}}, instantiate_opts, {x}, {&y})); - test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); + test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { @@ -668,7 +668,7 @@ bool IsCUDATensor(const Tensor& t) { void TestTwoDeviceMult( ProcessFunctionLibraryRuntimeTest* fixture, const FunctionLibraryRuntime::InstantiateOptions& inst_opts, - const string& error = "") { + const std::string& error = "") { fixture->Init({test::function::TwoDeviceMult()}); FunctionLibraryRuntime::Options opts; auto x = test::AsTensor({1, 2, 3}); @@ -764,18 +764,18 @@ void TestTwoDeviceInputOutput( test::ExpectTensorEqual(y2, test::AsTensor({30, 60})); } -std::vector CompleteDevices(const std::vector& v) { - std::vector result; +std::vector CompleteDevices(const std::vector& v) { + std::vector result; result.reserve(v.size()); - for (const string& s : v) { - result.push_back(strings::StrCat("/job:a/replica:0/task:0/device:", s)); + for (const std::string& s : v) { + result.push_back(absl::StrCat("/job:a/replica:0/task:0/device:", s)); } return result; } FunctionLibraryRuntime::InstantiateOptions MakeOptions( - const string& target, const std::vector& input_devices, - const std::vector& output_devices) { + const std::string& target, const std::vector& input_devices, + const std::vector& output_devices) { FunctionLibraryRuntime::InstantiateOptions inst_opts; inst_opts.target = target; inst_opts.input_devices = CompleteDevices(input_devices); @@ -924,8 +924,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_EmptyBodySwap) { test::ExpectTensorEqual(y2, test::AsTensor({1, 2})); } -Tensor GetResourceHandle(const string& var_name, const string& container, - const string& device_name) { +Tensor GetResourceHandle(const std::string& var_name, + const std::string& container, + const std::string& device_name) { ResourceHandle handle; handle.set_device(device_name); handle.set_container(container); @@ -1189,8 +1190,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_StateHandle) { // Attrs {}, // Nodes - {FunctionDefHelper::Const("shape", absl::Span({1})), - FunctionDefHelper::Const("minval", 0), + {FunctionDefHelper::Const("shape", + absl::Span({1})), + FunctionDefHelper::Const("minval", 0), {{"maxval"}, "ReadVariableOp", {"x"}, {{"dtype", T}}, {}}, // A stateful node. {{"y"}, diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc index a91eb74f1ef464..c79b42faffe83c 100644 --- a/tensorflow/core/common_runtime/process_state.cc +++ b/tensorflow/core/common_runtime/process_state.cc @@ -46,7 +46,7 @@ namespace tensorflow { ProcessState::ProcessState() : numa_enabled_(false), cpu_allocators_cached_(0) {} -string ProcessState::MemDesc::DebugString() { +std::string ProcessState::MemDesc::DebugString() { return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ", gpu_registered, ", nic: ", nic_registered); } diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h index dd667cc236a8e9..eb0b7f53a8c7a4 100644 --- a/tensorflow/core/common_runtime/process_state.h +++ b/tensorflow/core/common_runtime/process_state.h @@ -51,7 +51,7 @@ class ProcessState : public ProcessStateInterface { dev_index(0), gpu_registered(false), nic_registered(false) {} - string DebugString(); + std::string DebugString(); }; // If NUMA Allocators are desired, call this before calling any @@ -122,7 +122,7 @@ class RecordingAllocator : public Allocator { ProcessState::MemDesc md, mutex* mu) : mm_(mm), a_(a), md_(md), mu_(mu) {} - string Name() override { return a_->Name(); } + std::string Name() override { return a_->Name(); } void* AllocateRaw(size_t alignment, size_t num_bytes) override { void* p = a_->AllocateRaw(alignment, num_bytes); mutex_lock l(*mu_); diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc index 65733614bdc54c..233dcde498a6bc 100644 --- a/tensorflow/core/common_runtime/process_util.cc +++ b/tensorflow/core/common_runtime/process_util.cc @@ -35,12 +35,12 @@ namespace tensorflow { namespace { // Use environment setting if specified (init once) -int32 GetEnvNumInterOpThreads() { +int32_t GetEnvNumInterOpThreads() { static int32_t env_num_threads = NumInterOpThreadsFromEnvironment(); return env_num_threads; } -int32 DefaultNumInterOpThreads() { +int32_t DefaultNumInterOpThreads() { #ifndef __ANDROID__ int32_t env_num_threads = GetEnvNumInterOpThreads(); if (env_num_threads > 0) { @@ -90,13 +90,13 @@ thread::ThreadPool* ComputePool(const SessionOptions& options) { return compute_pool; } -int32 NumInterOpThreadsFromEnvironment() { +int32_t NumInterOpThreadsFromEnvironment() { int32_t num; const char* val = std::getenv("TF_NUM_INTEROP_THREADS"); return (val && absl::SimpleAtoi(val, &num)) ? num : 0; } -int32 NumIntraOpThreadsFromEnvironment() { +int32_t NumIntraOpThreadsFromEnvironment() { int32_t num; const char* val = std::getenv("TF_NUM_INTRAOP_THREADS"); return (val && absl::SimpleAtoi(val, &num)) ? num : 0; @@ -122,7 +122,7 @@ int32 DefaultNumIntraOpThreads() { return port::MaxParallelism(); } #endif // defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_MKL) -int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { +int32_t NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { const int32_t inter_op = options.config.inter_op_parallelism_threads(); if (inter_op > 0) return inter_op; const int32_t env_inter_op = GetEnvNumInterOpThreads(); @@ -169,7 +169,7 @@ void SchedClosure(absl::AnyInvocable closure) { if (!tsl::tracing::EventCollector::IsEnabled()) { return Env::Default()->SchedClosure(std::move(closure)); } - uint64 id = tsl::tracing::GetUniqueArg(); + uint64_t id = tsl::tracing::GetUniqueArg(); tsl::tracing::RecordEvent(tsl::tracing::EventCategory::kScheduleClosure, id); Env::Default()->SchedClosure([id, closure = std::move(closure)]() mutable { diff --git a/tensorflow/core/common_runtime/process_util.h b/tensorflow/core/common_runtime/process_util.h index cc2bc4390793c0..682556d19fbfad 100644 --- a/tensorflow/core/common_runtime/process_util.h +++ b/tensorflow/core/common_runtime/process_util.h @@ -32,10 +32,10 @@ namespace tensorflow { thread::ThreadPool* ComputePool(const SessionOptions& options); // Returns the TF_NUM_INTEROP_THREADS environment value, or 0 if not specified. -int32 NumInterOpThreadsFromEnvironment(); +int32_t NumInterOpThreadsFromEnvironment(); // Returns the TF_NUM_INTRAOP_THREADS environment value, or 0 if not specified. -int32 NumIntraOpThreadsFromEnvironment(); +int32_t NumIntraOpThreadsFromEnvironment(); // Returns the number of inter op threads specified in `options` or a default. // If no value or a negative value is specified in the provided options, then @@ -43,7 +43,7 @@ int32 NumIntraOpThreadsFromEnvironment(); // environment variable. If neither a value is specified in the options or in // the environment, this function will return a reasonable default value based // on the number of schedulable CPUs, and any MKL and OpenMP configurations. -int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options); +int32_t NumInterOpThreadsFromSessionOptions(const SessionOptions& options); // Creates a thread pool with number of inter op threads. // The number is set if `num_threads` > 0, otherwise it will be configured by diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h index 71aac10bf6887a..28ae706c3f08b7 100644 --- a/tensorflow/core/common_runtime/profile_handler.h +++ b/tensorflow/core/common_runtime/profile_handler.h @@ -40,9 +40,9 @@ class ProfileHandler { // - label: Extra content for timeline click text. // - op_type: String name of the Op. // - details: Main content for timeline click text. - virtual void RecordOneOp(const string& device, const NodeExecStats& stats, - bool is_copy, absl::string_view label, - absl::string_view op_type, + virtual void RecordOneOp(const std::string& device, + const NodeExecStats& stats, bool is_copy, + absl::string_view label, absl::string_view op_type, absl::string_view details) = 0; // Records that the current step finished. diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index dee1903b112b2b..6d65024cc4f50a 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -159,7 +159,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, if (need_create_iter) { tsl::profiler::TraceMe activit1y( [&]() { - return strings::StrCat( + return absl::StrCat( "PropagateOutputs::NextIteration::CreateIterationState"); }, tsl::profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); @@ -259,7 +259,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, const ImmutableExecutorState::FrameInfo& frame_info = immutable_state_.get_enter_frame_info(node_item); - const uint64 child_id = Hash64Combine( + const uint64_t child_id = Hash64Combine( frame->frame_id, Hash64Combine(iter_state->iter_num, Hash64(frame_info.name))); @@ -275,7 +275,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, // Need to create a new frame instance. // Note that this new frame instance is created without any locks. if (vlog_) { - const string child_name = strings::StrCat( + const std::string child_name = strings::StrCat( frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name); VLOG(2) << "Create frame: " << child_name << " id: " << child_id; } diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index bdfea225a5ac2d..238cb0552b2c67 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -255,11 +255,11 @@ class PropagatorState { // The name of this frame, which is the concatenation of its parent // frame name, the iteration of the parent frame when this frame was // created, and the value of the attr 'frame_name'. - string frame_name; + std::string frame_name; // The unique id for this frame. Generated by fingerprinting // frame_name. - uint64 frame_id; + uint64_t frame_id; // The iteration state of its parent frame when this frame is created. // nullptr if there is no parent frame. The frame_name/parent_iter pair @@ -543,7 +543,7 @@ class PropagatorState { // child frame is a hash composed of the ID of the parent frame, the iteration // number at which the parent frame is creating the new frame, and the // name of the new frame from nodedef. - absl::flat_hash_map outstanding_frames_ + absl::flat_hash_map outstanding_frames_ TF_GUARDED_BY(mu_); PropagatorState(const PropagatorState&) = delete; @@ -579,12 +579,12 @@ class OrderedPropagatorState : public PropagatorState { private: static bool compare(TaggedNode const& lhs, TaggedNode const& rhs) { - std::tuple lhs_prio{lhs.node_item->node_id, - lhs.input_frame->frame_id, - lhs.input_iter->iter_num}; - std::tuple rhs_prio{rhs.node_item->node_id, - rhs.input_frame->frame_id, - rhs.input_iter->iter_num}; + std::tuple lhs_prio{lhs.node_item->node_id, + lhs.input_frame->frame_id, + lhs.input_iter->iter_num}; + std::tuple rhs_prio{rhs.node_item->node_id, + rhs.input_frame->frame_id, + rhs.input_iter->iter_num}; return lhs_prio < rhs_prio; } diff --git a/tensorflow/core/common_runtime/quantize_training.cc b/tensorflow/core/common_runtime/quantize_training.cc index c800552b5d3bca..3459153ed7dace 100644 --- a/tensorflow/core/common_runtime/quantize_training.cc +++ b/tensorflow/core/common_runtime/quantize_training.cc @@ -35,18 +35,18 @@ namespace tensorflow { namespace { // TODO(suharshs): If desired, make these values configurable. -const uint32 kAllowedInputs = 2; +const uint32_t kAllowedInputs = 2; const float kEMADecay = 0.999; // Node types to rewrite. Insert quantize_and_dequantize op for their inputs. const auto* nodes_to_rewrite = - new std::unordered_set{"MatMul", "Conv2D"}; + new std::unordered_set{"MatMul", "Conv2D"}; // Contains necessary parameters to convert an edge. struct EdgeToConvert { // edge is not owned here. const Edge* edge; - int32 num_bits; + int32_t num_bits; bool signed_input; bool range_given; float input_min; @@ -67,7 +67,7 @@ struct EdgeToConvert { // TODO(jmchen): Make this check more robust as it is not guaranteed that the // forward node will not be named with a leading "gradients". inline bool IsGradientNode(const Graph* graph, const Node* node) { - static const string tag = "gradients"; + static const std::string tag = "gradients"; return (node->name().compare(0, tag.size(), tag) == 0); } @@ -76,7 +76,7 @@ inline bool IsGradientNode(const Graph* graph, const Node* node) { // Returns true if the root tensor op type is known, false otherwise. bool FindType(const Graph* graph, const Node* node, bool* signed_input, bool* range_given, float* input_min, float* input_max) { - const string& src_op = node->type_string(); + const std::string& src_op = node->type_string(); if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") { *signed_input = true; *range_given = false; @@ -154,7 +154,7 @@ absl::Status FindSaveOp(const Graph* graph, Node** save_op, Node* FindRestoreAllOp(const Graph* graph, absl::string_view save_prefix) { for (Node* node : graph->op_nodes()) { // The restore_all op should have the same prefix of the save_op. - if (node->name() == strings::StrCat(save_prefix, "/restore_all")) { + if (node->name() == absl::StrCat(save_prefix, "/restore_all")) { return node; } } @@ -254,21 +254,21 @@ absl::Status AddRestoreVariableSubgraphs( if (restore_all == nullptr) { return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp"); } - const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2"); - const string assign_op_name = strings::StrCat(name_prefix, "/Assign"); + const std::string restore_op_name = absl::StrCat(name_prefix, "/RestoreV2"); + const std::string assign_op_name = absl::StrCat(name_prefix, "/Assign"); for (Node* var : variables) { // Add an extra prefix after calling graph->NewName because the "unique" // name may conflict with names generated for Send nodes. // TODO(b/77547936): fix this more generally and get rid of the extra prefix // here. - string new_restore_op_name = - strings::StrCat(graph->NewName(restore_op_name), "_qt"); - string new_assign_op_name = - strings::StrCat(graph->NewName(assign_op_name), "_qt"); - string tensor_names_op_name = - strings::StrCat(new_restore_op_name, "/tensor_names"); - string shape_and_slices_op_name = - strings::StrCat(new_restore_op_name, "/shape_and_slices"); + std::string new_restore_op_name = + absl::StrCat(graph->NewName(restore_op_name), "_qt"); + std::string new_assign_op_name = + absl::StrCat(graph->NewName(assign_op_name), "_qt"); + std::string tensor_names_op_name = + absl::StrCat(new_restore_op_name, "/tensor_names"); + std::string shape_and_slices_op_name = + absl::StrCat(new_restore_op_name, "/shape_and_slices"); // Construct the tensor_names input with the variable name. Node* tensor_names; @@ -329,32 +329,32 @@ absl::Status AddSaveAndRestore(Graph* graph, // Sets output to the Node that computes reduction axes corresponding to all // dimensions of input and return. -absl::Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, - Node** output) { - name_prefix = strings::StrCat(name_prefix, "/ReductionAxes"); +absl::Status MakeReductionAxes(Graph* graph, std::string name_prefix, + Node* input, Node** output) { + name_prefix = absl::StrCat(name_prefix, "/ReductionAxes"); Node* start; Tensor zero_tensor(DT_INT32, TensorShape()); - zero_tensor.flat()(0) = 0; + zero_tensor.flat()(0) = 0; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const") + NodeBuilder(absl::StrCat(name_prefix, "/RangeStart"), "Const") .Attr("dtype", DT_INT32) .Attr("value", zero_tensor) .Finalize(graph, &start)); Node* delta; Tensor one_tensor(DT_INT32, TensorShape()); - one_tensor.flat()(0) = 1; + one_tensor.flat()(0) = 1; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const") + NodeBuilder(absl::StrCat(name_prefix, "/RangeDelta"), "Const") .Attr("dtype", DT_INT32) .Attr("value", one_tensor) .Finalize(graph, &delta)); Node* rank; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank") + NodeBuilder(absl::StrCat(name_prefix, "/InputRank"), "Rank") .Input(input) .Finalize(graph, &rank)); TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range") + NodeBuilder(absl::StrCat(name_prefix, "/ReductionAxes"), "Range") .Input(start) .Input(rank) .Input(delta) @@ -363,45 +363,43 @@ absl::Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, } // Computes the exponential moving average of input, updated in update_variable. -absl::Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, +absl::Status MakeExponentialMovingAverage(Graph* graph, std::string name_prefix, const NodeBuilder::NodeOut& input, Node* decay, Node* update_variable, Node** assign_value) { // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)] - name_prefix = strings::StrCat(name_prefix, "/EMA"); + name_prefix = absl::StrCat(name_prefix, "/EMA"); Node* one; Tensor one_tensor(DT_FLOAT, TensorShape()); one_tensor.flat()(0) = 1.0; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const") + NodeBuilder(absl::StrCat(name_prefix, "/OneConst"), "Const") .Attr("dtype", DT_FLOAT) .Attr("value", one_tensor) .Finalize(graph, &one)); Node* decay_complement; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub") + NodeBuilder(absl::StrCat(name_prefix, "/DecayComplement"), "Sub") .Input(one) .Input(decay) .Finalize(graph, &decay_complement)); Node* value_diff; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub") - .Input(update_variable) - .Input(input) - .Finalize(graph, &value_diff)); + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name_prefix, "/ValueDiff"), "Sub") + .Input(update_variable) + .Input(input) + .Finalize(graph, &value_diff)); Node* update_value; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul") + NodeBuilder(absl::StrCat(name_prefix, "/UpdateValue"), "Mul") .Input(value_diff) .Input(decay_complement) .Finalize(graph, &update_value)); - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub") - .Input(update_variable) - .Input(update_value) - .Finalize(graph, assign_value)); + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name_prefix, "/EMAValue"), "Sub") + .Input(update_variable) + .Input(update_value) + .Finalize(graph, assign_value)); return absl::OkStatus(); } @@ -416,25 +414,24 @@ absl::Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, // | EMA init_val // | \ / // +----------- assign -absl::Status MakeInitializedEMAVariable(Graph* graph, const string& name, +absl::Status MakeInitializedEMAVariable(Graph* graph, const std::string& name, Node* decay, Node* init_val, std::vector* added_variables, Node** var) { // TODO(suharshs): Update this to use ResourceVariables when they are ready. - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2") - .Attr("shape", TensorShape()) - .Attr("dtype", DT_FLOAT) - .Finalize(graph, var)); + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name, "/Variable"), "VariableV2") + .Attr("shape", TensorShape()) + .Attr("dtype", DT_FLOAT) + .Finalize(graph, var)); added_variables->push_back(*var); Node* is_initialized; - TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"), - "IsVariableInitialized") - .Input(*var) - .Finalize(graph, &is_initialized)); + TF_RETURN_IF_ERROR( + NodeBuilder(absl::StrCat(name, "/IsInitialized"), "IsVariableInitialized") + .Input(*var) + .Finalize(graph, &is_initialized)); Node* switch_node; - TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch") + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name, "/Switch"), "Switch") .Input(init_val) .Input(is_initialized) .Finalize(graph, &switch_node)); @@ -446,20 +443,19 @@ absl::Status MakeInitializedEMAVariable(Graph* graph, const string& name, decay, *var, &ema_value)); Node* assign_value; - TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge") + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name, "/Merge"), "Merge") .Input({output_false, ema_value}) .Finalize(graph, &assign_value)); - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign") - .Input(*var) - .Input(assign_value) - .Finalize(graph, var)); + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name, "/AssignValue"), "Assign") + .Input(*var) + .Input(assign_value) + .Finalize(graph, var)); return absl::OkStatus(); } // Computes the min and max EMA of input and stores them in min_var and max_var. -absl::Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, +absl::Status MakeEMAMinMaxVars(Graph* graph, const std::string& name_prefix, Node* input, std::vector* added_variables, Node** min_var, Node** max_var) { // TODO(suharshs): The decay will be constant, so we could make only one for @@ -468,23 +464,22 @@ absl::Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Tensor decay_tensor(DT_FLOAT, TensorShape()); decay_tensor.flat()(0) = kEMADecay; Node* decay; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const") - .Attr("dtype", DT_FLOAT) - .Attr("value", decay_tensor) - .Finalize(graph, &decay)); + TF_RETURN_IF_ERROR(NodeBuilder(absl::StrCat(name_prefix, "/Decay"), "Const") + .Attr("dtype", DT_FLOAT) + .Attr("value", decay_tensor) + .Finalize(graph, &decay)); Node* reduction_axes; TF_RETURN_IF_ERROR( MakeReductionAxes(graph, name_prefix, input, &reduction_axes)); Node* min; - string min_name = strings::StrCat(name_prefix, "/Min"); + std::string min_name = absl::StrCat(name_prefix, "/Min"); TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min") .Input(input) .Input(reduction_axes) .Finalize(graph, &min)); Node* max; - string max_name = strings::StrCat(name_prefix, "/Max"); + std::string max_name = absl::StrCat(name_prefix, "/Max"); TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max") .Input(input) .Input(reduction_axes) @@ -498,7 +493,7 @@ absl::Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, // Makes an input min and max constant if the range is given. Otherwise, makes // min and max variables that are updated by an EMA. -absl::Status MakeInputMinMax(Graph* graph, const string& name_prefix, +absl::Status MakeInputMinMax(Graph* graph, const std::string& name_prefix, const EdgeToConvert& edge, std::vector* added_variables, Node** input_min, Node** input_max) { @@ -508,14 +503,14 @@ absl::Status MakeInputMinMax(Graph* graph, const string& name_prefix, Tensor input_min_tensor(DT_FLOAT, TensorShape()); input_min_tensor.flat()(0) = edge.input_min; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const") + NodeBuilder(absl::StrCat(name_prefix, "/InputMin"), "Const") .Attr("dtype", DT_FLOAT) .Attr("value", input_min_tensor) .Finalize(graph, input_min)); Tensor input_max_tensor(DT_FLOAT, TensorShape()); input_max_tensor.flat()(0) = edge.input_max; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const") + NodeBuilder(absl::StrCat(name_prefix, "/InputMax"), "Const") .Attr("dtype", DT_FLOAT) .Attr("value", input_max_tensor) .Finalize(graph, input_max)); @@ -532,8 +527,8 @@ absl::Status MakeInputMinMax(Graph* graph, const string& name_prefix, // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op // (and required input nodes) based on edge. // The result is stored in convert_node. -absl::Status MakeQuantizeOp(Graph* graph, const string& name_prefix, - const string& quant_op_type, +absl::Status MakeQuantizeOp(Graph* graph, const std::string& name_prefix, + const std::string& quant_op_type, const EdgeToConvert& edge, std::vector* added_variables, Node** convert_node) { @@ -541,7 +536,7 @@ absl::Status MakeQuantizeOp(Graph* graph, const string& name_prefix, Node* input_max; TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables, &input_min, &input_max)); - string quant_name = strings::StrCat(name_prefix, "/", quant_op_type); + std::string quant_name = absl::StrCat(name_prefix, "/", quant_op_type); if (quant_op_type == "QuantizeAndDequantizeV2") { TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type) .Input(edge.edge->src()) @@ -566,15 +561,15 @@ absl::Status MakeQuantizeOp(Graph* graph, const string& name_prefix, // Insert conversion op, connect it to the graph and remove the old edge. absl::Status ProcessTargetEdges( - Graph* graph, const string& quant_op_type, + Graph* graph, const std::string& quant_op_type, const std::vector& target_edges) { // Remember previously converted ops to avoid duplicated conversion on the // same input. - std::unordered_map name_index; + std::unordered_map name_index; std::vector added_variables; for (const EdgeToConvert edge : target_edges) { Node* convert_node; - string name_prefix = edge.edge->src()->name(); + std::string name_prefix = edge.edge->src()->name(); auto iter = name_index.find(name_prefix); if (iter == name_index.end()) { @@ -596,7 +591,8 @@ absl::Status ProcessTargetEdges( } // namespace -absl::Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, +absl::Status DoQuantizeTraining(int32_t num_bits, + const std::string& quant_op_type, Graph* graph) { if (graph == nullptr) { return errors::InvalidArgument("Cannot accept empty graph pointer."); @@ -663,7 +659,7 @@ absl::Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, absl::Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, int32_t num_bits, - const string& quant_op_type, + const std::string& quant_op_type, GraphDef* result_graphdef) { Graph graph(OpRegistry::Global()); GraphConstructorOptions opts; @@ -678,8 +674,8 @@ absl::Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, } absl::Status DoQuantizeTrainingOnSerializedGraphDef( - const string& input_graph_string, int32_t num_bits, - const string& quant_op_type, string* result_graph_string) { + const std::string& input_graph_string, int32_t num_bits, + const std::string& quant_op_type, std::string* result_graph_string) { // First create the graph from the GraphDef. GraphDef input_graphdef; if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) { diff --git a/tensorflow/core/common_runtime/quantize_training.h b/tensorflow/core/common_runtime/quantize_training.h index de3ed6b476b24a..21f794cbec8f2c 100644 --- a/tensorflow/core/common_runtime/quantize_training.h +++ b/tensorflow/core/common_runtime/quantize_training.h @@ -35,21 +35,20 @@ namespace tensorflow { // - num_bits out of range. // - g is null. // - More than 1 unknown ops encountered. -absl::Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, - Graph* g); +absl::Status DoQuantizeTraining(int32_t num_bits, + const std::string& quant_op_type, Graph* g); // Converts the input serialized GraphDef and returns a rewritten serialized // GraphDef for quantized training. -absl::Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph, - int32_t num_bits, - const string& quant_op_type, - string* result_graph); +absl::Status DoQuantizeTrainingOnSerializedGraphDef( + const std::string& input_graph, int32_t num_bits, + const std::string& quant_op_type, std::string* result_graph); // Converts the input GraphDef and returns a rewritten GraphDef for quantized // training. absl::Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, int32_t num_bits, - const string& quant_op_type, + const std::string& quant_op_type, GraphDef* result_graphdef); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/quantize_training_test.cc b/tensorflow/core/common_runtime/quantize_training_test.cc index 7f2e1b0e709d35..5d4a1ac2618de2 100644 --- a/tensorflow/core/common_runtime/quantize_training_test.cc +++ b/tensorflow/core/common_runtime/quantize_training_test.cc @@ -51,7 +51,7 @@ class QuantizeTrainingTest : public ::testing::Test { return test::graph::Constant(g_.get(), test::AsTensor(values, shape)); } - absl::Status Placeholder(Graph* g, const string& name, TensorShape shape, + absl::Status Placeholder(Graph* g, const std::string& name, TensorShape shape, Node** out) { TF_RETURN_IF_ERROR(NodeBuilder(name, "Placeholder") .Attr("dtype", DT_FLOAT) @@ -60,7 +60,7 @@ class QuantizeTrainingTest : public ::testing::Test { return absl::OkStatus(); } - absl::Status FindNode(Graph* g, const string& name, Node** out) { + absl::Status FindNode(Graph* g, const std::string& name, Node** out) { for (Node* node : g->nodes()) { if (node->name() == name) { *out = node; @@ -111,15 +111,14 @@ TEST_F(QuantizeTrainingTest, SignedInput) { // Quantize_and_dequantize node for identity should have signed_input==true. Node* identity_q_node; TF_ASSERT_OK( - FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), + FindNode(g, absl::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), &identity_q_node)); ASSERT_EQ("true", SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input"))); // Quantize_and_dequantize node for relu should have signed_input==false. Node* relu_q_node; - TF_ASSERT_OK( - FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), - &relu_q_node)); + TF_ASSERT_OK(FindNode( + g, absl::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &relu_q_node)); ASSERT_EQ("false", SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input"))); } @@ -161,16 +160,15 @@ TEST_F(QuantizeTrainingTest, RangeGivenTrue) { // Quantize_and_dequantize node for relu6 should have range_given==true. Node* relu6_q_node; - TF_ASSERT_OK( - FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"), - &relu6_q_node)); + TF_ASSERT_OK(FindNode(g, + absl::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"), + &relu6_q_node)); ASSERT_EQ("true", SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given"))); // Quantize_and_dequantize node for relu should have range_given==true. Node* relu_q_node; - TF_ASSERT_OK( - FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), - &relu_q_node)); + TF_ASSERT_OK(FindNode( + g, absl::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &relu_q_node)); ASSERT_EQ("true", SummarizeAttrValue(*relu_q_node->attrs().Find("range_given"))); } @@ -215,18 +213,17 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) { // Ensure that the backwards matmul input was not quantized. Node* found_node; absl::Status s = FindNode( - g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"), &found_node); + g, absl::StrCat(d->name(), "/QuantizeAndDequantizeV2"), &found_node); EXPECT_TRUE(absl::StrContains(s.ToString(), "not found")) << s; // Ensure that m1 and m2's inputs were quantized. + TF_ASSERT_OK(FindNode( + g, absl::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &found_node)); TF_ASSERT_OK( - FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), - &found_node)); - TF_ASSERT_OK( - FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), + FindNode(g, absl::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), &found_node)); - TF_ASSERT_OK(FindNode( - g, strings::StrCat(c->name(), "/QuantizeAndDequantizeV2"), &found_node)); + TF_ASSERT_OK(FindNode(g, absl::StrCat(c->name(), "/QuantizeAndDequantizeV2"), + &found_node)); } TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) { @@ -269,18 +266,17 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) { // Ensure that the backwards matmul input was not quantized. Node* found_node; absl::Status s = FindNode( - g, strings::StrCat(d->name(), "/FakeQuantWithMinMaxVars"), &found_node); + g, absl::StrCat(d->name(), "/FakeQuantWithMinMaxVars"), &found_node); EXPECT_TRUE(absl::StrContains(s.ToString(), "not found")) << s; // Ensure that m1 and m2's inputs were quantized. + TF_ASSERT_OK(FindNode( + g, absl::StrCat(relu->name(), "/FakeQuantWithMinMaxVars"), &found_node)); TF_ASSERT_OK( - FindNode(g, strings::StrCat(relu->name(), "/FakeQuantWithMinMaxVars"), - &found_node)); - TF_ASSERT_OK( - FindNode(g, strings::StrCat(identity->name(), "/FakeQuantWithMinMaxVars"), + FindNode(g, absl::StrCat(identity->name(), "/FakeQuantWithMinMaxVars"), &found_node)); - TF_ASSERT_OK(FindNode( - g, strings::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), &found_node)); + TF_ASSERT_OK(FindNode(g, absl::StrCat(c->name(), "/FakeQuantWithMinMaxVars"), + &found_node)); } TEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) { @@ -301,10 +297,10 @@ TEST_F(QuantizeTrainingTest, QuantizeSerializedGraphDef) { // Convert the graph to the graphdef string. GraphDef input_graph; graph->ToGraphDef(&input_graph); - string input_string; + std::string input_string; input_graph.SerializeToString(&input_string); - string result_string; + std::string result_string; TF_ASSERT_OK(DoQuantizeTrainingOnSerializedGraphDef( input_string, num_bits, "QuantizeAndDequantizeV2", &result_string)); @@ -400,8 +396,8 @@ TEST_F(QuantizeTrainingTest, FixedRangeAndEMARange_QuantizeAndDequantize) { // The min and max values of the relu6 quantization should be constant values // of 0 and 6. - string min_const_name = strings::StrCat(relu6->name(), "/InputMin"); - string max_const_name = strings::StrCat(relu6->name(), "/InputMax"); + std::string min_const_name = absl::StrCat(relu6->name(), "/InputMin"); + std::string max_const_name = absl::StrCat(relu6->name(), "/InputMax"); std::vector outputs; TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs)); EXPECT_EQ(outputs[0].flat()(0), 0.0); @@ -416,8 +412,8 @@ TEST_F(QuantizeTrainingTest, FixedRangeAndEMARange_QuantizeAndDequantize) { // The value of the min and max should be set to the min and max of a1 since // this is the first run that initializes the EMA variables. - string min_var_name = strings::StrCat(relu->name(), "/Min/Variable"); - string max_var_name = strings::StrCat(relu->name(), "/Max/Variable"); + std::string min_var_name = absl::StrCat(relu->name(), "/Min/Variable"); + std::string max_var_name = absl::StrCat(relu->name(), "/Max/Variable"); TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs)); EXPECT_EQ(outputs[0].flat()(0), 0.0); EXPECT_EQ(outputs[1].flat()(0), 3.0); @@ -494,8 +490,8 @@ TEST_F(QuantizeTrainingTest, FixedRangeAndEMARange_FakeQuant) { // The min and max values of the relu6 quantization should be constant values // of 0 and 6. - string min_const_name = strings::StrCat(relu6->name(), "/InputMin"); - string max_const_name = strings::StrCat(relu6->name(), "/InputMax"); + std::string min_const_name = absl::StrCat(relu6->name(), "/InputMin"); + std::string max_const_name = absl::StrCat(relu6->name(), "/InputMax"); std::vector outputs; TF_ASSERT_OK(sess->Run({}, {min_const_name, max_const_name}, {}, &outputs)); EXPECT_EQ(outputs[0].flat()(0), 0.0); @@ -510,8 +506,8 @@ TEST_F(QuantizeTrainingTest, FixedRangeAndEMARange_FakeQuant) { // The value of the min and max should be set to the min and max of a1 since // this is the first run that initializes the EMA variables. - string min_var_name = strings::StrCat(relu->name(), "/Min/Variable"); - string max_var_name = strings::StrCat(relu->name(), "/Max/Variable"); + std::string min_var_name = absl::StrCat(relu->name(), "/Min/Variable"); + std::string max_var_name = absl::StrCat(relu->name(), "/Max/Variable"); TF_ASSERT_OK(sess->Run({}, {min_var_name, max_var_name}, {}, &outputs)); EXPECT_EQ(outputs[0].flat()(0), 0.0); EXPECT_EQ(outputs[1].flat()(0), 3.0); diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc index 0bfc121b23cab4..a4c15f74b49774 100644 --- a/tensorflow/core/common_runtime/renamed_device.cc +++ b/tensorflow/core/common_runtime/renamed_device.cc @@ -28,7 +28,7 @@ namespace tensorflow { /* static */ std::unique_ptr RenamedDevice::NewRenamedDevice( - const string& new_base, Device* underlying, bool owns_underlying, + const std::string& new_base, Device* underlying, bool owns_underlying, bool isolate_session_state, thread::ThreadPoolInterface* underlying_threadpool) { DeviceNameUtils::ParsedName parsed_name; @@ -39,9 +39,9 @@ std::unique_ptr RenamedDevice::NewRenamedDevice( CHECK(underlying_parsed_name.has_id); parsed_name.type = underlying_parsed_name.type; parsed_name.id = underlying_parsed_name.id; - string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica, - parsed_name.task, parsed_name.type, - parsed_name.id); + std::string name = DeviceNameUtils::FullName( + parsed_name.job, parsed_name.replica, parsed_name.task, parsed_name.type, + parsed_name.id); DeviceAttributes attributes(underlying->attributes()); attributes.set_name(name); // Call absl::WrapUnique to access private constructor. diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 4a0e1057b398a4..687f61f8eff2d8 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -30,7 +30,7 @@ namespace tensorflow { class RenamedDevice : public Device { public: static std::unique_ptr NewRenamedDevice( - const string& new_base, Device* underlying, bool owns_underlying, + const std::string& new_base, Device* underlying, bool owns_underlying, bool isolate_session_state, thread::ThreadPoolInterface* underlying_threadpool = nullptr); diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index f1a199ba97250d..1d6e53c6585068 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -99,9 +99,9 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr, if (in.dtype() != DT_VARIANT) { // Variants are handled by CopyTensor::ViaDMA. AllocationAttributes aa; - uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0); - std::function freed_by_func = [dst_device, - &safe_alloc_frontier]() { + uint64_t safe_alloc_frontier = dst_device->SafeAllocFrontier(0); + std::function freed_by_func = [dst_device, + &safe_alloc_frontier]() { safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier); return safe_alloc_frontier; }; diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 532f4e84a2f9f2..8f4e7acbf77ed5 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -22,7 +22,8 @@ namespace tensorflow { absl::Status SendTensorsToRendezvous( RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, - const std::vector& keys, absl::Span tensors_to_send) { + const std::vector& keys, + absl::Span tensors_to_send) { if (keys.size() != tensors_to_send.size()) { return errors::InvalidArgument( "keys and tensors_to_send are not the same size. keys.size() = ", @@ -56,7 +57,7 @@ absl::Status SendTensorsToRendezvous( void RecvOutputsFromRendezvousAsync( RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, - const std::vector& keys, std::vector* received_tensors, + const std::vector& keys, std::vector* received_tensors, StatusCallback done) { if (keys.empty()) { done(absl::OkStatus()); @@ -69,8 +70,8 @@ void RecvOutputsFromRendezvousAsync( } received_tensors->reserve(keys.size()); - std::vector< - std::tuple> + std::vector> arguments; for (int i = 0; i < keys.size(); ++i) { Rendezvous::ParsedKey parsed; @@ -90,7 +91,7 @@ void RecvOutputsFromRendezvousAsync( auto status_cb = new ReffedStatusCallback(std::move(done)); for (auto& p : arguments) { - const string& key = std::get<0>(p); + const std::string& key = std::get<0>(p); Tensor* val = std::get<1>(p); Rendezvous::ParsedKey parsed = std::get<2>(p); Rendezvous::Args rendez_args; @@ -124,7 +125,7 @@ absl::Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, // Receives values requested by the caller. Rendezvous::ParsedKey parsed; for (auto& p : *out) { - const string& key = p.first; + const std::string& key = p.first; Tensor* val = &p.second; bool is_dead = false; TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h index 8ed1dd7a11ad16..1c9ac0ef221a54 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.h +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -typedef std::map NamedTensors; +typedef std::map NamedTensors; typedef std::function StatusCallback; // Uses `rendezvous` to send tensors in `tensors_to_send`. `device_context` @@ -33,7 +33,8 @@ typedef std::function StatusCallback; absl::Status SendTensorsToRendezvous( RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, - const std::vector& keys, absl::Span tensors_to_send); + const std::vector& keys, + absl::Span tensors_to_send); // Uses `rendezvous` to obtain tensors. `device_context` should be the // DeviceContext associated with the receiving device. `alloc_attrs` contains @@ -42,7 +43,7 @@ absl::Status SendTensorsToRendezvous( void RecvOutputsFromRendezvousAsync( RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, - const std::vector& keys, std::vector* received_tensors, + const std::vector& keys, std::vector* received_tensors, StatusCallback done); absl::Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, diff --git a/tensorflow/core/common_runtime/rendezvous_util_test.cc b/tensorflow/core/common_runtime/rendezvous_util_test.cc index 484746ce416f0b..f2c866c307905c 100644 --- a/tensorflow/core/common_runtime/rendezvous_util_test.cc +++ b/tensorflow/core/common_runtime/rendezvous_util_test.cc @@ -32,20 +32,20 @@ class RendezvousUtilTest : public ::testing::Test { }; // string -> Tensor -Tensor V(const string& content) { +Tensor V(const std::string& content) { Tensor tensor(DT_STRING, TensorShape({})); tensor.scalar()() = content; return tensor; } // Tensor -> string -string V(const Tensor& tensor) { +std::string V(const Tensor& tensor) { CHECK_EQ(tensor.dtype(), DT_STRING); CHECK(TensorShapeUtils::IsScalar(tensor.shape())); return tensor.scalar()(); } -string MakeStringKey(const string& name) { +std::string MakeStringKey(const std::string& name) { return Rendezvous::CreateKey( "/job:localhost/replica:0/task:0/device:CPU:0", 0, "/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0)); diff --git a/tensorflow/core/common_runtime/replicate_constants_pass.cc b/tensorflow/core/common_runtime/replicate_constants_pass.cc index 7da785ca6f54e3..9dfa50ae0dc2a4 100644 --- a/tensorflow/core/common_runtime/replicate_constants_pass.cc +++ b/tensorflow/core/common_runtime/replicate_constants_pass.cc @@ -70,8 +70,8 @@ bool HasCpuDevice(const Node* node) { // Convert the CPU device name to the corresponding CPU device name. If // multiple local CPU devices are enabled, the CPU device name will also // contain the device id. -absl::Status DeviceNameToCpuDeviceNameWithDeviceId(const string& device_name, - string* host_device_name) { +absl::Status DeviceNameToCpuDeviceNameWithDeviceId( + const std::string& device_name, std::string* host_device_name) { DeviceNameUtils::ParsedName device; if (!DeviceNameUtils::ParseFullName(device_name, &device)) { return absl::InternalError( diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc index 3f4cf1498769a0..e60117f588f8c0 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc @@ -45,7 +45,7 @@ class ReplicateHelper { // Replicate the given node to an allowed device. absl::Status ReplicateNode(const Node* node, - const std::vector& allowed_devices, + const std::vector& allowed_devices, int allowed_device_index, Graph* graph) { auto& replicated_nodes = replicated_nodes_map_.at(node); if (replicated_nodes[allowed_device_index] != nullptr) { @@ -53,8 +53,8 @@ class ReplicateHelper { } const auto& device = allowed_devices.at(allowed_device_index); NodeDef node_def = node->def(); - const string suffix = strings::StrCat("/R", allowed_device_index); - node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix))); + const std::string suffix = absl::StrCat("/R", allowed_device_index); + node_def.set_name(graph->NewName(absl::StrCat(node_def.name(), suffix))); TF_ASSIGN_OR_RETURN(Node * replicated_node, graph->AddNode(node_def)); replicated_node->set_assigned_device_name(device); if (replicated_node->IsArg()) { @@ -83,7 +83,7 @@ class ReplicateHelper { // Replace an edge (composite device -> composite device) with // N edges (allowed devices -> allowed devices). absl::Status ReplicateFromCompositeDeviceToCompositeDevice( - const Edge* edge, const std::vector& allowed_devices, + const Edge* edge, const std::vector& allowed_devices, Graph* graph) { const std::vector& src_replicated_nodes = replicated_nodes_map_.at(edge->src()); @@ -115,12 +115,12 @@ class ReplicateHelper { // Control edge: replace an edge (composite device -> a regular device) with // N edges (allowed devices -> a regular device). absl::Status ReplicateFromCompositeDeviceToRegularDevice( - const Edge* edge, const std::vector& allowed_devices, + const Edge* edge, const std::vector& allowed_devices, Graph* graph) { const std::vector& src_replicated_nodes = replicated_nodes_map_.at(edge->src()); Node* dst = edge->dst(); - const string& dst_device = dst->assigned_device_name(); + const std::string& dst_device = dst->assigned_device_name(); bool found_src_node = false; for (int i = 0; i < allowed_devices.size(); ++i) { if (allowed_devices.at(i) == dst_device) { @@ -198,7 +198,7 @@ class ReplicateHelper { // Replicate the nodes in cluster_nodes and update edges. absl::Status ReplicateNodesAndEdges( - const std::vector& allowed_devices, + const std::vector& allowed_devices, absl::flat_hash_map* cluster_nodes, ReplicateHelper* helper, Graph* graph) { // Contains nodes in cluster_nodes whose out nodes are all on physical @@ -253,19 +253,19 @@ absl::Status ReplicateNodesAndEdges( } // namespace absl::Status ReplicatePerReplicaNodesInFunctionGraph( - const absl::flat_hash_map*>& + const absl::flat_hash_map*>& composite_devices, Graph* graph) { VLOG(1) << "Starting ReplicatePerReplicaNodesInFunctionGraph"; VLOG(1) << "Graph #nodes " << graph->num_nodes() << " #edges " << graph->num_edges(); - std::set composite_device_names; + std::set composite_device_names; for (const auto& it : composite_devices) { composite_device_names.insert(it.first); } // Map from a composite device to a cluster of nodes assigned to the // composite device and the numbers of their out edges to process. - absl::flat_hash_map> + absl::flat_hash_map> composite_device_to_cluster_nodes; for (Node* n : graph->op_nodes()) { if (composite_device_names.find(n->assigned_device_name()) != @@ -284,7 +284,7 @@ absl::Status ReplicatePerReplicaNodesInFunctionGraph( } for (auto& it : composite_device_to_cluster_nodes) { - const std::vector& allowed_devices = + const std::vector& allowed_devices = *composite_devices.at(it.first); if (allowed_devices.empty()) { return errors::InvalidArgument("No allowed device of composite device: ", diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.h b/tensorflow/core/common_runtime/replicate_per_replica_nodes.h index 4be95ea32ca44b..414bd21de35361 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.h +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.h @@ -35,7 +35,7 @@ namespace tensorflow { // dependency. // TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass. absl::Status ReplicatePerReplicaNodesInFunctionGraph( - const absl::flat_hash_map*>& + const absl::flat_hash_map*>& composite_devices, Graph* graph); diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc index ff6fcb4b8bc735..f0a859286fba06 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc @@ -40,7 +40,7 @@ class GraphHelper { } } - Node* GetNodeByName(const string& name) { + Node* GetNodeByName(const std::string& name) { const auto it = nodes_by_name_.find(name); if (it != nodes_by_name_.end()) { return it->second; @@ -53,7 +53,8 @@ class GraphHelper { return nullptr; } - void SetAssignedDevice(const string& node_name, const string& device_name) { + void SetAssignedDevice(const std::string& node_name, + const std::string& device_name) { CHECK_NOTNULL(GetNodeByName(node_name)) ->set_assigned_device_name(device_name); } @@ -68,14 +69,14 @@ class GraphHelper { EXPECT_EQ(arg_num, expected_num); } - void CheckAssignedDevice(const string& node_name, - const string& expected_device_name) { + void CheckAssignedDevice(const std::string& node_name, + const std::string& expected_device_name) { EXPECT_EQ(expected_device_name, CHECK_NOTNULL(GetNodeByName(node_name))->assigned_device_name()); } - void CheckAssignedDevicePrefix(const string& node_name, - const string& expected_device_name) { + void CheckAssignedDevicePrefix(const std::string& node_name, + const std::string& expected_device_name) { auto assigned = CHECK_NOTNULL(GetNodeByName(node_name))->assigned_device_name(); EXPECT_EQ(assigned.rfind(expected_device_name, 0), 0); @@ -85,21 +86,21 @@ class GraphHelper { const Graph& graph_; // Maps from a node name to a Node* in the graph. We use an ordered map here // to ensure stability of GetNodeByName(). - std::map nodes_by_name_; + std::map nodes_by_name_; }; TEST(ReplicatePerReplicaNodesTest, SingleCompositeDevice) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); Output arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32); - auto one = ops::Const(scope.WithOpName("one"), 1); + auto one = ops::Const(scope.WithOpName("one"), 1); auto write = ops::AssignVariableOp(scope.WithOpName("write"), arg, one); auto ret = ops::_Retval( scope.WithOpName("ret").WithControlDependencies({write}), read, 0); - const std::vector underlying_devices = {"/device:TPU:0", - "/device:TPU:1"}; - const absl::flat_hash_map*> + const std::vector underlying_devices = {"/device:TPU:0", + "/device:TPU:1"}; + const absl::flat_hash_map*> composite_devices = {{"/device:TPU_COMPOSITE:0", &underlying_devices}}; Graph graph(OpRegistry::Global()); @@ -143,8 +144,8 @@ TEST(ReplicatePerReplicaNodesTest, SingleCompositeDeviceToSingleDevice) { auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32); auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0); - const std::vector underlying_devices = {"/device:TPU:0"}; - const absl::flat_hash_map*> + const std::vector underlying_devices = {"/device:TPU:0"}; + const absl::flat_hash_map*> composite_devices = {{"/device:TPU_COMPOSITE:0", &underlying_devices}}; Graph graph(OpRegistry::Global()); @@ -183,11 +184,11 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) { auto add = ops::Add(scope.WithOpName("add"), identity0, identity1); auto ret = ops::_Retval(scope.WithOpName("ret"), add, 0); - const std::vector underlying_devices_0 = {"/device:TPU:0", - "/device:TPU:1"}; - const std::vector underlying_devices_1 = {"/device:TPU:2", - "/device:TPU:3"}; - const absl::flat_hash_map*> + const std::vector underlying_devices_0 = {"/device:TPU:0", + "/device:TPU:1"}; + const std::vector underlying_devices_1 = {"/device:TPU:2", + "/device:TPU:3"}; + const absl::flat_hash_map*> composite_devices = {{"/device:TPU_COMPOSITE:0", &underlying_devices_0}, {"/device:TPU_COMPOSITE:1", &underlying_devices_1}}; @@ -232,9 +233,9 @@ TEST(ReplicatePerReplicaNodesTest, MultipleCompositeDevices) { } TEST(ReplicatePerReplicaNodesTest, NestedFunctions) { - const std::vector underlying_devices = {"/device:TPU:0", - "/device:TPU:1"}; - const absl::flat_hash_map*> + const std::vector underlying_devices = {"/device:TPU:0", + "/device:TPU:1"}; + const absl::flat_hash_map*> composite_devices = {{"/device:TPU_COMPOSITE:0", &underlying_devices}}; FunctionDefLibrary fdef_lib; @@ -311,9 +312,9 @@ TEST(ReplicatePerReplicaNodesTest, DeadArgNodes) { auto read = ops::ReadVariableOp(scope.WithOpName("read"), arg, DT_INT32); auto ret = ops::_Retval(scope.WithOpName("ret"), read, 0); - const std::vector underlying_devices = {"/device:TPU:0", - "/device:TPU:1"}; - const absl::flat_hash_map*> + const std::vector underlying_devices = {"/device:TPU:0", + "/device:TPU:1"}; + const absl::flat_hash_map*> composite_devices = {{"/device:TPU_COMPOSITE:0", &underlying_devices}}; Graph graph(OpRegistry::Global()); diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index a12acfdf64c9dd..ff44370ecbd451 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -61,8 +61,8 @@ namespace { // RingAlg instances. Note that the exec_key will differentiate between // different instances consequently we don't need to further differentiate // between subclasses of RingAlg. -string RingAlgBufKey(const string& name, const string& exec_key, int pass, - int section, int source_rank) { +std::string RingAlgBufKey(const std::string& name, const std::string& exec_key, + int pass, int section, int source_rank) { if (READABLE_KEYS) { return strings::StrCat(name, "(", exec_key, "):pass(", pass, "):section(", section, "):srcrank(", source_rank, ")"); @@ -97,7 +97,7 @@ RingAlg::RingField* RingAlg::PCQueue::Dequeue() { return rf; } -RingAlg::RingAlg(CollectiveType type, const string& name) +RingAlg::RingAlg(CollectiveType type, const std::string& name) : type_(type), name_(name), col_ctx_(nullptr), @@ -163,10 +163,10 @@ absl::Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { } if (VLOG_IS_ON(2)) { - string subdiv_buf; + std::string subdiv_buf; for (const int subdiv_offset : col_params->instance.impl_details.subdiv_offsets) { - strings::StrAppend(&subdiv_buf, " ", subdiv_offset); + absl::StrAppend(&subdiv_buf, " ", subdiv_offset); } VLOG(2) << "Dynamically generated " << num_subdivs << " subdiv_offsets:" << subdiv_buf << " tensor_size " @@ -178,7 +178,7 @@ absl::Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { } // namespace absl::Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { - const string& device_name = + const std::string& device_name = col_params->group.members[col_params->default_rank].device.name(); // Each subdiv permutation is a ring formed by rotating each // single-task subsequence of devices by an offset. This makes most @@ -190,7 +190,7 @@ absl::Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { // Precondition: device_names must be sorted so that all devices in // the same task are adjacent. std::vector dev_per_task; - const string* prior_task_name = &col_params->group.members[0].task; + const std::string* prior_task_name = &col_params->group.members[0].task; int dev_count = 1; for (int di = 1; di < col_params->group.group_size; ++di) { if (col_params->group.members[di].task != *prior_task_name) { @@ -265,7 +265,7 @@ absl::Status RingAlg::InitializeCollectiveContext( &col_ctx->device_locality); } -string RingAlg::TensorDebugString(const Tensor& tensor) { +std::string RingAlg::TensorDebugString(const Tensor& tensor) { const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = col_ctx_->op_ctx->device()->tensorflow_accelerator_device_info(); if (accelerator_device_info) { @@ -383,11 +383,11 @@ void RingAlg::AdvanceToSecondPass(RingField* rf) { VLOG(3) << "IncrRingField new value " << rf->DebugString(); } -string RingAlg::RingField::DebugString() const { - string rv = strings::StrCat("RingField rank=", rank, " chunk_idx=", chunk_idx, - " subdiv=", subdiv_idx, " sc_idx=", sc_idx, - " action=", action); - strings::StrAppend(&rv, " pass=", second_pass); +std::string RingAlg::RingField::DebugString() const { + std::string rv = strings::StrCat( + "RingField rank=", rank, " chunk_idx=", chunk_idx, " subdiv=", subdiv_idx, + " sc_idx=", sc_idx, " action=", action); + absl::StrAppend(&rv, " pass=", second_pass); strings::StrAppend(&rv, " do_send=", do_send, " do_recv=", do_recv, " is_final=", is_final, " recv_is_remote=", recv_is_remote, " recv_dev_idx=", recv_dev_idx, " sc_idx=", sc_idx); @@ -396,8 +396,8 @@ string RingAlg::RingField::DebugString() const { void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) { DCHECK(rf->do_send); - string send_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key, - rf->second_pass, rf->sc_idx, rf->rank); + std::string send_buf_key = RingAlgBufKey( + name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx, rf->rank); VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key " << send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx " << rf->sc_idx; @@ -415,7 +415,7 @@ void RingAlg::DispatchSend(RingField* rf, const StatusCallback& done) { void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { DCHECK(rf->do_recv); - string recv_buf_key = + std::string recv_buf_key = RingAlgBufKey(name_, col_ctx_->exec_key, rf->second_pass, rf->sc_idx, (rf->rank + (group_size_ - 1)) % group_size_); VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key " @@ -434,9 +434,9 @@ void RingAlg::DispatchRecv(RingField* rf, const StatusCallback& done) { col_ctx_->op_ctx->cancellation_manager(), done); } -string RingAlg::FieldState() { - string s = strings::StrCat( - "Ring", name_, " ", strings::Hex(reinterpret_cast(this)), +std::string RingAlg::FieldState() { + std::string s = strings::StrCat( + "Ring", name_, " ", strings::Hex(reinterpret_cast(this)), " exec ", col_ctx_->exec_key, " step_id=", col_ctx_->step_id, " state of all ", rfv_.size(), " fields:"); for (int i = 0; i < rfv_.size(); ++i) { diff --git a/tensorflow/core/common_runtime/ring_alg.h b/tensorflow/core/common_runtime/ring_alg.h index d2294f830db2c1..b54da03a01a739 100644 --- a/tensorflow/core/common_runtime/ring_alg.h +++ b/tensorflow/core/common_runtime/ring_alg.h @@ -31,7 +31,7 @@ class Device; // for specific collective functions. class RingAlg : public CollectiveImplementationInterface { public: - explicit RingAlg(CollectiveType type, const string& name); + explicit RingAlg(CollectiveType type, const std::string& name); ~RingAlg() override {} // Establishes the requested number of subdivision permutations based on the @@ -63,11 +63,11 @@ class RingAlg : public CollectiveImplementationInterface { // Tracks progress of actions on a single subfield of the entire tensor. struct RingField { - int16 chunk_idx; // major division index - int16 subdiv_idx; // minor division index - int16 sc_idx; // subchunk index - int16 rank; // rank within subdiv permutation - int16 recv_dev_idx; // dev from which value should be recv'd + int16_t chunk_idx; // major division index + int16_t subdiv_idx; // minor division index + int16_t sc_idx; // subchunk index + int16_t rank; // rank within subdiv permutation + int16_t recv_dev_idx; // dev from which value should be recv'd RingFieldAction action; bool second_pass; bool recv_is_remote = false; @@ -78,7 +78,7 @@ class RingAlg : public CollectiveImplementationInterface { Tensor chunk; // alias to field values Tensor tmp_chunk; absl::Status status; - string DebugString() const; + std::string DebugString() const; }; virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, int field_idx); @@ -87,8 +87,8 @@ class RingAlg : public CollectiveImplementationInterface { void DispatchRecv(RingField* rf, const StatusCallback& done); // For constructing log messages for debugging. - string FieldState(); - string TensorDebugString(const Tensor& tensor); + std::string FieldState(); + std::string TensorDebugString(const Tensor& tensor); // Producer/Consumer Queue of RingField structs. class PCQueue { @@ -104,7 +104,7 @@ class RingAlg : public CollectiveImplementationInterface { }; const CollectiveType type_; - const string name_; + const std::string name_; std::shared_ptr col_ctx_; const CollectiveParams* col_params_; // Not owned StatusCallback done_; diff --git a/tensorflow/core/common_runtime/ring_gatherer.cc b/tensorflow/core/common_runtime/ring_gatherer.cc index bc016b366696d4..bd85f07aef1840 100644 --- a/tensorflow/core/common_runtime/ring_gatherer.cc +++ b/tensorflow/core/common_runtime/ring_gatherer.cc @@ -71,7 +71,7 @@ void RingGatherer::Run(StatusCallback done) { DCHECK_GT(num_subdivs_, 0); if (VLOG_IS_ON(1)) { - string buf; + std::string buf; for (int r = 0; r < col_params_->group.members.size(); ++r) { strings::StrAppend(&buf, "dev ", r, " : ", col_params_->group.members[r].device.name(), "\n"); @@ -79,10 +79,10 @@ void RingGatherer::Run(StatusCallback done) { for (int sd = 0; sd < col_params_->instance.impl_details.subdiv_permutations.size(); ++sd) { - strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: "); + absl::StrAppend(&buf, "\nsubdiv ", sd, " perm: "); for (auto x : col_params_->instance.impl_details.subdiv_permutations[sd]) { - strings::StrAppend(&buf, x, ", "); + absl::StrAppend(&buf, x, ", "); } } VLOG(1) << "RingGatherer::Run for device " << col_ctx_->device_name diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc index 595ff502737b93..884fb17340c4c0 100644 --- a/tensorflow/core/common_runtime/ring_gatherer_test.cc +++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc @@ -105,7 +105,7 @@ class RingGathererTest : public ::testing::Test { // Confirm that every device terminated with the expected error status. for (int di = 0; di < static_cast(instances_.size()); ++di) { EXPECT_NE(instances_[di]->status_.message().find("Deliberate failure"), - string::npos); + std::string::npos); } } else { // Confirm that every device accumulated the same set of correct @@ -130,7 +130,7 @@ class RingGathererTest : public ::testing::Test { GenerateEvenSubdivOffsets(test_env->num_devices_per_worker, num_subdivs); } - string dev_name = col_params_->group.members[rank].device.name(); + std::string dev_name = col_params_->group.members[rank].device.name(); TF_CHECK_OK(test_env_->device_mgr->LookupDevice(dev_name, &device_)) << "Couldn't find device " << dev_name << " existing devices: " << test_env_->device_mgr->DebugString(); diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc index c448f021f055f0..3ad099caee9b9b 100644 --- a/tensorflow/core/common_runtime/ring_reducer.cc +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -67,7 +67,7 @@ void RingReducer::Run(StatusCallback done) { CHECK_GT(num_subdivs_, 0); if (VLOG_IS_ON(1)) { - string buf; + std::string buf; for (int r = 0; r < col_params_->group.members.size(); ++r) { strings::StrAppend(&buf, "dev ", r, " : ", col_params_->group.members[r].device.name(), "\n"); @@ -75,10 +75,10 @@ void RingReducer::Run(StatusCallback done) { for (int sd = 0; sd < col_params_->instance.impl_details.subdiv_permutations.size(); ++sd) { - strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: "); + absl::StrAppend(&buf, "\nsubdiv ", sd, " perm: "); for (auto x : col_params_->instance.impl_details.subdiv_permutations[sd]) { - strings::StrAppend(&buf, x, ", "); + absl::StrAppend(&buf, x, ", "); } } VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name @@ -129,9 +129,9 @@ void RingReducer::ContinueAfterInputCopy() { // can be provided to the kernel in host memory? Tensor group_size_val = ca_->Scalar(group_size_); if (col_params_->group.device_type != "CPU") { - uint64 safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(0); + uint64_t safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(0); AllocationAttributes aa; - std::function freed_by_func = [this, &safe_alloc_frontier]() { + std::function freed_by_func = [this, &safe_alloc_frontier]() { safe_alloc_frontier = col_ctx_->device->SafeAllocFrontier(safe_alloc_frontier); return safe_alloc_frontier; diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index d4baa4aaef652e..bedfa64134de51 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -138,7 +138,7 @@ class RingReducerTest : public ::testing::Test { // Confirm that every device terminated with the expected error status. for (int di = 0; di < static_cast(instances_.size()); ++di) { EXPECT_NE(instances_[di]->status_.message().find("Deliberate failure"), - string::npos); + std::string::npos); } } else { // Confirm that every device computed the same correct reduction value. @@ -165,7 +165,7 @@ class RingReducerTest : public ::testing::Test { GenerateEvenSubdivOffsets(test_env->num_devices_per_worker, num_subdivs); } - string dev_name = col_params_->group.members[rank].device.name(); + std::string dev_name = col_params_->group.members[rank].device.name(); TF_CHECK_OK(test_env_->device_mgr->LookupDevice(dev_name, &device_)) << "Couldn't find device " << dev_name << " existing devices: " << test_env_->device_mgr->DebugString(); @@ -200,7 +200,7 @@ class RingReducerTest : public ::testing::Test { std::unique_ptr test_env_; std::vector> instances_; mutex mu_; - int32 reduce_counter_ TF_GUARDED_BY(mu_) = 0; + int32_t reduce_counter_ TF_GUARDED_BY(mu_) = 0; }; class RingReducerInitParamsTest : public ::testing::Test { diff --git a/tensorflow/core/common_runtime/scoped_allocator.cc b/tensorflow/core/common_runtime/scoped_allocator.cc index 1b3d39a8c6e996..24e7e089784e17 100644 --- a/tensorflow/core/common_runtime/scoped_allocator.cc +++ b/tensorflow/core/common_runtime/scoped_allocator.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { ScopedAllocator::ScopedAllocator(const Tensor& backing_tensor, int32_t scope_id, - const string& name, + const std::string& name, const absl::Span fields, int32_t expected_call_count, ScopedAllocatorContainer* container) @@ -69,7 +69,7 @@ void* ScopedAllocator::AllocateRaw(int32_t field_index, size_t num_bytes) { return nullptr; } - int32_t num_fields = static_cast(fields_.size()); + int32_t num_fields = static_cast(fields_.size()); if (field_index >= num_fields) { LOG(ERROR) << "ScopedAllocator " << name_ << " received unexpected field number " << field_index; @@ -228,8 +228,8 @@ void ScopedAllocatorInstance::DeallocateRaw(void* p) { if (del) delete this; } -string ScopedAllocatorInstance::Name() { - return strings::StrCat(scoped_allocator_->name(), "_field_", field_index_); +std::string ScopedAllocatorInstance::Name() { + return absl::StrCat(scoped_allocator_->name(), "_field_", field_index_); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/scoped_allocator.h b/tensorflow/core/common_runtime/scoped_allocator.h index 5b22deb264ce52..8c894372fbee15 100644 --- a/tensorflow/core/common_runtime/scoped_allocator.h +++ b/tensorflow/core/common_runtime/scoped_allocator.h @@ -33,7 +33,7 @@ class ScopedAllocator { // A subrange of the TensorBuffer associated with this object that // will be the backing memory for one aliased tensor. struct Field { - int32 scope_id; + int32_t scope_id; size_t offset; size_t bytes_requested; size_t bytes_allocated; @@ -71,13 +71,13 @@ class ScopedAllocator { void DeallocateRaw(void* p) TF_LOCKS_EXCLUDED(mu_); Tensor backing_tensor_; TensorBuffer* tbuf_; - int32 id_; + int32_t id_; std::string name_; ScopedAllocatorContainer* container_; std::vector fields_; mutex mu_; - int32 expected_call_count_ TF_GUARDED_BY(mu_); - int32 live_alloc_count_ TF_GUARDED_BY(mu_); + int32_t expected_call_count_ TF_GUARDED_BY(mu_); + int32_t live_alloc_count_ TF_GUARDED_BY(mu_); }; // An Allocator that will return a pointer into the backing buffer of @@ -117,7 +117,7 @@ class ScopedAllocatorInstance : public Allocator { private: mutex mu_; ScopedAllocator* scoped_allocator_; - int32 field_index_; + int32_t field_index_; bool allocated_ TF_GUARDED_BY(mu_); bool deallocated_ TF_GUARDED_BY(mu_); bool in_table_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc index 47ddfabbc27efe..d4fe07b5f27d2b 100644 --- a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc +++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc @@ -20,7 +20,8 @@ limitations under the License. namespace tensorflow { absl::Status ScopedAllocatorContainer::AddScopedAllocator( - const Tensor& backing_tensor, int32_t scope_id, const string& scope_name, + const Tensor& backing_tensor, int32_t scope_id, + const std::string& scope_name, const absl::Span& fields, int32_t expected_call_count) { VLOG(1) << "AddScopedAllocator " << mgr_->device_name() @@ -152,7 +153,7 @@ ScopedAllocatorContainer* ScopedAllocatorMgr::GetContainer(int64_t step_id) { absl::Status ScopedAllocatorMgr::AddScopedAllocator( const Tensor& backing_tensor, int64_t step_id, int32_t scope_id, - const string& scope_name, + const std::string& scope_name, const absl::Span& fields, int32_t expected_call_count) { ScopedAllocatorContainer* sac = GetContainer(step_id); @@ -164,7 +165,7 @@ absl::Status ScopedAllocatorMgr::AddScopedAllocator( size_t ScopedAllocatorMgr::PopulateFields( int32_t scope_id, const absl::Span& shapes, const DataType dtype, std::vector* fields) { - const int32_t num_fields = static_cast(shapes.size()); + const int32_t num_fields = static_cast(shapes.size()); fields->resize(num_fields); // At the end of iteration `i`, `offset` points to the offset from the start // of the backing buffer until the end of `field[i].bytes_allocated`. This diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.h b/tensorflow/core/common_runtime/scoped_allocator_mgr.h index dbbf7c3249ae54..22924a7005e892 100644 --- a/tensorflow/core/common_runtime/scoped_allocator_mgr.h +++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.h @@ -54,7 +54,7 @@ class ScopedAllocatorContainer : public core::RefCounted { int64_t step_id_; mutex mu_; struct SAField { - int32 field_index; + int32_t field_index; union { ScopedAllocator* scoped_allocator; ScopedAllocatorInstance* instance; @@ -67,7 +67,7 @@ class ScopedAllocatorContainer : public core::RefCounted { : field_index(ScopedAllocator::kBackingIndex), scoped_allocator(nullptr) {} }; - std::unordered_map allocators_ TF_GUARDED_BY(mu_); + std::unordered_map allocators_ TF_GUARDED_BY(mu_); }; // At most one of these exists per device. diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc index ab0d769ceebe8b..59deffca41c19c 100644 --- a/tensorflow/core/common_runtime/session.cc +++ b/tensorflow/core/common_runtime/session.cc @@ -36,27 +36,29 @@ Session::Session() {} Session::~Session() {} -absl::Status Session::Run(const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_tensor_names, - std::vector* outputs, - RunMetadata* run_metadata) { +absl::Status Session::Run( + const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_tensor_names, + std::vector* outputs, RunMetadata* run_metadata) { return errors::Unimplemented( "Run with options is not supported for this session."); } -absl::Status Session::PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) { +absl::Status Session::PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + std::string* handle) { return errors::Unimplemented( "Partial run is not supported for this session."); } absl::Status Session::PRun( - const string& handle, const std::vector >& inputs, - const std::vector& output_names, std::vector* outputs) { + const std::string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs) { return errors::Unimplemented( "Partial run is not supported for this session."); } @@ -96,7 +98,7 @@ absl::Status NewSession(const SessionOptions& options, Session** out_session) { } absl::Status Reset(const SessionOptions& options, - const std::vector& containers) { + const std::vector& containers) { SessionFactory* factory; TF_RETURN_IF_ERROR(SessionFactory::GetFactory(options, &factory)); return factory->Reset(options, containers); diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index c21f1dc9483ee2..fc28ab4e05e887 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -33,7 +33,7 @@ static mutex* get_session_factory_lock() { return &session_factory_lock; } -typedef std::unordered_map SessionFactories; +typedef std::unordered_map SessionFactories; SessionFactories* session_factories() { static SessionFactories* factories = new SessionFactories; return factories; @@ -41,7 +41,7 @@ SessionFactories* session_factories() { } // namespace -void SessionFactory::Register(const string& runtime_type, +void SessionFactory::Register(const std::string& runtime_type, SessionFactory* factory) { mutex_lock l(*get_session_factory_lock()); if (!session_factories()->insert({runtime_type, factory}).second) { @@ -51,17 +51,17 @@ void SessionFactory::Register(const string& runtime_type, } namespace { -const string RegisteredFactoriesErrorMessageLocked() { - std::vector factory_types; +const std::string RegisteredFactoriesErrorMessageLocked() { + std::vector factory_types; for (const auto& session_factory : *session_factories()) { factory_types.push_back(session_factory.first); } - return strings::StrCat("Registered factories are {", - absl::StrJoin(factory_types, ", "), "}."); + return absl::StrCat("Registered factories are {", + absl::StrJoin(factory_types, ", "), "}."); } -string SessionOptionsToString(const SessionOptions& options) { - return strings::StrCat("target: \"", options.target, - "\" config: ", options.config.ShortDebugString()); +std::string SessionOptionsToString(const SessionOptions& options) { + return absl::StrCat("target: \"", options.target, + "\" config: ", options.config.ShortDebugString()); } } // namespace @@ -69,7 +69,7 @@ absl::Status SessionFactory::GetFactory(const SessionOptions& options, SessionFactory** out_factory) { mutex_lock l(*get_session_factory_lock()); // could use reader lock - std::vector> candidate_factories; + std::vector> candidate_factories; for (const auto& session_factory : *session_factories()) { if (session_factory.second->AcceptsOptions(options)) { VLOG(2) << "SessionFactory type " << session_factory.first @@ -93,7 +93,7 @@ absl::Status SessionFactory::GetFactory(const SessionOptions& options, // the number of sessions grows. // TODO(mrry): Consider providing a system-default fallback option // in this case. - std::vector factory_types; + std::vector factory_types; factory_types.reserve(candidate_factories.size()); for (const auto& candidate_factory : candidate_factories) { factory_types.push_back(candidate_factory.first); diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h index ffadb29ae21a6c..3c9d08db121c68 100644 --- a/tensorflow/core/common_runtime/session_factory.h +++ b/tensorflow/core/common_runtime/session_factory.h @@ -61,12 +61,13 @@ class SessionFactory { // // Sessions that support resource containers should override this function. virtual absl::Status Reset(const SessionOptions& options, - const std::vector& containers) { + const std::vector& containers) { return errors::Unimplemented("Reset()"); } virtual ~SessionFactory() {} - static void Register(const string& runtime_type, SessionFactory* factory); + static void Register(const std::string& runtime_type, + SessionFactory* factory); static absl::Status GetFactory(const SessionOptions& options, SessionFactory** out_factory); }; diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 47341276fef563..5a236367357099 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -23,7 +23,8 @@ namespace tensorflow { // kTensorHandleResourceTypeName. const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle"; -absl::Status SessionState::GetTensor(const string& handle, Tensor* tensor) { +absl::Status SessionState::GetTensor(const std::string& handle, + Tensor* tensor) { mutex_lock l(state_lock_); auto it = tensors_.find(handle); if (it == tensors_.end()) { @@ -34,7 +35,7 @@ absl::Status SessionState::GetTensor(const string& handle, Tensor* tensor) { return absl::OkStatus(); } -absl::Status SessionState::AddTensor(const string& handle, +absl::Status SessionState::AddTensor(const std::string& handle, const Tensor& tensor) { mutex_lock l(state_lock_); if (!tensors_.insert({handle, tensor}).second) { @@ -44,7 +45,7 @@ absl::Status SessionState::AddTensor(const string& handle, return absl::OkStatus(); } -absl::Status SessionState::DeleteTensor(const string& handle) { +absl::Status SessionState::DeleteTensor(const std::string& handle) { mutex_lock l(state_lock_); if (tensors_.erase(handle) == 0) { return errors::InvalidArgument("Failed to delete a tensor with handle '", @@ -58,7 +59,7 @@ int64_t SessionState::GetNewId() { return tensor_id_++; } -absl::Status TensorStore::AddTensor(const string& name, +absl::Status TensorStore::AddTensor(const std::string& name, const TensorAndKey& tk) { mutex_lock l(lock_); if (!tensors_.insert({name, tk}).second) { @@ -69,18 +70,18 @@ absl::Status TensorStore::AddTensor(const string& name, return absl::OkStatus(); } -absl::Status TensorStore::SaveTensors(const std::vector& output_names, - SessionState* session_state) { +absl::Status TensorStore::SaveTensors( + const std::vector& output_names, SessionState* session_state) { mutex_lock l(lock_); if (!tensors_.empty()) { // Save only the tensors in output_names in the session. - for (const string& name : output_names) { + for (const std::string& name : output_names) { TensorId id(ParseTensorName(name)); - const string op_name(id.first); + const std::string op_name(id.first); auto it = tensors_.find(op_name); if (it != tensors_.end()) { // Save the tensor to the session state. - string key = it->second.GetHandle(op_name); + std::string key = it->second.GetHandle(op_name); TF_RETURN_IF_ERROR(session_state->AddTensor(key, it->second.tensor)); } } diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 0893140693fdf9..bc4787864315b0 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -162,7 +162,7 @@ absl::Status ShapeRefiner::InferShapesForFunction( const FunctionDef* function_def, AttrSlice attributes, InferenceContext* outer_context) { const Graph* graph; - const string& fname = function_def->signature().name(); + const std::string& fname = function_def->signature().name(); auto it = functions_.find(fname); if (it != functions_.end()) { graph = it->second.get(); @@ -170,7 +170,7 @@ absl::Status ShapeRefiner::InferShapesForFunction( InstantiationResult result; TF_RETURN_IF_ERROR(InstantiateFunction( *function_def, attributes, - [this](const string& op, const OpDef** sig) { + [this](const std::string& op, const OpDef** sig) { return this->function_library_->LookUpOpDef(op, sig); }, &result)); @@ -476,7 +476,7 @@ absl::Status ShapeRefiner::EvaluateConstantIntScalarEdge( scalar.NumElements()); } if (scalar.dtype() == DT_INT32) { - *result = scalar.scalar()(); + *result = scalar.scalar()(); } else { if (scalar.dtype() != DT_INT64) { return errors::InvalidArgument( @@ -515,7 +515,7 @@ absl::Status ShapeRefiner::ConstantPartialShape( "of '-1' is required to represent an unknown shape."); } if (t.dims() == 0) { - if (t.dtype() == DT_INT32 && t.scalar()() == -1) { + if (t.dtype() == DT_INT32 && t.scalar()() == -1) { *result = target_context->UnknownShape(); return absl::OkStatus(); } else if (t.dtype() == DT_INT64 && t.scalar()() == -1) { @@ -531,7 +531,7 @@ absl::Status ShapeRefiner::ConstantPartialShape( TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape)); - const string& src_op = input_edge->src()->type_string(); + const std::string& src_op = input_edge->src()->type_string(); if (src_context->Value(src_context->Dim(src_shape, 0)) == 0) { // Source tensor is a vector of length 0, so the shape it // represents is as scalar. diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 111303a095f5ae..f67e5dd4b388e7 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -87,7 +87,7 @@ class ShapeRefiner { } // Getters and setters for graph_def_version_. - int32 graph_def_version() const { return graph_def_version_; } + int32_t graph_def_version() const { return graph_def_version_; } void set_graph_def_version(int32_t version) { graph_def_version_ = version; } void set_require_shape_inference_fns(bool require_shape_inference_fns) { @@ -250,7 +250,7 @@ class ShapeRefiner { shape_inference::InferenceContext* context, shape_inference::InferenceContext* outer_context = nullptr); - int32 graph_def_version_; + int32_t graph_def_version_; const OpRegistryInterface* const ops_registry_; // The lifetime of the tensors are bound to the runner, so it should be the diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index c54f26e7cc460c..580a987b3ccffd 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -281,9 +281,10 @@ TEST_F(ShapeRefinerTest, ExtractConstantSubgraphMultiOutput) { // input_tensor from the shape function. { Scope root = Scope::NewRootScope(); - auto small = ops::Const(root, {static_cast(1), TensorShape({1, 1})}); + auto small = + ops::Const(root, {static_cast(1), TensorShape({1, 1})}); auto large = ops::Const( - root, {static_cast(2), TensorShape({4, kMaxTensorSize / 2})}); + root, {static_cast(2), TensorShape({4, kMaxTensorSize / 2})}); Node* multi; TF_ASSERT_OK(NodeBuilder("MI", "MultiIdentity") .Input(std::vector{small.node(), @@ -313,7 +314,7 @@ TEST_F(ShapeRefinerTest, ExtractConstantSubgraphMultiOutput) { // The add adds 1 and 2 together, and its output has kMaxTensorSize*2 // elements. shape_inference::InferenceContext* ctx = m.GetContext(shape_v2); - EXPECT_EQ(strings::StrCat("[", kMaxTensorSize * 2 * 3, "]"), + EXPECT_EQ(absl::StrCat("[", kMaxTensorSize * 2 * 3, "]"), ctx->DebugString(ctx->output(0))); } } @@ -380,7 +381,7 @@ REGISTER_OP("ShapeData") std::vector dims; dims.reserve(shape_data->NumElements()); for (int i = 0; i < shape_data->NumElements(); ++i) { - dims.emplace_back(c->MakeDim(shape_data->flat()(i))); + dims.emplace_back(c->MakeDim(shape_data->flat()(i))); } c->set_output(0, c->MakeShape(dims)); @@ -418,7 +419,7 @@ REGISTER_OP("ShapeVectorForAllElements") } int64_t total = 0; for (int i = 0; i < shape_data->NumElements(); ++i) { - total += shape_data->flat()(i); + total += shape_data->flat()(i); } c->set_output(0, c->Vector(total)); @@ -487,7 +488,8 @@ TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) { // Create variable 2x4 tensor. auto input = ops::Variable( - root, {2, 4, static_cast(std::numeric_limits::max()) * 2}, + root, + {2, 4, static_cast(std::numeric_limits::max()) * 2}, DT_INT64); // Shape is a vector of 2 elements (2,4) @@ -521,7 +523,8 @@ TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) { // Create variable 2x4 tensor. auto input = ops::Variable( - root, {2, 4, static_cast(std::numeric_limits::max()) * 2}, + root, + {2, 4, static_cast(std::numeric_limits::max()) * 2}, DT_INT32); // Shape is a vector of 2 elements (2,4) @@ -607,7 +610,7 @@ TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) { auto input = ops::Variable( root, {1, 2, 3, 4, 5, - static_cast(std::numeric_limits::max()) * 2}, + static_cast(std::numeric_limits::max()) * 2}, DT_INT64); // 5! * int32_max_value * 2. @@ -638,7 +641,7 @@ TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) { auto input = ops::Variable( root, {1, 2, 3, 4, 5, - static_cast(std::numeric_limits::max()) * 2}, + static_cast(std::numeric_limits::max()) * 2}, DT_INT32); // 5!. @@ -845,7 +848,7 @@ absl::Status PartialTensorAsShapeShapeFn(shape_inference::InferenceContext* c) { return absl::OkStatus(); } TF_RETURN_IF_ERROR( - c->MakeShapeFromTensorShape(TensorShape({t->flat()(0)}), &out)); + c->MakeShapeFromTensorShape(TensorShape({t->flat()(0)}), &out)); c->set_output(0, out); return absl::OkStatus(); } @@ -967,10 +970,10 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { InputList inputs{ // clang-format off - Input(ops::Const(root, 10)), - Input(ops::Const(root, 20)), + Input(ops::Const(root, 10)), + Input(ops::Const(root, 20)), Input(Output(scalar_non_const)), - Input(ops::Const(root, 40)), + Input(ops::Const(root, 40)), }; // clang-format on auto pack = ops::Stack(root, inputs); TF_ASSERT_OK(root.status()); diff --git a/tensorflow/core/common_runtime/simple_propagator_state.cc b/tensorflow/core/common_runtime/simple_propagator_state.cc index af721c1893baa0..3855c6a3d6cfce 100644 --- a/tensorflow/core/common_runtime/simple_propagator_state.cc +++ b/tensorflow/core/common_runtime/simple_propagator_state.cc @@ -35,7 +35,7 @@ SimplePropagatorState::SimplePropagatorState( vlog_(vlog || VLOG_IS_ON(1)), input_tensors_(finfo.total_inputs), pending_( - new std::atomic[immutable_state.graph_view().num_nodes()]), + new std::atomic[immutable_state.graph_view().num_nodes()]), active_(vlog_ ? new std::vector( immutable_state.graph_view().num_nodes()) : nullptr), diff --git a/tensorflow/core/common_runtime/simple_propagator_state.h b/tensorflow/core/common_runtime/simple_propagator_state.h index 3c53a5f900414f..8ef9775f93aee8 100644 --- a/tensorflow/core/common_runtime/simple_propagator_state.h +++ b/tensorflow/core/common_runtime/simple_propagator_state.h @@ -167,7 +167,7 @@ class SimplePropagatorState { // is never concurrent access to the same entry. std::vector input_tensors_; - std::unique_ptr[]> pending_; + std::unique_ptr[]> pending_; // If `vlog_` is true, this stores a bit vector of active nodes, indexed by // node ID. diff --git a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc index 10226237cbb8e3..5eb084d0def629 100644 --- a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc +++ b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc @@ -109,12 +109,12 @@ void RedirectEdge(Graph* graph, Node* old_src_node, Node* dst_node, } // Find the corresponding host device name from the TPU device name. -string GetHostDeviceName(Node* tpu_node) { +std::string GetHostDeviceName(Node* tpu_node) { auto device_name = tpu_node->requested_device(); if (device_name.empty()) device_name = tpu_node->assigned_device_name(); DeviceNameUtils::ParsedName parsed_device_name; DeviceNameUtils::ParseFullName(device_name, &parsed_device_name); - string host_device_name = DeviceNameUtils::FullName( + std::string host_device_name = DeviceNameUtils::FullName( parsed_device_name.job, parsed_device_name.replica, parsed_device_name.task, /*type=*/"CPU", /*id=*/0); return host_device_name; @@ -143,7 +143,8 @@ int GetTPUTaskId(Node* tpu_node) { // Build the fill op. Its value is 0 and the fill op is put on the host device // with the same task id as the TPUExecute node. Node* BuildFillOp(GraphDefBuilder::Options& bopts, Node* tpu_node, - Node* in_node, int input_index, string host_device_name) { + Node* in_node, int input_index, + std::string host_device_name) { // Find the output_shape vector auto output_shape_vec = GetOutputShapeVec(in_node); if (!output_shape_vec.has_value()) return nullptr; @@ -191,7 +192,7 @@ absl::Status ReplaceIciDummyVariables(Graph* graph, int input_index, continue; } - string host_device_name = GetHostDeviceName(tpu_node); + std::string host_device_name = GetHostDeviceName(tpu_node); // If the node corresponding to host_device_name is already in the graph, // replace the edge from in_node to tpu_node with the edge from diff --git a/tensorflow/core/common_runtime/single_threaded_executor.cc b/tensorflow/core/common_runtime/single_threaded_executor.cc index a7c30baec739ad..c737d274fbcd64 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor.cc @@ -65,8 +65,8 @@ namespace { typedef absl::InlinedVector TensorValueVec; typedef absl::InlinedVector AllocatorAttributeVec; -static const string& kSingleThreadedExecutor = - *new string("SINGLE_THREADED_EXECUTOR"); +static const std::string& kSingleThreadedExecutor = + *new std::string("SINGLE_THREADED_EXECUTOR"); class SingleThreadedExecutorImpl : public Executor { public: diff --git a/tensorflow/core/common_runtime/single_threaded_executor_test.cc b/tensorflow/core/common_runtime/single_threaded_executor_test.cc index 334ada5ad0a389..b081e17d86a978 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor_test.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor_test.cc @@ -170,8 +170,9 @@ float V(const Tensor& tensor) { return tensor.scalar()(); } -Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, - const string& receiver, const string& name) { +Rendezvous::ParsedKey Key(const std::string& sender, const uint64_t incarnation, + const std::string& receiver, + const std::string& name) { Rendezvous::ParsedKey result; TF_CHECK_OK( Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, @@ -363,8 +364,8 @@ void BM_executor(::testing::benchmark::State& state) { Graph* g = new Graph(OpRegistry::Global()); random::PhiloxRandom philox(1729, 17); random::SimplePhilox rand(&philox); - uint64 cur = 0; - uint32 r = 1 + rand.Rand32() % width; + uint64_t cur = 0; + uint32_t r = 1 + rand.Rand32() % width; std::vector ready_nodes; for (int i = 0; i < r; ++i) { ready_nodes.push_back(test::graph::NoOp(g, {})); @@ -392,7 +393,7 @@ void BM_executor(::testing::benchmark::State& state) { test::Benchmark("cpu", g, nullptr, nullptr, nullptr, "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api=*/false) .Run(state); - state.SetLabel(strings::StrCat("Nodes = ", cur)); + state.SetLabel(absl::StrCat("Nodes = ", cur)); state.SetItemsProcessed(cur * static_cast(state.iterations())); } @@ -424,7 +425,7 @@ void BM_const_identity(::testing::benchmark::State& state) { "SINGLE_THREADED_EXECUTOR", /*old_benchmark_api=*/false) .Run(state); - state.SetLabel(strings::StrCat("Nodes = ", (1 + outputs_per_const) * width)); + state.SetLabel(absl::StrCat("Nodes = ", (1 + outputs_per_const) * width)); state.SetItemsProcessed((1 + outputs_per_const) * width * static_cast(state.iterations())); } diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.cc b/tensorflow/core/common_runtime/stats_publisher_interface.cc index 8b04ac9f80523d..610efbdadb7dc8 100644 --- a/tensorflow/core/common_runtime/stats_publisher_interface.cc +++ b/tensorflow/core/common_runtime/stats_publisher_interface.cc @@ -43,7 +43,8 @@ class NoOpStatsPublisher : public StatsPublisherInterface { function_records) override {} std::unique_ptr GetProfileHandler( - uint64 step, int64_t execution_count, const RunOptions& ropts) override { + uint64_t step, int64_t execution_count, + const RunOptions& ropts) override { return nullptr; } @@ -74,7 +75,7 @@ StatsPublisherFactory StatsPublisherInterface::GetStatsPublisherFactory() { } std::unique_ptr CreateNoOpStatsPublisher( - const string& session, const BuildGraphOptions& bopts, + const std::string& session, const BuildGraphOptions& bopts, const SessionOptions& sopts) { return std::unique_ptr(new NoOpStatsPublisher); } diff --git a/tensorflow/core/common_runtime/stats_publisher_interface.h b/tensorflow/core/common_runtime/stats_publisher_interface.h index 450683e643dc0c..2f0e3221be97cb 100644 --- a/tensorflow/core/common_runtime/stats_publisher_interface.h +++ b/tensorflow/core/common_runtime/stats_publisher_interface.h @@ -61,7 +61,7 @@ class StatsPublisherInterface { // // This method may return a null pointer, if no handler was created. virtual std::unique_ptr GetProfileHandler( - uint64 step, int64_t execution_count, const RunOptions& ropts) = 0; + uint64_t step, int64_t execution_count, const RunOptions& ropts) = 0; virtual ~StatsPublisherInterface() {} @@ -77,7 +77,7 @@ class StatsPublisherInterface { }; std::unique_ptr CreateNoOpStatsPublisher( - const string& session, const BuildGraphOptions& bopts, + const std::string& session, const BuildGraphOptions& bopts, const SessionOptions& sopts); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 03fe4d946bdb0d..cc32e668309402 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -36,7 +36,7 @@ const int kMaxAllocReportNodes = 100; const float kMaxAllocReportFraction = 0.99; struct AllocStats { - std::map> nodes_by_size; + std::map> nodes_by_size; int64_t total_bytes = 0; int64_t total_nodes = 0; }; @@ -65,39 +65,39 @@ NodeExecStatsWrapper::NodeExecStatsWrapper( node_(node), step_stats_collector_(step_stats_collector) {} -void NodeExecStatsWrapper::Done(const string& device) { +void NodeExecStatsWrapper::Done(const std::string& device) { // TODO(tucker): merge with the DetailText function in session.cc in a common // location. DCHECK(node_); - string memory; + std::string memory; for (auto& all : stats_->memory()) { int64_t tot = all.total_bytes(); if (tot >= 0.1 * 1048576.0) { int64_t peak = all.peak_bytes(); if (peak > 0) { memory = - strings::StrCat(memory, "[", all.allocator_name(), - strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0, - peak / 1048576.0)); + absl::StrCat(memory, "[", all.allocator_name(), + strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0, + peak / 1048576.0)); } else { - memory = strings::StrCat(memory, "[", all.allocator_name(), - strings::Printf(" %.1fMB] ", tot / 1048576.0)); + memory = absl::StrCat(memory, "[", all.allocator_name(), + strings::Printf(" %.1fMB] ", tot / 1048576.0)); } } } const AttrSlice attrs(*node_); - string text; + std::string text; if (IsSend(node_)) { - string tensor_name; + std::string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); - string recv_device; + std::string recv_device; TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device)); text = strings::StrCat(memory, node_->name(), " = ", node_->op(), "(", tensor_name, " @", recv_device, ")"); } else if (IsRecv(node_)) { - string tensor_name; + std::string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); - string send_device; + std::string send_device; TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device)); text = strings::StrCat(memory, node_->name(), " = ", node_->op(), "(", tensor_name, " @", send_device, ")"); @@ -197,7 +197,7 @@ void NodeExecStatsWrapper::Finalize() { StepStatsCollector::StepStatsCollector(StepStats* step_stats) : finalized_(false), step_stats_(step_stats) {} -static int ExtractGpuWithStreamAll(string device_name) { +static int ExtractGpuWithStreamAll(std::string device_name) { // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp, // and if it does return the stream index (always positive). If it doesn't // return -1. @@ -220,7 +220,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture(capture); + std::string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(absl::SimpleAtoi(ordered_capture, &gpu_id)); @@ -228,7 +228,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } } -static int ExtractGpuWithoutStream(string device_name) { +static int ExtractGpuWithoutStream(std::string device_name) { // Check if the device name matches the ".*gpu:(\\d+)$" regexp, // and if it does return the stream index (always positive). If it doesn't // return -1. @@ -249,7 +249,7 @@ static int ExtractGpuWithoutStream(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture(capture); + std::string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(absl::SimpleAtoi(ordered_capture, &gpu_id)); @@ -259,7 +259,7 @@ static int ExtractGpuWithoutStream(string device_name) { void StepStatsCollector::BuildCostModel( CostModelManager* cost_model_manager, - const std::unordered_map& device_map) { + const std::unordered_map& device_map) { mutex_lock lock(mu_); if (!finalized_) { @@ -282,7 +282,7 @@ void StepStatsCollector::BuildCostModel( for (int i = 0; i < step_stats_->dev_stats_size(); ++i) { const DeviceStepStats& device_stats = step_stats_->dev_stats(i); - const string& device_name = device_stats.device(); + const std::string& device_name = device_stats.device(); const int gpu_id = ExtractGpuWithStreamAll(device_name); if (gpu_id >= 0) { // These are gpu hardware stats @@ -296,7 +296,7 @@ void StepStatsCollector::BuildCostModel( for (auto& itr : per_device_stats) { const absl::string_view device_name = itr.first; - const int gpu_id = ExtractGpuWithoutStream(string(device_name)); + const int gpu_id = ExtractGpuWithoutStream(std::string(device_name)); if (gpu_id >= 0) { // Reference the gpu hardware stats in addition to the regular stats // for this gpu device if they're available. @@ -324,10 +324,10 @@ void StepStatsCollector::BuildCostModel( const DeviceStats& dev_stats = per_device_stats.find(device)->second; - std::unordered_map name_to_hw_node_stats; + std::unordered_map name_to_hw_node_stats; if (dev_stats.hardware_stats) { for (const auto& node_stats : dev_stats.hardware_stats->node_stats()) { - string node_name = node_stats.node_name(); + std::string node_name = node_stats.node_name(); // Remove the part of op name (e.g. :Conv2D) in the end of a node name. size_t pos = node_name.find_first_of(':'); if (pos != std::string::npos) { @@ -368,7 +368,8 @@ void StepStatsCollector::BuildCostModel( cm->RecordMemoryStats(node, stats.memory_stats()); // Use hardware stats to record the execution time if they're available, // otherwise use the regular (less accurate) stats - string node_name = dev_stats.regular_stats->node_stats(i).node_name(); + std::string node_name = + dev_stats.regular_stats->node_stats(i).node_name(); if (dev_stats.hardware_stats && name_to_hw_node_stats.find(node_name) != name_to_hw_node_stats.end()) { const NodeExecStats& hw_stats = name_to_hw_node_stats[node_name]; @@ -383,14 +384,14 @@ void StepStatsCollector::BuildCostModel( } } -void StepStatsCollector::Save(const string& device, +void StepStatsCollector::Save(const std::string& device, NodeExecStats* node_stats_pb) { Save(device, new NodeExecStatsWrapper(std::unique_ptr(node_stats_pb), nullptr, this)); } -void StepStatsCollector::Save(const string& device, +void StepStatsCollector::Save(const std::string& device, NodeExecStatsWrapper* node_stats) { if (!node_stats) return; VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats(); @@ -410,9 +411,9 @@ void StepStatsCollector::Save(const string& device, } } -void StepStatsCollector::SaveThreadName(const string& device, - const uint32 thread_id, - const string& thread_name) { +void StepStatsCollector::SaveThreadName(const std::string& device, + const uint32_t thread_id, + const std::string& thread_name) { VLOG(1) << "Save dev " << device << " thread id " << thread_id << " name " << thread_name; { @@ -434,17 +435,17 @@ NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats( return new NodeExecStatsWrapper(node, this); } -string StepStatsCollector::ReportAllocsOnResourceExhausted( +std::string StepStatsCollector::ReportAllocsOnResourceExhausted( const absl::string_view err) { mutex_lock l(mu_); if (err.find("OOM") == err.npos) { return ""; } // -> AllocStats - std::map, AllocStats> allocs_map; - string report = "\n"; + std::map, AllocStats> allocs_map; + std::string report = "\n"; for (const auto& dev_stat : dev_stats_) { - const string& device = dev_stat.first; + const std::string& device = dev_stat.first; // Only print the device that has OOM. // TODO(xpan): Extract device from err first to speed it up. if (err.find(device) == err.npos) { @@ -490,7 +491,7 @@ string StepStatsCollector::ReportAllocsOnResourceExhausted( // Print allocations stats of the pair. for (auto it = dev_allocs_stats.nodes_by_size.rbegin(); it != dev_allocs_stats.nodes_by_size.rend(); ++it) { - for (const string& node_name : it->second) { + for (const std::string& node_name : it->second) { reported_bytes += it->first; strings::StrAppend(&report, " ", strings::HumanReadableNumBytes(it->first), " from ", @@ -532,7 +533,7 @@ void StepStatsCollector::FinalizeInternal() { return; } finalized_ = true; - std::map dev_stats_pb; + std::map dev_stats_pb; for (auto& ds : *step_stats_->mutable_dev_stats()) { dev_stats_pb[ds.device()] = &ds; } diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 277630cd40f9de..1c3503a8101654 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -51,7 +51,7 @@ class NodeExecStatsInterface { // Called when the statistics collection for the node has finished. Once this // method is called, the caller should not make assumptions about the validity // of this object. - virtual void Done(const string& device) = 0; + virtual void Done(const std::string& device) = 0; // Called immediately after this node starts being processed by the executor. virtual void RecordExecutorStarted() = 0; @@ -101,7 +101,7 @@ class NodeExecStatsWrapper : public NodeExecStatsInterface { // Destructor calls Finalize() to release the TrackingAllocators. ~NodeExecStatsWrapper() override { Finalize(); } - void Done(const string& device) override; + void Done(const std::string& device) override; void RecordExecutorStarted() override; void RecordComputeStarted() override; void RecordComputeEnded() override; @@ -148,7 +148,8 @@ class StepStatsCollectorInterface { // `err` message needs to contain device name and allocator name, e.g.: // "ResourceExhaustedError: OOM when allocating tensor ... // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" - virtual string ReportAllocsOnResourceExhausted(absl::string_view err) = 0; + virtual std::string ReportAllocsOnResourceExhausted( + absl::string_view err) = 0; }; // StepStatsCollector manages the collection of a StepStats object. @@ -164,19 +165,19 @@ class StepStatsCollector : public StepStatsCollectorInterface { // device_map. void BuildCostModel( CostModelManager* cost_model_manager, - const std::unordered_map& device_map); + const std::unordered_map& device_map); // Saves node statistics to the DeviceStats object associated with device. // Should be called before Finalize. - void Save(const string& device, NodeExecStats* node_stats_pb); - void Save(const string& device, NodeExecStatsWrapper* node_stats); + void Save(const std::string& device, NodeExecStats* node_stats_pb); + void Save(const std::string& device, NodeExecStatsWrapper* node_stats); // Saves thread name. - void SaveThreadName(const string& device, const uint32 thread_id, - const string& thread_name); + void SaveThreadName(const std::string& device, const uint32_t thread_id, + const std::string& thread_name); NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override; - string ReportAllocsOnResourceExhausted(absl::string_view err) override; + std::string ReportAllocsOnResourceExhausted(absl::string_view err) override; // The following 2 Finalize methods populate the StepStats passed // from the constructor. Calling it more than once won't have any effect. @@ -188,19 +189,21 @@ class StepStatsCollector : public StepStatsCollectorInterface { private: // TODO(suharshs): Make this configurable if its not possible to find a value // that works for all cases. - static constexpr uint64 kMaxCollectedNodes = 1 << 20; + static constexpr uint64_t kMaxCollectedNodes = 1 << 20; typedef std::vector> NodeStatsVector; - typedef std::unordered_map ThreadNamesMap; + typedef std::unordered_map ThreadNamesMap; void FinalizeInternal() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); mutex mu_; bool finalized_ TF_GUARDED_BY(mu_); - std::unordered_map dev_stats_ TF_GUARDED_BY(mu_); - std::unordered_map thread_names_ TF_GUARDED_BY(mu_); + std::unordered_map dev_stats_ + TF_GUARDED_BY(mu_); + std::unordered_map thread_names_ + TF_GUARDED_BY(mu_); StepStats* step_stats_ TF_GUARDED_BY(mu_); - uint64 collected_nodes_ TF_GUARDED_BY(mu_) = 0; + uint64_t collected_nodes_ TF_GUARDED_BY(mu_) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/testlib_ops.cc b/tensorflow/core/common_runtime/testlib_ops.cc index 11970bee114128..d36ad0a20dc0a0 100644 --- a/tensorflow/core/common_runtime/testlib_ops.cc +++ b/tensorflow/core/common_runtime/testlib_ops.cc @@ -46,7 +46,7 @@ class ErrorOp : public OpKernel { } private: - string errmsg_; + std::string errmsg_; bool log_error_ = false; }; REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp); diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 23166b69540083..8ada1107f7f044 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -54,7 +54,7 @@ info. It does not have any negative impact on performance. */ namespace tensorflow { ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, - const string& name, Bytes memory_limit, + const std::string& name, Bytes memory_limit, const DeviceLocality& locality, Allocator* allocator) : LocalDevice(options, Device::BuildDeviceAttributes( diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h index 08175ccb1f231c..4e6c0b87935082 100644 --- a/tensorflow/core/common_runtime/threadpool_device.h +++ b/tensorflow/core/common_runtime/threadpool_device.h @@ -25,7 +25,7 @@ namespace tensorflow { // CPU device implementation. class ThreadPoolDevice : public LocalDevice { public: - ThreadPoolDevice(const SessionOptions& options, const string& name, + ThreadPoolDevice(const SessionOptions& options, const std::string& name, Bytes memory_limit, const DeviceLocality& locality, Allocator* allocator); ~ThreadPoolDevice() override; diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc index 3ac8ea5ae8b68c..a6756935e63e27 100644 --- a/tensorflow/core/common_runtime/threadpool_device_factory.cc +++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc @@ -29,14 +29,14 @@ namespace tensorflow { // TODO(zhifengc/tucker): Figure out the bytes of available RAM. class ThreadPoolDeviceFactory : public DeviceFactory { public: - absl::Status ListPhysicalDevices(std::vector* devices) override { + absl::Status ListPhysicalDevices(std::vector* devices) override { devices->push_back("/physical_device:CPU:0"); return absl::OkStatus(); } absl::Status CreateDevices( - const SessionOptions& options, const string& name_prefix, + const SessionOptions& options, const std::string& name_prefix, std::vector>* devices) override { int num_numa_nodes = port::NUMANumNodes(); int n = 1; @@ -45,7 +45,7 @@ class ThreadPoolDeviceFactory : public DeviceFactory { n = iter->second; } for (int i = 0; i < n; i++) { - string name = strings::StrCat(name_prefix, "/device:CPU:", i); + std::string name = absl::StrCat(name_prefix, "/device:CPU:", i); std::unique_ptr tpd; if (options.config.experimental().use_numa_affinity()) { int numa_node = i % num_numa_nodes; diff --git a/tensorflow/core/common_runtime/type_inference.cc b/tensorflow/core/common_runtime/type_inference.cc index 8239e1c2196767..0434c287f31a5e 100644 --- a/tensorflow/core/common_runtime/type_inference.cc +++ b/tensorflow/core/common_runtime/type_inference.cc @@ -125,7 +125,7 @@ absl::Status update_inferred_type(Node* target, const FullTypeDef& t, return absl::OkStatus(); } -absl::StatusOr run_inference(const string& fn_name, +absl::StatusOr run_inference(const std::string& fn_name, const TypeRefVector& in_types) { // TODO(b/224776031): Things remaining to implement: // * look up function by name diff --git a/tensorflow/core/common_runtime/type_inference_test.cc b/tensorflow/core/common_runtime/type_inference_test.cc index 068f81ea191ace..6f7a165e695326 100644 --- a/tensorflow/core/common_runtime/type_inference_test.cc +++ b/tensorflow/core/common_runtime/type_inference_test.cc @@ -60,7 +60,6 @@ TEST(TypeInferenceTest, BasicStraightline) { Node* ds; TensorShapeProto shape; - shape.mutable_dim(); shape.set_unknown_rank(false); TF_ASSERT_OK(NodeBuilder("ds", "RangeDataset", &root.graph()->flib_def()) .Input({NodeBuilder::NodeOut(start.node())}) @@ -100,7 +99,6 @@ TEST(TypeInferenceTest, CyclicGraphWithV1ControlFlow) { Node* ds; TensorShapeProto shape; - shape.mutable_dim(); shape.set_unknown_rank(false); TF_ASSERT_OK(NodeBuilder("ds", "RangeDataset", &root.graph()->flib_def()) .Input({NodeBuilder::NodeOut(start.node())}) @@ -443,7 +441,6 @@ TEST(ReverseTypeInferenceTest, BasicVDependency) { Node* ds; // This node has a type constructor. TensorShapeProto shape; - shape.mutable_dim(); shape.set_unknown_rank(false); TF_ASSERT_OK(NodeBuilder("ds", "RangeDataset", &root.graph()->flib_def()) .Input({NodeBuilder::NodeOut(start.node())}) @@ -491,7 +488,6 @@ TEST(ReverseTypeInferenceTest, FromUnsetType) { Node* it; TensorShapeProto shape; - shape.mutable_dim(); shape.set_unknown_rank(false); TF_ASSERT_OK( NodeBuilder("it", "AnonymousIteratorV2", &root.graph()->flib_def()) diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h index d6bc4d9531173f..f257876ad907ad 100644 --- a/tensorflow/core/config/flag_defs.h +++ b/tensorflow/core/config/flag_defs.h @@ -69,6 +69,9 @@ class Flags { "graphs.") TF_DECLARE_FLAG(enable_graph_debug_info_caching_for_stack_frames, true, "If true, graph debug info will cache the stack frames.") + TF_DECLARE_FLAG( + enable_fatal_error_on_collective_abort, false, + "If true, a fatal error will be raised when a collective is aborted.") // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc) }; diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc index 9da0ba1b64b0b7..13b3fcb0d135bf 100644 --- a/tensorflow/core/config/flags_api_wrapper.cc +++ b/tensorflow/core/config/flags_api_wrapper.cc @@ -56,5 +56,6 @@ PYBIND11_MODULE(flags_pybind, m) { TF_PY_DECLARE_FLAG(enable_function_pruning_before_inlining) TF_PY_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs) TF_PY_DECLARE_FLAG(enable_graph_debug_info_caching_for_stack_frames) + TF_PY_DECLARE_FLAG(enable_fatal_error_on_collective_abort) // LINT.ThenChange(//tensorflow/core/config/flag_defs.h) }; diff --git a/tensorflow/core/data/compression_utils.cc b/tensorflow/core/data/compression_utils.cc index bd65978ab1da1f..cfc5e4b18208ef 100644 --- a/tensorflow/core/data/compression_utils.cc +++ b/tensorflow/core/data/compression_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/data/compression_utils.h" #include +#include +#include #include #include @@ -123,7 +125,7 @@ absl::Status CompressElement(const std::vector& element, } } - if (iov.NumBytes() > kuint32max) { + if (iov.NumBytes() > std::numeric_limits::max()) { return errors::OutOfRange("Encountered dataset element of size ", iov.NumBytes(), ", exceeding the 4GB Snappy limit."); diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 14f385dcaa48cb..a9300194089efe 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -493,7 +494,7 @@ bool MatchesAnyVersion(absl::string_view op_prefix, return true; } size_t index = op_to_match.length() - 1; - while (isdigit(op_to_match[index])) { + while (absl::ascii_isdigit(op_to_match[index])) { index--; } return (op_to_match[index] == 'V') && (op_prefix.length() == index); diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc index cf99ef7088c5ae..b8f0434adae02d 100644 --- a/tensorflow/core/data/dataset_utils_test.cc +++ b/tensorflow/core/data/dataset_utils_test.cc @@ -718,8 +718,6 @@ INSTANTIATE_TEST_SUITE_P(Test, GetOptimizationsTest, GetOptimizationTestCase4())); TEST(DeterministicOpsTest, GetOptimizations) { - // TODO(b/259305727): Re-enable for MacOS when the bug is fixed. -#if !defined(__APPLE__) tsl::test::DeterministicOpsScope det_scope; Options options; // options.deterministic should be ignored when deterministic ops are enabled. @@ -729,7 +727,6 @@ TEST(DeterministicOpsTest, GetOptimizations) { EXPECT_THAT(std::vector(actual_enabled.begin(), actual_enabled.end()), ::testing::UnorderedElementsAreArray({"make_deterministic"})); EXPECT_EQ(actual_disabled.size(), 0); -#endif } REGISTER_DATASET_EXPERIMENT("test_only_experiment", diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index 36aa6492a22faa..1a79089fbccc0f 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -119,7 +119,7 @@ absl::Status DataServiceClient::Initialize( << " in tf.data service client."; dispatcher_ = std::make_unique(params_.address, params_.protocol); - int64_t deadline_micros = kint64max; + int64_t deadline_micros = std::numeric_limits::max(); std::optional job_name; if (!params_.job_name.empty()) { job_name = params_.job_name; @@ -668,7 +668,7 @@ void DataServiceClient::RunWorkerThread(std::function done) } VLOG(3) << "Processing task " << task_to_process->info.task_id(); } - int64_t deadline_micros = kint64max; + int64_t deadline_micros = std::numeric_limits::max(); absl::Status s = GetElementTraced(task_to_process.get(), deadline_micros, /*enqueue_result=*/!IsCoordinatedRead(), allow_skip, result); diff --git a/tensorflow/core/data/service/dispatcher_client.cc b/tensorflow/core/data/service/dispatcher_client.cc index 87608b858eedee..c06acb3e332ddf 100644 --- a/tensorflow/core/data/service/dispatcher_client.cc +++ b/tensorflow/core/data/service/dispatcher_client.cc @@ -379,9 +379,9 @@ absl::Status DataServiceDispatcherClient::DisableCompressionAtRuntime( } absl::Status DataServiceDispatcherClient::EnsureInitialized() { - return grpc_util::Retry([this] { return Initialize(); }, - "Initialize dispatcher client", - /*deadline_micros=*/kint64max); + return grpc_util::Retry( + [this] { return Initialize(); }, "Initialize dispatcher client", + /*deadline_micros=*/std::numeric_limits::max()); } } // namespace data diff --git a/tensorflow/core/data/service/task_runner.cc b/tensorflow/core/data/service/task_runner.cc index 2b85af5aa20b73..a4d82ede95d362 100644 --- a/tensorflow/core/data/service/task_runner.cc +++ b/tensorflow/core/data/service/task_runner.cc @@ -232,7 +232,7 @@ std::shared_ptr CachingTaskRunner::model() const { RoundRobinTaskRunner::RoundRobinTaskRunner( std::unique_ptr iterator, int64_t num_consumers, - string worker_address) + std::string worker_address) : num_consumers_(num_consumers), worker_address_(worker_address), buffer_(num_consumers_), diff --git a/tensorflow/core/data/service/task_runner.h b/tensorflow/core/data/service/task_runner.h index 79d698f9edc65f..9f208b6bc0c35e 100644 --- a/tensorflow/core/data/service/task_runner.h +++ b/tensorflow/core/data/service/task_runner.h @@ -261,7 +261,7 @@ class PrefetchThread { class RoundRobinTaskRunner : public TaskRunner { public: RoundRobinTaskRunner(std::unique_ptr iterator, - int64_t num_consumers, string worker_address); + int64_t num_consumers, std::string worker_address); absl::Status GetNext(const GetElementRequest& req, GetElementResult& result) override; @@ -280,7 +280,7 @@ class RoundRobinTaskRunner : public TaskRunner { // start. absl::Status PrepareRound(const GetElementRequest& req); const int64_t num_consumers_; - const string worker_address_; + const std::string worker_address_; mutex mu_; bool cancelled_ TF_GUARDED_BY(mu_) = false; // Condition variable notified whenever we start a new round of round-robin. @@ -291,7 +291,7 @@ class RoundRobinTaskRunner : public TaskRunner { requests_ TF_GUARDED_BY(mu_); // Index of the first round we plan to serve. At startup, this is the minimum // of all requested element indices. - int64_t first_round_ TF_GUARDED_BY(mu_) = kint64max; + int64_t first_round_ TF_GUARDED_BY(mu_) = std::numeric_limits::max(); int64_t current_round_ TF_GUARDED_BY(mu_) = -1; bool round_skipped_ TF_GUARDED_BY(mu_) = false; // Buffered results for the current round. diff --git a/tensorflow/core/data/service/thread_safe_buffer_test.cc b/tensorflow/core/data/service/thread_safe_buffer_test.cc index ea4008b3886dde..b486a078cf92cc 100644 --- a/tensorflow/core/data/service/thread_safe_buffer_test.cc +++ b/tensorflow/core/data/service/thread_safe_buffer_test.cc @@ -167,7 +167,7 @@ TEST_P(ThreadSafeBufferTest, BlockWriterWhenBufferIsFull) { ASSERT_THAT(buffer.Push(Tensor("Test tensor")), absl_testing::IsOk()); } - uint64 push_time = 0; + uint64_t push_time = 0; auto thread = absl::WrapUnique(Env::Default()->StartThread( /*thread_options=*/{}, /*name=*/"writer_thread", [&buffer, &push_time]() { ASSERT_THAT(buffer.Push(Tensor("Test tensor")), absl_testing::IsOk()); @@ -176,7 +176,7 @@ TEST_P(ThreadSafeBufferTest, BlockWriterWhenBufferIsFull) { // Popping an element unblocks the `Push` call. Env::Default()->SleepForMicroseconds(10000); - uint64 pop_time = Env::Default()->NowMicros(); + uint64_t pop_time = Env::Default()->NowMicros(); ASSERT_THAT(buffer.Pop(), absl_testing::IsOk()); thread.reset(); EXPECT_LE(pop_time, push_time); diff --git a/tensorflow/core/data/service/utils.cc b/tensorflow/core/data/service/utils.cc index 4f79b9384de3b7..c4a1ea3dfd351a 100644 --- a/tensorflow/core/data/service/utils.cc +++ b/tensorflow/core/data/service/utils.cc @@ -44,7 +44,7 @@ absl::Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(path, &file)); io::RecordReader reader(file.get()); - uint64 offset = 0; + uint64_t offset = 0; tstring record; TF_RETURN_IF_ERROR(reader.ReadRecord(&offset, &record)); if (!dataset_def.ParseFromString(record)) { diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index c89c8a1c4881f4..f5978a573b24e4 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/data/service/worker_impl.h" #include +#include #include #include #include @@ -183,9 +184,9 @@ absl::Status DataServiceWorkerImpl::Start( mutex_lock l(mu_); return !cancelled_; }; - TF_RETURN_IF_ERROR(grpc_util::Retry([this]() { return Heartbeat(); }, - should_retry, "Worker heartbeat.", - /*deadline_micros=*/kint64max)); + TF_RETURN_IF_ERROR(grpc_util::Retry( + [this]() { return Heartbeat(); }, should_retry, "Worker heartbeat.", + /*deadline_micros=*/std::numeric_limits::max())); LOG(INFO) << "Worker registered with dispatcher running at " << config_.dispatcher_address() << ". Worker config: " << config_.DebugString(); @@ -248,10 +249,10 @@ DataServiceWorkerImpl::CreateDispatcherClient() const TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); return !cancelled_; }; - TF_RETURN_IF_ERROR( - grpc_util::Retry([&dispatcher]() { return dispatcher->Initialize(); }, - should_retry, "Initialize dispatcher client.", - /*deadline_micros=*/kint64max)); + TF_RETURN_IF_ERROR(grpc_util::Retry( + [&dispatcher]() { return dispatcher->Initialize(); }, should_retry, + "Initialize dispatcher client.", + /*deadline_micros=*/std::numeric_limits::max())); return dispatcher; } diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index 576cbed01fb633..9946b3dc213020 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -710,7 +710,7 @@ absl::Status Reader::MakeNestedDataset( datasets.push_back( new Dataset(DatasetContext(DatasetContext::Params( {"SnapshotDatasetReader", - strings::StrCat("SnapshotDatasetReader/_", i)})), + absl::StrCat("SnapshotDatasetReader/_", i)})), shard_dirs.at(i), compression_type, version, dtypes, shapes, dataset_start_index)); datasets.back()->Initialize(/*metadata=*/{}); diff --git a/tensorflow/core/data/snapshot_utils.h b/tensorflow/core/data/snapshot_utils.h index f083cbe495fa72..1af57f169efd97 100644 --- a/tensorflow/core/data/snapshot_utils.h +++ b/tensorflow/core/data/snapshot_utils.h @@ -65,10 +65,10 @@ constexpr char kShardDirectorySuffix[] = ".shard"; enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; // Returns the name of the "hash" directory for the given base path and hash ID. -std::string HashDirectory(const std::string& path, uint64 hash); +std::string HashDirectory(const std::string& path, uint64_t hash); // Returns the name of the "run" directory for the given base path and run ID. -std::string RunDirectory(const std::string& hash_directory, uint64 run_id); +std::string RunDirectory(const std::string& hash_directory, uint64_t run_id); std::string RunDirectory(const std::string& hash_directory, const std::string& run_id); @@ -78,7 +78,7 @@ std::string ShardDirectory(const std::string& run_directory, int64_t shard_id); // Returns the checkpoint file name for the given directory and checkpoint ID. std::string GetCheckpointFileName(const std::string& shard_directory, - uint64 checkpoint_id); + uint64_t checkpoint_id); // This is a interface class that exposes snapshot writing functionality. class Writer { @@ -132,7 +132,7 @@ class TFRecordWriter : public Writer { // Writes snapshot with a custom (legacy) file format. class CustomWriter : public Writer { public: - static constexpr const size_t kHeaderSize = sizeof(uint64); + static constexpr const size_t kHeaderSize = sizeof(uint64_t); static constexpr const char* const kClassName = "SnapshotWriter"; static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; @@ -210,7 +210,7 @@ class Reader { // the `version`, `compression_type`, and `dtypes` arguments passed into // `Writer` and `Reader` must be the same for the reading to succeed. static absl::Status Create(Env* env, const std::string& filename, - const string& compression_type, int version, + const std::string& compression_type, int version, const DataTypeVector& dtypes, std::unique_ptr* out_reader); @@ -221,7 +221,8 @@ class Reader { // contains all the elements written out to each individual snapshot file. static absl::Status MakeNestedDataset( Env* env, const std::vector& shard_dirs, - const string& compression_type, int version, const DataTypeVector& dtypes, + const std::string& compression_type, int version, + const DataTypeVector& dtypes, const std::vector& shapes, int64_t start_index, DatasetBase** output); @@ -253,7 +254,8 @@ class TFRecordReaderImpl { // tensorflow/compiler/xla/tsl/lib/io/compression.h. // `output_buffer_size` specifies the buffer size required by Snappy/Zlib // compression algorithms. Ignored if compression is not enabled. - TFRecordReaderImpl(const std::string& filename, const string& compression, + TFRecordReaderImpl(const std::string& filename, + const std::string& compression, std::optional output_buffer_size = std::nullopt); // Initializes the reader. Callers must initialize the reader before calling @@ -279,14 +281,14 @@ class TFRecordReaderImpl { uint64_t offset_ = 0; uint64_t bytes_read_ = 0; - const string compression_; + const std::string compression_; const std::optional output_buffer_size_; }; // Reads snapshots previously written with `TFRecordWriter`. class TFRecordReader : public Reader { public: - TFRecordReader(const std::string& filename, const string& compression, + TFRecordReader(const std::string& filename, const std::string& compression, const DataTypeVector& dtypes, std::optional output_buffer_size = std::nullopt) : reader_impl_(filename, compression, output_buffer_size), @@ -321,14 +323,14 @@ class CustomReader : public Reader { // TODO(b/148804377): Set this in a smarter fashion. static constexpr const int64_t kSnappyReaderOutputBufferSizeBytes = 32 << 20; // 32 MiB - static constexpr const size_t kHeaderSize = sizeof(uint64); + static constexpr const size_t kHeaderSize = sizeof(uint64_t); static constexpr const char* const kClassName = "SnapshotReader"; static constexpr const char* const kReadString = "ReadString"; static constexpr const char* const kReadCord = "ReadCord"; static constexpr const char* const kSeparator = "::"; - CustomReader(const std::string& filename, const string& compression_type, + CustomReader(const std::string& filename, const std::string& compression_type, int version, const DataTypeVector& dtypes); absl::Status ReadTensors(std::vector* read_tensors) override; @@ -356,7 +358,7 @@ class CustomReader : public Reader { std::string filename_; std::unique_ptr file_; std::unique_ptr input_stream_; - const string compression_type_; + const std::string compression_type_; const int version_; const DataTypeVector dtypes_; int num_simple_ = 0; @@ -366,18 +368,18 @@ class CustomReader : public Reader { // Writes snapshot metadata to the given directory. absl::Status WriteMetadataFile( - Env* env, const string& dir, + Env* env, const std::string& dir, const experimental::SnapshotMetadataRecord* metadata); // Writes distributed snapshot metadata to the given directory. An error is // returned if `dir` is unable to be created or if `metadata` is unable to be // written. absl::Status WriteMetadataFile( - Env* env, const string& dir, + Env* env, const std::string& dir, const experimental::DistributedSnapshotMetadata* metadata); // Reads snapshot metadata from the given directory. -absl::Status ReadMetadataFile(Env* env, const string& dir, +absl::Status ReadMetadataFile(Env* env, const std::string& dir, experimental::SnapshotMetadataRecord* metadata, bool* file_exists); @@ -386,17 +388,17 @@ absl::Status ReadMetadataFile(Env* env, const string& dir, // returned. If the file exists in `dir` but is unable to be opened, an error // is returned. absl::Status ReadMetadataFile( - Env* env, const string& dir, + Env* env, const std::string& dir, experimental::DistributedSnapshotMetadata* metadata, bool* file_exists); // Writes a dataset graph to the given directory. -absl::Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, +absl::Status DumpDatasetGraph(Env* env, const std::string& path, uint64_t hash, const GraphDef* graph); absl::Status DetermineOpState( const std::string& mode_string, bool file_exists, const experimental::SnapshotMetadataRecord* metadata, - uint64 pending_snapshot_expiry_seconds, Mode* mode); + uint64_t pending_snapshot_expiry_seconds, Mode* mode); // Represents a dataset element or EOF. struct ElementOrEOF { @@ -420,9 +422,9 @@ struct ElementOrEOF { class AsyncWriter { public: explicit AsyncWriter(Env* env, int64_t file_index, - const std::string& shard_directory, uint64 checkpoint_id, - const std::string& compression, int64_t version, - const DataTypeVector& output_types, + const std::string& shard_directory, + uint64_t checkpoint_id, const std::string& compression, + int64_t version, const DataTypeVector& output_types, std::function done); // Writes the given tensors. The method is non-blocking and returns without @@ -437,7 +439,7 @@ class AsyncWriter { void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_); bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); absl::Status WriterThread(Env* env, const std::string& shard_directory, - uint64 checkpoint_id, + uint64_t checkpoint_id, const std::string& compression, int64_t version, DataTypeVector output_types); diff --git a/tensorflow/core/data/split_utils.cc b/tensorflow/core/data/split_utils.cc index 5248e6370781b6..44eda7649af7cc 100644 --- a/tensorflow/core/data/split_utils.cc +++ b/tensorflow/core/data/split_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -82,7 +83,7 @@ absl::Status IndexSplitProvider::Restore( int64_t IndexSplitProvider::Cardinality() const { // RandomDataset uses kint64max to simulate infinite splits. // See RandomDatasetOp::Dataset::MakeSplitProviders. - if (n_ == tsl::kint64max) { + if (n_ == std::numeric_limits::max()) { return kInfiniteCardinality; } return n_; diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index 8052db9883010d..74c87c5103e21b 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -162,7 +162,7 @@ absl::Status Dataset::FromGraph(Params params, const GraphDef& graph_def, return absl::OkStatus(); }}); - string fetch_node = ""; + std::string fetch_node = ""; for (const auto& node : graph_def.node()) { if (node.op() == "_Retval") { fetch_node = node.input(0); diff --git a/tensorflow/core/data/standalone_test.cc b/tensorflow/core/data/standalone_test.cc index b60b4e752492e9..aa8e7259b5d90b 100644 --- a/tensorflow/core/data/standalone_test.cc +++ b/tensorflow/core/data/standalone_test.cc @@ -514,7 +514,7 @@ constexpr const char* const kMapGraphNoAutotuneProto = R"pb( TEST(Scalar, Standalone) { struct TestCase { - string graph_string; + std::string graph_string; std::vector expected_outputs; }; auto test_cases = { diff --git a/tensorflow/core/data/stats_utils.cc b/tensorflow/core/data/stats_utils.cc index 80c1de6dbd5576..12c12b8907994e 100644 --- a/tensorflow/core/data/stats_utils.cc +++ b/tensorflow/core/data/stats_utils.cc @@ -33,40 +33,40 @@ ABSL_CONST_INIT const char kFeaturesCount[] = "features_count"; ABSL_CONST_INIT const char kFeatureValuesCount[] = "feature_values_count"; ABSL_CONST_INIT const char kExamplesCount[] = "examples_count"; -string ExecutionTimeHistogramName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kExecutionTime); +std::string ExecutionTimeHistogramName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kExecutionTime); } -string ThreadUtilizationScalarName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kThreadUtilization); +std::string ThreadUtilizationScalarName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kThreadUtilization); } -string BufferSizeScalarName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kBufferSize); +std::string BufferSizeScalarName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kBufferSize); } -string BufferCapacityScalarName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kBufferCapacity); +std::string BufferCapacityScalarName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kBufferCapacity); } -string BufferUtilizationHistogramName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kBufferUtilization); +std::string BufferUtilizationHistogramName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kBufferUtilization); } -string FilterdElementsScalarName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kFilteredElements); +std::string FilterdElementsScalarName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kFilteredElements); } -string DroppedElementsScalarName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kDroppedElements); +std::string DroppedElementsScalarName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kDroppedElements); } -string FeatureHistogramName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kFeaturesCount); +std::string FeatureHistogramName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kFeaturesCount); } -string FeatureValueHistogramName(const string& prefix) { - return strings::StrCat(prefix, kDelimiter, kFeatureValuesCount); +std::string FeatureValueHistogramName(const std::string& prefix) { + return absl::StrCat(prefix, kDelimiter, kFeatureValuesCount); } } // namespace stats_utils diff --git a/tensorflow/core/data/stats_utils.h b/tensorflow/core/data/stats_utils.h index 5fa1eae397b39e..22a40b9be963a7 100644 --- a/tensorflow/core/data/stats_utils.h +++ b/tensorflow/core/data/stats_utils.h @@ -33,33 +33,33 @@ extern const char kFeatureValuesCount[]; extern const char kExamplesCount[]; // Name for tf.data function execution time (in ns) histogram metrics. -string ExecutionTimeHistogramName(const string& prefix); +std::string ExecutionTimeHistogramName(const std::string& prefix); // Name for thread utilization (ratio of threads being used and maximum number // of threads allocated) scalar metrics. -string ThreadUtilizationScalarName(const string& prefix); +std::string ThreadUtilizationScalarName(const std::string& prefix); // Name for buffer size scalar metrics. -string BufferSizeScalarName(const string& prefix); +std::string BufferSizeScalarName(const std::string& prefix); // Name for buffer capacity (maximum allocated buffer size) scalar metrics. -string BufferCapacityScalarName(const string& prefix); +std::string BufferCapacityScalarName(const std::string& prefix); // Name for buffer utilization (ratio of buffer size and maximum allocated // buffer size.) histogram metrics. -string BufferUtilizationHistogramName(const string& prefix); +std::string BufferUtilizationHistogramName(const std::string& prefix); // Name for filtered elements scalar metrics. -string FilterdElementsScalarName(const string& prefix); +std::string FilterdElementsScalarName(const std::string& prefix); // Name for dropped elements scalar mereics. -string DroppedElementsScalarName(const string& prefix); +std::string DroppedElementsScalarName(const std::string& prefix); // Name for features count histogram metrics. -string FeatureHistogramName(const string& prefix); +std::string FeatureHistogramName(const std::string& prefix); // Name for feature-values count histogram metrics. -string FeatureValueHistogramName(const string& prefix); +std::string FeatureValueHistogramName(const std::string& prefix); } // namespace stats_utils } // namespace data diff --git a/tensorflow/core/data/unbounded_thread_pool.cc b/tensorflow/core/data/unbounded_thread_pool.cc index 3ffbbd5c70569f..6d196322298612 100644 --- a/tensorflow/core/data/unbounded_thread_pool.cc +++ b/tensorflow/core/data/unbounded_thread_pool.cc @@ -59,7 +59,7 @@ class UnboundedThreadPool::LogicalThreadFactory : public ThreadFactory { public: explicit LogicalThreadFactory(UnboundedThreadPool* pool) : pool_(pool) {} - std::unique_ptr StartThread(const string& name, + std::unique_ptr StartThread(const std::string& name, std::function fn) override { auto done = std::make_shared(); pool_->ScheduleOnWorkQueue(std::move(fn), done); diff --git a/tensorflow/core/data/unbounded_thread_pool.h b/tensorflow/core/data/unbounded_thread_pool.h index 1b89024a8db86e..1046c8ad5e7bb7 100644 --- a/tensorflow/core/data/unbounded_thread_pool.h +++ b/tensorflow/core/data/unbounded_thread_pool.h @@ -35,9 +35,9 @@ namespace data { // `UnboundedWorkQueue`. class UnboundedThreadPool : public thread::ThreadPoolInterface { public: - UnboundedThreadPool(Env* env, const string& thread_name) + UnboundedThreadPool(Env* env, const std::string& thread_name) : unbounded_work_queue_(env, thread_name) {} - UnboundedThreadPool(Env* env, const string& thread_name, + UnboundedThreadPool(Env* env, const std::string& thread_name, const ThreadOptions& thread_options) : unbounded_work_queue_(env, thread_name, thread_options) {} ~UnboundedThreadPool() override = default; diff --git a/tensorflow/core/debug/bfc_dump_reader.cc b/tensorflow/core/debug/bfc_dump_reader.cc index 9ff9dd9d474e7b..dbad9888c99caa 100644 --- a/tensorflow/core/debug/bfc_dump_reader.cc +++ b/tensorflow/core/debug/bfc_dump_reader.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/core/util/command_line_flags.h" namespace tensorflow { -MemoryDump ReadDumpFile(const string& fname) { +MemoryDump ReadDumpFile(const std::string& fname) { absl::Status status; - uint64 file_size = 0; + uint64_t file_size = 0; status = Env::Default()->GetFileSize(fname, &file_size); if (!status.ok()) { LOG(ERROR) << "Failed to get size of " << fname; @@ -66,7 +66,7 @@ MemoryDump FilterByChunkType(MemoryDump md, const char chunk_type) { return filtered; } -void PrintChunk(const MemChunk& mc, const uint64 ac_offset, bool freed_at, +void PrintChunk(const MemChunk& mc, const uint64_t ac_offset, bool freed_at, const int64_t total_bytes, int64_t* cumulative_bytes) { // A size class corresponding approximately to log base 100. int size_class = floor(0.5 * log10(static_cast(mc.size()))); @@ -120,7 +120,7 @@ void PrintSortedChunks( chunks.reserve(md.chunk_size()); int64_t total_bytes = 0; int64_t cumulative_bytes = 0; - uint64 max_action_count = 0; + uint64_t max_action_count = 0; for (auto& it : md.chunk()) { chunks.push_back(&it); total_bytes += it.size(); @@ -129,7 +129,7 @@ void PrintSortedChunks( } } sort(chunks.begin(), chunks.end(), compare); - uint64 last_end = 0; + uint64_t last_end = 0; for (int i = 0; i < chunks.size(); ++i) { const MemChunk* c = chunks[i]; if (by_addr && i > 0 && last_end != c->address()) { @@ -174,12 +174,12 @@ void PrintChunksBySize(const MemoryDump& md, bool by_age, bool freed_at) { by_age, freed_at, false /*by_addr*/); } -void PrintChunksByOpName(const MemoryDump& md, const string& op_name, +void PrintChunksByOpName(const MemoryDump& md, const std::string& op_name, bool by_age, bool freed_at) { printf("------------Chunks matching \"%s\":----------------------\n", op_name.c_str()); MemoryDump filtered; - uint64 total_bytes = 0; + uint64_t total_bytes = 0; filtered.set_allocator_name(md.allocator_name()); for (const auto& it : md.bin_summary()) { *filtered.add_bin_summary() = it; @@ -203,7 +203,7 @@ void PrintChunksByOpName(const MemoryDump& md, const string& op_name, void PrintSizeHistory(const MemoryDump& md, bool by_age) { printf("------------Allocated Bytes by Action Count--------\n"); printf("num snapshots: %d\n", md.snap_shot_size()); - uint64 max_action_count = 0; + uint64_t max_action_count = 0; if (by_age) { for (auto& it : md.snap_shot()) { if (it.action_count() > max_action_count) { diff --git a/tensorflow/core/debug/debug_callback_registry.cc b/tensorflow/core/debug/debug_callback_registry.cc index 97967a3f040eca..5ee0d53d507624 100644 --- a/tensorflow/core/debug/debug_callback_registry.cc +++ b/tensorflow/core/debug/debug_callback_registry.cc @@ -28,20 +28,20 @@ DebugCallbackRegistry* DebugCallbackRegistry::singleton() { return instance_; } -void DebugCallbackRegistry::RegisterCallback(const string& key, +void DebugCallbackRegistry::RegisterCallback(const std::string& key, EventCallback callback) { mutex_lock lock(mu_); keyed_callback_[key] = std::move(callback); } DebugCallbackRegistry::EventCallback* DebugCallbackRegistry::GetCallback( - const string& key) { + const std::string& key) { mutex_lock lock(mu_); auto iter = keyed_callback_.find(key); return iter == keyed_callback_.end() ? nullptr : &iter->second; } -void DebugCallbackRegistry::UnregisterCallback(const string& key) { +void DebugCallbackRegistry::UnregisterCallback(const std::string& key) { mutex_lock lock(mu_); keyed_callback_.erase(key); } diff --git a/tensorflow/core/debug/debug_callback_registry.h b/tensorflow/core/debug/debug_callback_registry.h index 94b57401418eb9..c3cf8d665af9d9 100644 --- a/tensorflow/core/debug/debug_callback_registry.h +++ b/tensorflow/core/debug/debug_callback_registry.h @@ -45,14 +45,14 @@ class DebugCallbackRegistry { static DebugCallbackRegistry* singleton(); // Returns the registered callback, or nullptr, for key. - EventCallback* GetCallback(const string& key); + EventCallback* GetCallback(const std::string& key); // Associates callback with key. This must be called by clients observing // nodes to be exported by this callback router before running a session. - void RegisterCallback(const string& key, EventCallback callback); + void RegisterCallback(const std::string& key, EventCallback callback); // Removes the callback associated with key. - void UnregisterCallback(const string& key); + void UnregisterCallback(const std::string& key); private: DebugCallbackRegistry(); @@ -61,7 +61,7 @@ class DebugCallbackRegistry { mutex mu_; // Maps debug_url keys to callbacks for routing observed tensors. - std::map keyed_callback_ TF_GUARDED_BY(mu_); + std::map keyed_callback_ TF_GUARDED_BY(mu_); static DebugCallbackRegistry* instance_; }; diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 10ee5a3d33b8ad..9b0fc5c517c170 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -30,8 +30,8 @@ namespace tensorflow { namespace { // TODO(cais): Switch to safe_strtob when available. -absl::Status ParseBoolString(const string& bool_str, bool* bool_val) { - const string lower_bool_str = absl::AsciiStrToLower(bool_str); +absl::Status ParseBoolString(const std::string& bool_str, bool* bool_val) { + const std::string lower_bool_str = absl::AsciiStrToLower(bool_str); if (lower_bool_str == "false" || lower_bool_str == "f" || lower_bool_str == "0") { *bool_val = false; @@ -60,15 +60,15 @@ absl::Status DebugNodeInserter::InsertNodes( } // Debug ops and URLs for wildcard node names (if any). - std::vector default_debug_ops; - std::vector default_debug_urls; + std::vector default_debug_ops; + std::vector default_debug_urls; // A map from tensor name (e.g., "node_a:0") to list of debug op names // (e.g., {"DebugIdentity", "DebugNanCount"}) - std::unordered_map> tensor_watches; + std::unordered_map> tensor_watches; // A map from tensor name to debug_url. - std::unordered_map> tensor_watch_urls; - std::unordered_map tensor_tolerate_failures; + std::unordered_map> tensor_watch_urls; + std::unordered_map tensor_tolerate_failures; // Cache the proto content for fast lookup later for (const DebugTensorWatch& watch : watches) { @@ -105,11 +105,11 @@ absl::Status DebugNodeInserter::InsertNodes( } } - string tensor_name = + std::string tensor_name = absl::StrCat(watch.node_name(), ":", watch.output_slot()); - std::vector debug_ops; - for (const string& debug_op : watch.debug_ops()) { + std::vector debug_ops; + for (const std::string& debug_op : watch.debug_ops()) { debug_ops.push_back(debug_op); } @@ -117,8 +117,8 @@ absl::Status DebugNodeInserter::InsertNodes( tensor_tolerate_failures[tensor_name] = watch.tolerate_debug_op_creation_failures(); - std::vector urls; - for (const string& url : watch.debug_urls()) { + std::vector urls; + for (const std::string& url : watch.debug_urls()) { urls.push_back(url); } tensor_watch_urls[tensor_name] = urls; @@ -148,7 +148,7 @@ absl::Status DebugNodeInserter::InsertNodes( // Iterate through all output slots of the node. for (int src_output_slot = 0; src_output_slot < src_node->num_outputs(); ++src_output_slot) { - const string tensor_name = + const std::string tensor_name = absl::StrCat(src_node->name(), ":", src_output_slot); const bool explicit_tensor_match = tensor_watches.find(tensor_name) != tensor_watches.end(); @@ -176,10 +176,10 @@ absl::Status DebugNodeInserter::InsertNodes( src_output_slot, &memory_type)); // Create the copy node for the watched tensor. - const std::vector debug_ops = explicit_tensor_match - ? tensor_watches[tensor_name] - : default_debug_ops; - const std::vector debug_urls = + const std::vector debug_ops = + explicit_tensor_match ? tensor_watches[tensor_name] + : default_debug_ops; + const std::vector debug_urls = explicit_tensor_match ? tensor_watch_urls[tensor_name] : default_debug_urls; Node* copy_node; @@ -200,7 +200,7 @@ absl::Status DebugNodeInserter::InsertNodes( // Create all requested debug nodes and their edges to the Copy node. std::vector debug_nodes; for (size_t i = 0; i < debug_ops.size(); ++i) { - const string& debug_op_name = debug_ops[i]; + const std::string& debug_op_name = debug_ops[i]; Node* debug_node; absl::Status debug_s = CreateDebugNode( @@ -280,17 +280,17 @@ void DebugNodeInserter::DeparallelizeWhileLoops(Graph* graph, Device* device) { } // static -const string DebugNodeInserter::GetCopyNodeName(const string& node_name, - const int output_slot) { +const std::string DebugNodeInserter::GetCopyNodeName( + const std::string& node_name, const int output_slot) { // For example, if the watched node is named "node1" and the output slot // is 0, the debug node will be called: __copy_node1_0 return absl::StrCat("__copy_", node_name, "_", output_slot); } // static -const string DebugNodeInserter::GetDebugNodeName(const string& tensor_name, - const int debug_op_num, - const string& debug_op_name) { +const std::string DebugNodeInserter::GetDebugNodeName( + const std::string& tensor_name, const int debug_op_num, + const std::string& debug_op_name) { // For example, if the watched node is named "node1" and the debug op that // watches the output slot of node1 is of the type "DebugNanCount", the // debug node will be called: __dbg_node1_0_0_DebugNanCount. @@ -301,23 +301,24 @@ const string DebugNodeInserter::GetDebugNodeName(const string& tensor_name, // static absl::Status DebugNodeInserter::CreateCopyNode( Graph* graph, const DeviceType device_type, const bool is_host_memory, - const string& src_node_name, const int src_output, const DataType src_dt, - const string& tensor_name, const std::vector& debug_ops, - const std::vector& debug_urls, Node** copy_node) { - const string kGatedGrpcAttributeKey = "gated_grpc"; + const std::string& src_node_name, const int src_output, + const DataType src_dt, const std::string& tensor_name, + const std::vector& debug_ops, + const std::vector& debug_urls, Node** copy_node) { + const std::string kGatedGrpcAttributeKey = "gated_grpc"; NodeDef node_def; const KernelDef* kdef; - const string copy_op_name = is_host_memory ? "CopyHost" : "Copy"; - const string copy_node_name = GetCopyNodeName(src_node_name, src_output); + const std::string copy_op_name = is_host_memory ? "CopyHost" : "Copy"; + const std::string copy_node_name = GetCopyNodeName(src_node_name, src_output); // Cross debug_ops and debug_urls to get the list of debug ops and watches. - std::vector debug_ops_spec; - for (const string& debug_op : debug_ops) { - for (const string& debug_url : debug_urls) { - string debug_op_name_proper; - std::unordered_map custom_attributes; + std::vector debug_ops_spec; + for (const std::string& debug_op : debug_ops) { + for (const std::string& debug_url : debug_urls) { + std::string debug_op_name_proper; + std::unordered_map custom_attributes; TF_RETURN_IF_ERROR(ParseDebugOpName(debug_op, &debug_op_name_proper, &custom_attributes)); @@ -363,24 +364,25 @@ absl::Status DebugNodeInserter::CreateCopyNode( // static absl::Status DebugNodeInserter::ParseDebugOpName( - const string& debug_op_name, string* debug_op_name_proper, - std::unordered_map* attributes) { + const std::string& debug_op_name, std::string* debug_op_name_proper, + std::unordered_map* attributes) { const size_t l_index = debug_op_name.find('('); const size_t r_index = debug_op_name.find(')'); - if (l_index == string::npos && r_index == string::npos) { + if (l_index == std::string::npos && r_index == std::string::npos) { *debug_op_name_proper = debug_op_name; } else { - if (l_index == string::npos || l_index == 0 || + if (l_index == std::string::npos || l_index == 0 || r_index != debug_op_name.size() - 1) { return absl::InvalidArgumentError( absl::StrCat("Malformed debug op name \"", debug_op_name, "\"")); } *debug_op_name_proper = debug_op_name.substr(0, l_index); - string arguments = debug_op_name.substr(l_index + 1, r_index - l_index - 1); + std::string arguments = + debug_op_name.substr(l_index + 1, r_index - l_index - 1); - std::vector attribute_segs = str_util::Split(arguments, ";"); - for (const string& attribute_seg : attribute_segs) { + std::vector attribute_segs = str_util::Split(arguments, ";"); + for (const std::string& attribute_seg : attribute_segs) { absl::string_view seg(attribute_seg); str_util::RemoveWhitespaceContext(&seg); if (seg.empty()) { @@ -388,13 +390,13 @@ absl::Status DebugNodeInserter::ParseDebugOpName( } const size_t eq_index = seg.find('='); - if (eq_index == string::npos) { + if (eq_index == std::string::npos) { return absl::InvalidArgumentError(absl::StrCat( "Malformed attributes in debug op name \"", debug_op_name, "\"")); } - const string key(seg.substr(0, eq_index)); - const string value( + const std::string key(seg.substr(0, eq_index)); + const std::string value( seg.substr(eq_index + 1, attribute_seg.size() - eq_index - 1)); if (key.empty() || value.empty()) { return absl::InvalidArgumentError(absl::StrCat( @@ -415,17 +417,18 @@ absl::Status DebugNodeInserter::ParseDebugOpName( // static absl::Status DebugNodeInserter::SetDebugNodeAttributes( - Node* debug_node, const std::unordered_map& attributes) { - std::unordered_set unfulfilled_keys; + Node* debug_node, + const std::unordered_map& attributes) { + std::unordered_set unfulfilled_keys; for (const auto& item : attributes) { unfulfilled_keys.insert(item.first); } for (const auto& attr : debug_node->op_def().attr()) { if (attributes.find(attr.name()) != attributes.end()) { - const string& attr_value = attributes.at(attr.name()); + const std::string& attr_value = attributes.at(attr.name()); if (attr.type() == "string") { - debug_node->AddAttr(attr.name(), attr_value); + debug_node->AddAttr(attr.name(), attr_value); } else if (attr.type() == "float") { float float_value = 0.0; if (!absl::SimpleAtof(attr_value, &float_value)) { @@ -472,19 +475,19 @@ absl::Status DebugNodeInserter::SetDebugNodeAttributes( // static absl::Status DebugNodeInserter::CreateDebugNode( - Graph* graph, const Device& device, const string& src_copy_node_name, - const DataType src_dt, const string& tensor_name, - const std::vector& debug_urls, const int debug_op_num, - const string& debug_op_name, Node** debug_node) { + Graph* graph, const Device& device, const std::string& src_copy_node_name, + const DataType src_dt, const std::string& tensor_name, + const std::vector& debug_urls, const int debug_op_num, + const std::string& debug_op_name, Node** debug_node) { NodeDef node_def; const KernelDef* kdef; - string debug_op_name_proper; - std::unordered_map custom_attributes; + std::string debug_op_name_proper; + std::unordered_map custom_attributes; TF_RETURN_IF_ERROR(ParseDebugOpName(debug_op_name, &debug_op_name_proper, &custom_attributes)); - const string debug_node_name = + const std::string debug_node_name = GetDebugNodeName(tensor_name, debug_op_num, debug_op_name_proper); auto builder = NodeDefBuilder(debug_node_name, debug_op_name_proper) .Input(src_copy_node_name, 0, src_dt) diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h index 27cfb357e2b9d9..9552becbe7b27c 100644 --- a/tensorflow/core/debug/debug_graph_utils.h +++ b/tensorflow/core/debug/debug_graph_utils.h @@ -82,20 +82,21 @@ class DebugNodeInserter { static void DeparallelizeWhileLoops(Graph* graph, Device* device); // Get canonical name of a copy node. - static const string GetCopyNodeName(const string& node_name, - const int output_slot); + static const std::string GetCopyNodeName(const std::string& node_name, + const int output_slot); // Get canonical name of a debug node. - static const string GetDebugNodeName(const string& tensor_name, - const int debug_op_num, - const string& debug_op_name); + static const std::string GetDebugNodeName(const std::string& tensor_name, + const int debug_op_num, + const std::string& debug_op_name); private: static absl::Status CreateCopyNode( Graph* graph, const DeviceType device_type, const bool is_host_memory, - const string& src_node_name, const int src_output, const DataType src_dt, - const string& tensor_name, const std::vector& debug_ops, - const std::vector& debug_urls, Node** copy_node); + const std::string& src_node_name, const int src_output, + const DataType src_dt, const std::string& tensor_name, + const std::vector& debug_ops, + const std::vector& debug_urls, Node** copy_node); // Parse the debug_op_name string to extract proper op name and attributes. // debug_op_name can be the proper op name only, e.g., "DebugNumericSummary". @@ -104,17 +105,18 @@ class DebugNodeInserter { // with semicolons (";"), which optional whitespace in between, e.g., // "DebugNumericSummary(mute_if_healthy=true, lower_bound=-100.0)". static absl::Status ParseDebugOpName( - const string& debug_op_name, string* debug_op_name_proper, - std::unordered_map* attributes); + const std::string& debug_op_name, std::string* debug_op_name_proper, + std::unordered_map* attributes); static absl::Status SetDebugNodeAttributes( - Node* debug_node, const std::unordered_map& attributes); + Node* debug_node, + const std::unordered_map& attributes); static absl::Status CreateDebugNode( - Graph* graph, const Device& device, const string& src_copy_node_name, - const DataType src_dt, const string& tensor_name, - const std::vector& debug_urls, const int debug_op_num, - const string& debug_op_name, Node** debug_node); + Graph* graph, const Device& device, const std::string& src_copy_node_name, + const DataType src_dt, const std::string& tensor_name, + const std::vector& debug_urls, const int debug_op_num, + const std::string& debug_op_name, Node** debug_node); // TODO(cais): Cut down the number of args to this method. friend class DebugGraphUtilsTest; diff --git a/tensorflow/core/debug/debug_graph_utils_test.cc b/tensorflow/core/debug/debug_graph_utils_test.cc index 207b8bc1b3c1f7..d1184d5d18c498 100644 --- a/tensorflow/core/debug/debug_graph_utils_test.cc +++ b/tensorflow/core/debug/debug_graph_utils_test.cc @@ -25,16 +25,16 @@ namespace tensorflow { class DebugGraphUtilsTest : public ::testing::Test { protected: absl::Status ParseDebugOpName( - const string& debug_op_name, string* debug_op_name_proper, - std::unordered_map* attributes) { + const std::string& debug_op_name, std::string* debug_op_name_proper, + std::unordered_map* attributes) { return DebugNodeInserter::ParseDebugOpName( debug_op_name, debug_op_name_proper, attributes); } }; TEST_F(DebugGraphUtilsTest, TestParseNoAttributeDebugOpName) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; TF_ASSERT_OK( ParseDebugOpName("DebugIdentity", &debug_op_name_proper, &attributes)); ASSERT_EQ("DebugIdentity", debug_op_name_proper); @@ -42,8 +42,8 @@ TEST_F(DebugGraphUtilsTest, TestParseNoAttributeDebugOpName) { } TEST_F(DebugGraphUtilsTest, TestMalformedDebugOpName) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; absl::Status s = ParseDebugOpName("(mute_if_healthy=true)", &debug_op_name_proper, &attributes); @@ -59,8 +59,8 @@ TEST_F(DebugGraphUtilsTest, TestMalformedDebugOpName) { } TEST_F(DebugGraphUtilsTest, TestDebugOpNameWithMalformedAttributes) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; absl::Status s = ParseDebugOpName("DebugNumericSummary(=)", &debug_op_name_proper, &attributes); @@ -89,8 +89,8 @@ TEST_F(DebugGraphUtilsTest, TestDebugOpNameWithMalformedAttributes) { } TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithSingleAttribute) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; TF_ASSERT_OK(ParseDebugOpName("DebugNumericSummary()", &debug_op_name_proper, &attributes)); @@ -106,8 +106,8 @@ TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithSingleAttribute) { } TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreThanOneAttributes) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; TF_ASSERT_OK(ParseDebugOpName( "DebugNumericSummary(mute_if_healthy=true; threshold=300.0)", &debug_op_name_proper, &attributes)); @@ -128,8 +128,8 @@ TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreThanOneAttributes) { } TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreDuplicateAttributes) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; absl::Status s = ParseDebugOpName( "DebugNumericSummary(mute_if_healthy=true; lower_bound=3; " "mute_if_healthy=false;)", @@ -138,8 +138,8 @@ TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreDuplicateAttributes) { } TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithWhitespaceInAttributes) { - string debug_op_name_proper; - std::unordered_map attributes; + std::string debug_op_name_proper; + std::unordered_map attributes; TF_ASSERT_OK(ParseDebugOpName( "DebugNumericSummary( mute_if_healthy=true; threshold=300.0 )", diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc index f6618666101361..19c79a04d2123d 100644 --- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc +++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc @@ -34,7 +34,7 @@ class GrpcDebugTest : public ::testing::Test { protected: struct ServerData { int port; - string url; + std::string url; std::unique_ptr server; std::unique_ptr thread_pool; }; @@ -86,7 +86,7 @@ TEST_F(GrpcDebugTest, ConnectionTimeoutWorks) { SetChannelConnectionTimeoutMicros(kShortTimeoutMicros); ASSERT_EQ(kShortTimeoutMicros, GetChannelConnectionTimeoutMicros()); - const string& kInvalidGrpcUrl = + const std::string& kInvalidGrpcUrl = absl::StrCat("grpc://localhost:", testing::PickUnusedPortOrDie()); Tensor tensor(DT_FLOAT, TensorShape({1, 1})); tensor.flat()(0) = 42.0; @@ -98,10 +98,11 @@ TEST_F(GrpcDebugTest, ConnectionTimeoutWorks) { TF_ASSERT_OK(DebugIO::CloseDebugURL(kInvalidGrpcUrl)); ASSERT_FALSE(publish_status.ok()); - const string expected_error_msg = strings::StrCat( + const std::string expected_error_msg = strings::StrCat( "Failed to connect to gRPC channel at ", kInvalidGrpcUrl.substr(7), " within a timeout of ", kShortTimeoutMicros / 1e6, " s"); - ASSERT_NE(string::npos, publish_status.message().find(expected_error_msg)); + ASSERT_NE(std::string::npos, + publish_status.message().find(expected_error_msg)); } TEST_F(GrpcDebugTest, ConnectionToDelayedStartingServerWorks) { @@ -149,7 +150,7 @@ TEST_F(GrpcDebugTest, SendSingleDebugTensorViaGrpcTest) { TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) { Tensor tensor(DT_STRING, TensorShape({1, 1})); - tensor.flat()(0) = string(5000 * 1024, 'A'); + tensor.flat()(0) = std::string(5000 * 1024, 'A'); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); const absl::Status status = DebugIO::PublishDebugTensor( @@ -158,14 +159,14 @@ TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) { ASSERT_NE(status.message().find("string value at index 0 from debug " "node foo_tensor:0:DebugIdentity does " "not fit gRPC message size limit"), - string::npos); + std::string::npos); TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url)); } TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) { Tensor tensor(DT_STRING, TensorShape({1, 2})); tensor.flat()(0) = "A"; - tensor.flat()(1) = string(5000 * 1024, 'A'); + tensor.flat()(1) = std::string(5000 * 1024, 'A'); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); const absl::Status status = DebugIO::PublishDebugTensor( @@ -174,7 +175,7 @@ TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) { ASSERT_NE(status.message().find("string value at index 1 from debug " "node foo_tensor:0:DebugIdentity does " "not fit gRPC message size limit"), - string::npos); + std::string::npos); TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url)); } @@ -197,7 +198,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { int tensor_count TF_GUARDED_BY(mu) = 0; std::vector statuses TF_GUARDED_BY(mu); - const std::vector urls({server_data_.url}); + const std::vector urls({server_data_.url}); // Set up the concurrent tasks of sending Tensors via an Event stream to the // server. @@ -210,7 +211,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { } // Different concurrent tasks will send different tensors. - const uint64 wall_time = Env::Default()->NowMicros(); + const uint64_t wall_time = Env::Default()->NowMicros(); absl::Status publish_status = DebugIO::PublishDebugTensor( DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", absl::StrCat("synchronized_node_", this_count), 0, @@ -247,7 +248,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { // One prep tensor plus kSends concurrent tensors are expected. ASSERT_EQ(kSends, server_data_.server->node_names.size()); for (size_t i = 0; i < server_data_.server->node_names.size(); ++i) { - std::vector items = + std::vector items = str_util::Split(server_data_.server->node_names[i], '_'); int tensor_index; strings::safe_strto32(items[2], &tensor_index); @@ -267,10 +268,10 @@ TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) { Tensor tensor(DT_INT32, TensorShape({1, 1})); tensor.flat()(0) = 42; - const std::vector urls({server_data_.url}); + const std::vector urls({server_data_.url}); for (int i = 0; i < 3; ++i) { server_data_.server->ClearReceivedDebugData(); - const uint64 wall_time = Env::Default()->NowMicros(); + const uint64_t wall_time = Env::Default()->NowMicros(); // On the 1st send (i == 0), gating is disabled, so data should be sent. // On the 2nd send (i == 1), gating is enabled, and the server has enabled @@ -315,10 +316,10 @@ TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) { Tensor tensor(DT_INT32, TensorShape({1, 1})); tensor.flat()(0) = 42; - const std::vector urls({server_data_.url}); + const std::vector urls({server_data_.url}); for (int i = 0; i < 3; ++i) { server_data_.server->ClearReceivedDebugData(); - const uint64 wall_time = Env::Default()->NowMicros(); + const uint64_t wall_time = Env::Default()->NowMicros(); // On the 1st send (i == 0), gating is disabled, so data should be sent. // On the 2nd send (i == 1), gating is enabled, and the server has enabled @@ -367,8 +368,8 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) { } TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) { - const string kGrpcUrl1 = "grpc://localhost:3333"; - const string kGrpcUrl2 = "grpc://localhost:3334"; + const std::string kGrpcUrl1 = "grpc://localhost:3333"; + const std::string kGrpcUrl2 = "grpc://localhost:3334"; DebugGrpcIO::SetDebugNodeKeyGrpcState( kGrpcUrl1, "foo:0:DebugIdentity", @@ -398,9 +399,9 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) { } TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) { - const string kGrpcUrl1 = "grpc://localhost:3333"; - const string kGrpcUrl2 = "grpc://localhost:3334"; - const string kGrpcUrl3 = "grpc://localhost:3335"; + const std::string kGrpcUrl1 = "grpc://localhost:3333"; + const std::string kGrpcUrl2 = "grpc://localhost:3334"; + const std::string kGrpcUrl3 = "grpc://localhost:3335"; DebugGrpcIO::SetDebugNodeKeyGrpcState( kGrpcUrl1, "foo:0:DebugIdentity", @@ -434,14 +435,14 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) { "grpc://localhost:3333", "foo:0:DebugIdentity", EventReply::DebugOpStateChange::READ_ONLY); - std::vector debug_urls_1; + std::vector debug_urls_1; ASSERT_FALSE( DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", debug_urls_1)); } TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) { - const string kGrpcUrl1 = "grpc://localhost:3333"; - const string kWatch1 = "foo:0:DebugIdentity"; + const std::string kGrpcUrl1 = "grpc://localhost:3333"; + const std::string kWatch1 = "foo:0:DebugIdentity"; ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen( {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); @@ -456,10 +457,10 @@ TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) { } TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) { - const string kGrpcUrl1 = "grpc://localhost:3333"; - const string kGrpcUrl2 = "grpc://localhost:3334"; - const string kWatch1 = "foo:0:DebugIdentity"; - const string kWatch2 = "foo:1:DebugIdentity"; + const std::string kGrpcUrl1 = "grpc://localhost:3333"; + const std::string kGrpcUrl2 = "grpc://localhost:3334"; + const std::string kWatch1 = "foo:0:DebugIdentity"; + const std::string kWatch2 = "foo:1:DebugIdentity"; DebugGrpcIO::SetDebugNodeKeyGrpcState( kGrpcUrl1, kWatch1, EventReply::DebugOpStateChange::READ_ONLY); diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc index 2b593ae6601cd1..0f3dfb8bb737f4 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.cc +++ b/tensorflow/core/debug/debug_grpc_testlib.cc @@ -44,11 +44,11 @@ ::grpc::Status TestEventListenerImpl::SendEvents( } else if (event.has_summary()) { const Summary::Value& val = event.summary().value(0); - std::vector name_items = + std::vector name_items = tensorflow::str_util::Split(val.node_name(), ':'); - const string node_name = name_items[0]; - const string debug_op = name_items[2]; + const std::string node_name = name_items[0]; + const std::string debug_op = name_items[2]; const TensorProto& tensor_proto = val.tensor(); Tensor tensor(tensor_proto.dtype()); @@ -156,7 +156,7 @@ void TestEventListenerImpl::StopServer() { } } -bool PollTillFirstRequestSucceeds(const string& server_url, +bool PollTillFirstRequestSucceeds(const std::string& server_url, const size_t max_attempts) { const int kSleepDurationMicros = 100 * 1000; size_t n_attempts = 0; @@ -168,7 +168,7 @@ bool PollTillFirstRequestSucceeds(const string& server_url, prep_tensor.flat()(0) = 42.0f; while (n_attempts++ < max_attempts) { - const uint64 wall_time = Env::Default()->NowMicros(); + const uint64_t wall_time = Env::Default()->NowMicros(); absl::Status publish_s = DebugIO::PublishDebugTensor( DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "prep_node", 0, "DebugIdentity"), diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h index 2a57df8d866331..415ce6435c7bdf 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.h +++ b/tensorflow/core/debug/debug_grpc_testlib.h @@ -48,12 +48,12 @@ class TestEventListenerImpl final : public grpc::EventListener::Service { const EventReply::DebugOpStateChange::State new_state, const DebugNodeKey& debug_node_key); - std::vector debug_metadata_strings; - std::vector encoded_graph_defs; - std::vector device_names; - std::vector node_names; - std::vector output_slots; - std::vector debug_ops; + std::vector debug_metadata_strings; + std::vector encoded_graph_defs; + std::vector device_names; + std::vector node_names; + std::vector output_slots; + std::vector debug_ops; std::vector debug_tensors; private: @@ -77,7 +77,7 @@ class TestEventListenerImpl final : public grpc::EventListener::Service { // // Returns: // Whether the polling succeeded within max_attempts. -bool PollTillFirstRequestSucceeds(const string& server_url, +bool PollTillFirstRequestSucceeds(const std::string& server_url, const size_t max_attempts); } // namespace test diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 50677be5fa3769..430bc36ea1a96c 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -66,8 +66,8 @@ constexpr absl::string_view kDumpSubDirName = "node-io-dump"; // shape). It does not set the value.tensor field, which should be set by the // caller separately. Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key, - const uint64 wall_time_us, const size_t num_chunks, - const size_t chunk_index, + const uint64_t wall_time_us, + const size_t num_chunks, const size_t chunk_index, const DataType& tensor_dtype, const TensorShapeProto& tensor_shape) { Event event; @@ -92,7 +92,7 @@ Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key, metadata.set_chunk_index(chunk_index); // Encode the data in JSON. - string json_output; + std::string json_output; tensorflow::protobuf::util::JsonPrintOptions json_options; json_options.always_print_fields_with_no_presence = true; auto status = tensorflow::protobuf::util::MessageToJsonString( @@ -120,7 +120,7 @@ Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key, // (i.e., an estimate that is usually too large, but never too small under the // gRPC message size limit) of the Varint-encoded length, to workaround the lack // of a portable length function. -const size_t StringValMaxBytesInProto(const string& str) { +const size_t StringValMaxBytesInProto(const std::string& str) { #if defined(PLATFORM_GOOGLE) return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize; #else @@ -131,11 +131,12 @@ const size_t StringValMaxBytesInProto(const string& str) { // Breaks a string Tensor (represented as a TensorProto) as a vector of Event // protos. absl::Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, - const uint64 wall_time_us, + const uint64_t wall_time_us, const size_t chunk_size_limit, TensorProto* tensor_proto, std::vector* events) { - const protobuf::RepeatedPtrField& strs = tensor_proto->string_val(); + const protobuf::RepeatedPtrField& strs = + tensor_proto->string_val(); const size_t num_strs = strs.size(); const size_t chunk_size_ub = chunk_size_limit > 0 ? chunk_size_limit @@ -191,7 +192,8 @@ absl::Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a // length-1 vector will be returned, regardless of the size of the tensor. absl::Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key, - const Tensor& tensor, const uint64 wall_time_us, + const Tensor& tensor, + const uint64_t wall_time_us, const size_t chunk_size_limit, std::vector* events) { TensorProto tensor_proto; @@ -237,10 +239,11 @@ absl::Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key, // TOCTOU race condition is not of concern here due to the fact that tfdbg // sets parallel_iterations attribute of all while_loops to 1 to prevent // the same node from between executed multiple times concurrently. -string AppendTimestampToFilePath(const string& in, const uint64 timestamp) { - string out = absl::StrCat(in, "_", timestamp); +std::string AppendTimestampToFilePath(const std::string& in, + const uint64_t timestamp) { + std::string out = absl::StrCat(in, "_", timestamp); - uint64 i = 1; + uint64_t i = 1; while (Env::Default()->FileExists(out).ok()) { out = strings::StrCat(in, "_", timestamp, "-", i); ++i; @@ -251,11 +254,10 @@ string AppendTimestampToFilePath(const string& in, const uint64 timestamp) { #ifndef PLATFORM_WINDOWS // Publishes encoded GraphDef through a gRPC debugger stream, in chunks, // conforming to the gRPC message size limit. -absl::Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, - const string& device_name, - const int64_t wall_time, - const string& debug_url) { - const uint64 hash = ::tensorflow::Hash64(encoded_graph_def); +absl::Status PublishEncodedGraphDefInChunks( + const std::string& encoded_graph_def, const std::string& device_name, + const int64_t wall_time, const std::string& debug_url) { + const uint64_t hash = ::tensorflow::Hash64(encoded_graph_def); const size_t total_length = encoded_graph_def.size(); const size_t num_chunks = static_cast(std::ceil(static_cast(total_length) / @@ -297,11 +299,12 @@ const char* const DebugIO::kGraphTag = "graph_"; const char* const DebugIO::kHashTag = "hash"; -absl::Status ReadEventFromFile(const string& dump_file_path, Event* event) { +absl::Status ReadEventFromFile(const std::string& dump_file_path, + Event* event) { Env* env(Env::Default()); - string content; - uint64 file_size = 0; + std::string content; + uint64_t file_size = 0; absl::Status s = env->GetFileSize(dump_file_path, &file_size); if (!s.ok()) { @@ -333,10 +336,11 @@ const char* const DebugIO::kMemoryURLScheme = "memcbk://"; // Publishes debug metadata to a set of debug URLs. absl::Status DebugIO::PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, - const int64_t executor_step_index, const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - const std::unordered_set& debug_urls) { + const int64_t executor_step_index, + const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + const std::unordered_set& debug_urls) { std::ostringstream oss; // Construct a JSON string to carry the metadata. @@ -370,24 +374,24 @@ absl::Status DebugIO::PublishDebugMetadata( oss << "]"; oss << "}"; - const string json_metadata = oss.str(); + const std::string json_metadata = oss.str(); Event event; event.set_wall_time(static_cast(Env::Default()->NowMicros())); LogMessage* log_message = event.mutable_log_message(); log_message->set_message(json_metadata); absl::Status status; - for (const string& url : debug_urls) { + for (const std::string& url : debug_urls) { if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) { #ifndef PLATFORM_WINDOWS Event grpc_event; // Determine the path (if any) in the grpc:// URL, and add it as a field // of the JSON string. - const string address = url.substr(strlen(DebugIO::kFileURLScheme)); - const string path = address.find('/') == string::npos - ? "" - : address.substr(address.find('/')); + const std::string address = url.substr(strlen(DebugIO::kFileURLScheme)); + const std::string path = address.find('/') == std::string::npos + ? "" + : address.substr(address.find('/')); grpc_event.set_wall_time(event.wall_time()); LogMessage* log_message_grpc = grpc_event.mutable_log_message(); log_message_grpc->set_message( @@ -400,8 +404,8 @@ absl::Status DebugIO::PublishDebugMetadata( GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; #endif } else if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) { - const string dump_root_dir = url.substr(strlen(kFileURLScheme)); - const string core_metadata_path = AppendTimestampToFilePath( + const std::string dump_root_dir = url.substr(strlen(kFileURLScheme)); + const std::string core_metadata_path = AppendTimestampToFilePath( io::JoinPath( dump_root_dir, absl::StrCat(DebugNodeKey::kMetadataFilePrefix, @@ -410,8 +414,8 @@ absl::Status DebugIO::PublishDebugMetadata( session_run_index)))), Env::Default()->NowMicros()); status.Update(DebugFileIO::DumpEventProtoToFile( - event, string(io::Dirname(core_metadata_path)), - string(io::Basename(core_metadata_path)))); + event, std::string(io::Dirname(core_metadata_path)), + std::string(io::Basename(core_metadata_path)))); } } @@ -420,13 +424,13 @@ absl::Status DebugIO::PublishDebugMetadata( absl::Status DebugIO::PublishDebugTensor( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const absl::Span debug_urls, + const uint64_t wall_time_us, const absl::Span debug_urls, const bool gated_grpc, const int64_t step_id) { int32_t num_failed_urls = 0; std::vector fail_statuses; - for (const string& url : debug_urls) { + for (const std::string& url : debug_urls) { if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) { - const string dump_root_dir = url.substr(strlen(kFileURLScheme)); + const std::string dump_root_dir = url.substr(strlen(kFileURLScheme)); const int64_t tensorBytes = tensor.IsInitialized() ? tensor.TotalBytes() : 0; @@ -465,7 +469,7 @@ absl::Status DebugIO::PublishDebugTensor( GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; #endif } else if (absl::StartsWith(absl::AsciiStrToLower(url), kMemoryURLScheme)) { - const string dump_root_dir = url.substr(strlen(kMemoryURLScheme)); + const std::string dump_root_dir = url.substr(strlen(kMemoryURLScheme)); auto* callback_registry = DebugCallbackRegistry::singleton(); auto* callback = callback_registry->GetCallback(dump_root_dir); CHECK(callback) << "No callback registered for: " << dump_root_dir; @@ -479,7 +483,7 @@ absl::Status DebugIO::PublishDebugTensor( if (num_failed_urls == 0) { return absl::OkStatus(); } else { - string error_message = strings::StrCat( + std::string error_message = strings::StrCat( "Publishing to ", num_failed_urls, " of ", debug_urls.size(), " debug target URLs failed, due to the following errors:"); for (absl::Status& status : fail_statuses) { @@ -492,18 +496,19 @@ absl::Status DebugIO::PublishDebugTensor( absl::Status DebugIO::PublishDebugTensor( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const absl::Span debug_urls) { + const uint64_t wall_time_us, + const absl::Span debug_urls) { return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls, false); } absl::Status DebugIO::PublishGraph( - const Graph& graph, const string& device_name, - const std::unordered_set& debug_urls) { + const Graph& graph, const std::string& device_name, + const std::unordered_set& debug_urls) { GraphDef graph_def; graph.ToGraphDef(&graph_def); - string buf; + std::string buf; graph_def.SerializeToString(&buf); const int64_t now_micros = Env::Default()->NowMicros(); @@ -512,13 +517,13 @@ absl::Status DebugIO::PublishGraph( event.set_graph_def(buf); absl::Status status = absl::OkStatus(); - for (const string& debug_url : debug_urls) { + for (const std::string& debug_url : debug_urls) { if (absl::StartsWith(debug_url, kFileURLScheme)) { - const string dump_root_dir = + const std::string dump_root_dir = io::JoinPath(debug_url.substr(strlen(kFileURLScheme)), DebugNodeKey::DeviceNameToDevicePath(device_name)); - const uint64 graph_hash = ::tensorflow::Hash64(buf); - const string file_name = + const uint64_t graph_hash = ::tensorflow::Hash64(buf); + const std::string file_name = strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag, DebugIO::kHashTag, graph_hash, "_", now_micros); @@ -556,10 +561,10 @@ bool DebugIO::IsCopyNodeGateOpen( #endif } -bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, - const std::vector& debug_urls) { +bool DebugIO::IsDebugNodeGateOpen(const std::string& watch_key, + const std::vector& debug_urls) { #ifndef PLATFORM_WINDOWS - for (const string& debug_url : debug_urls) { + for (const std::string& debug_url : debug_urls) { if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme), DebugIO::kGrpcURLScheme)) { return true; @@ -575,8 +580,8 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, #endif } -bool DebugIO::IsDebugURLGateOpen(const string& watch_key, - const string& debug_url) { +bool DebugIO::IsDebugURLGateOpen(const std::string& watch_key, + const std::string& debug_url) { #ifndef PLATFORM_WINDOWS if (debug_url != kGrpcURLScheme) { return true; @@ -588,7 +593,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key, #endif } -absl::Status DebugIO::CloseDebugURL(const string& debug_url) { +absl::Status DebugIO::CloseDebugURL(const std::string& debug_url) { if (absl::StartsWith(debug_url, DebugIO::kGrpcURLScheme)) { #ifndef PLATFORM_WINDOWS return DebugGrpcIO::CloseGrpcStream(debug_url); @@ -603,10 +608,10 @@ absl::Status DebugIO::CloseDebugURL(const string& debug_url) { absl::Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, - const string& dump_root_dir, - string* dump_file_path) { - const string file_path = + const uint64_t wall_time_us, + const std::string& dump_root_dir, + std::string* dump_file_path) { + const std::string file_path = GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us); if (dump_file_path != nullptr) { @@ -618,9 +623,9 @@ absl::Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, absl::Status DebugFileIO::DumpTensorToDirForNodeDumping( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const string& dump_root_dir, - string* dump_file_path, const int64_t step_id) { - const string file_path = GetDumpFilePathForNodeDumping( + const uint64_t wall_time_us, const std::string& dump_root_dir, + std::string* dump_file_path, const int64_t step_id) { + const std::string file_path = GetDumpFilePathForNodeDumping( dump_root_dir, debug_node_key, wall_time_us, step_id); if (dump_file_path != nullptr) { *dump_file_path = file_path; @@ -629,9 +634,9 @@ absl::Status DebugFileIO::DumpTensorToDirForNodeDumping( return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path); } -string DebugFileIO::GetDumpFilePath(const string& dump_root_dir, - const DebugNodeKey& debug_node_key, - const uint64 wall_time_us) { +std::string DebugFileIO::GetDumpFilePath(const std::string& dump_root_dir, + const DebugNodeKey& debug_node_key, + const uint64_t wall_time_us) { return AppendTimestampToFilePath( io::JoinPath(dump_root_dir, debug_node_key.device_path, strings::StrCat(debug_node_key.node_name, "_", @@ -640,9 +645,9 @@ string DebugFileIO::GetDumpFilePath(const string& dump_root_dir, wall_time_us); } -string DebugFileIO::GetDumpFilePathForNodeDumping( - const string& dump_root_dir, const DebugNodeKey& debug_node_key, - const uint64 wall_time_us, const int64_t step_id) { +std::string DebugFileIO::GetDumpFilePathForNodeDumping( + const std::string& dump_root_dir, const DebugNodeKey& debug_node_key, + const uint64_t wall_time_us, const int64_t step_id) { return AppendTimestampToFilePath( io::JoinPath( dump_root_dir, kDumpSubDirName, absl::StrCat("step-", step_id), @@ -654,8 +659,8 @@ string DebugFileIO::GetDumpFilePathForNodeDumping( } absl::Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, - const string& dir_name, - const string& file_name) { + const std::string& dir_name, + const std::string& file_name) { Env* env(Env::Default()); absl::Status s = RecursiveCreateDir(env, dir_name); @@ -665,9 +670,9 @@ absl::Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, ", due to: ", s.message())); } - const string file_path = io::JoinPath(dir_name, file_name); + const std::string file_path = io::JoinPath(dir_name, file_name); - string event_str; + std::string event_str; event_proto.SerializeToString(&event_str); std::unique_ptr f = nullptr; @@ -680,21 +685,21 @@ absl::Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, absl::Status DebugFileIO::DumpTensorToEventFile( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const string& file_path) { + const uint64_t wall_time_us, const std::string& file_path) { std::vector events; TF_RETURN_IF_ERROR( WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events)); - return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)), - string(io::Basename(file_path))); + return DumpEventProtoToFile(events[0], std::string(io::Dirname(file_path)), + std::string(io::Basename(file_path))); } -absl::Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { +absl::Status DebugFileIO::RecursiveCreateDir(Env* env, const std::string& dir) { if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { // The path already exists as a directory. Return OK right away. return absl::OkStatus(); } - string parent_dir(io::Dirname(dir)); + std::string parent_dir(io::Dirname(dir)); if (!env->FileExists(parent_dir).ok()) { // The parent path does not exist yet, create it first. absl::Status s = RecursiveCreateDir(env, parent_dir); // Recursive call @@ -724,13 +729,13 @@ absl::Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { } // Default total disk usage limit: 100 GBytes -const uint64 DebugFileIO::kDefaultGlobalDiskBytesLimit = 107374182400L; -uint64 DebugFileIO::global_disk_bytes_limit_ = 0; -uint64 DebugFileIO::disk_bytes_used_ = 0; +const uint64_t DebugFileIO::kDefaultGlobalDiskBytesLimit = 107374182400L; +uint64_t DebugFileIO::global_disk_bytes_limit_ = 0; +uint64_t DebugFileIO::disk_bytes_used_ = 0; mutex DebugFileIO::bytes_mu_(LINKER_INITIALIZED); -bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { +bool DebugFileIO::requestDiskByteUsage(uint64_t bytes) { mutex_lock l(bytes_mu_); if (global_disk_bytes_limit_ == 0) { const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT"); @@ -760,13 +765,13 @@ void DebugFileIO::resetDiskByteUsage() { } #ifndef PLATFORM_WINDOWS -DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr) +DebugGrpcChannel::DebugGrpcChannel(const std::string& server_stream_addr) : server_stream_addr_(server_stream_addr), url_(absl::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {} absl::Status DebugGrpcChannel::Connect(const int64_t timeout_micros) { ::grpc::ChannelArguments args; - args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits::max()); + args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits::max()); // Avoid problems where default reconnect backoff is too long (e.g., 20 s). args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000); channel_ = ::grpc::CreateCustomChannel( @@ -801,9 +806,10 @@ void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) { ReadEventReply(&event_reply)) { for (const EventReply::DebugOpStateChange& debug_op_state_change : event_reply.debug_op_state_changes()) { - string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":", - debug_op_state_change.output_slot(), - ":", debug_op_state_change.debug_op()); + std::string watch_key = + strings::StrCat(debug_op_state_change.node_name(), ":", + debug_op_state_change.output_slot(), ":", + debug_op_state_change.debug_op()); DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key, debug_op_state_change.state()); } @@ -832,17 +838,17 @@ const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024; const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6; -std::unordered_map>* +std::unordered_map>* DebugGrpcIO::GetStreamChannels() { - static std::unordered_map>* - stream_channels = - new std::unordered_map>(); + static std::unordered_map< + std::string, std::unique_ptr>* stream_channels = + new std::unordered_map>(); return stream_channels; } absl::Status DebugGrpcIO::SendTensorThroughGrpcStream( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const string& grpc_stream_url, + const uint64_t wall_time_us, const std::string& grpc_stream_url, const bool gated) { if (gated && !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { @@ -868,7 +874,7 @@ absl::Status DebugGrpcIO::SendTensorThroughGrpcStream( } absl::Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( - EventReply* event_reply, const string& grpc_stream_url) { + EventReply* event_reply, const std::string& grpc_stream_url) { DebugGrpcChannel* debug_grpc_channel = nullptr; TF_RETURN_IF_ERROR( GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); @@ -881,16 +887,16 @@ absl::Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( } absl::Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( - const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) { - const string addr_with_path = + const std::string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) { + const std::string addr_with_path = absl::StartsWith(grpc_stream_url, DebugIO::kGrpcURLScheme) ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme)) : grpc_stream_url; - const string server_stream_addr = + const std::string server_stream_addr = addr_with_path.substr(0, addr_with_path.find('/')); { mutex_lock l(streams_mu_); - std::unordered_map>* + std::unordered_map>* stream_channels = GetStreamChannels(); if (stream_channels->find(grpc_stream_url) == stream_channels->end()) { std::unique_ptr channel( @@ -905,7 +911,7 @@ absl::Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( } absl::Status DebugGrpcIO::SendEventProtoThroughGrpcStream( - const Event& event_proto, const string& grpc_stream_url, + const Event& event_proto, const std::string& grpc_stream_url, const bool receive_reply) { DebugGrpcChannel* debug_grpc_channel; TF_RETURN_IF_ERROR( @@ -924,15 +930,15 @@ absl::Status DebugGrpcIO::SendEventProtoThroughGrpcStream( return absl::OkStatus(); } -bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url, - const string& watch_key) { +bool DebugGrpcIO::IsReadGateOpen(const std::string& grpc_debug_url, + const std::string& watch_key) { const DebugNodeName2State* enabled_node_to_state = GetEnabledDebugOpStatesAtUrl(grpc_debug_url); return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end(); } -bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url, - const string& watch_key) { +bool DebugGrpcIO::IsWriteGateOpen(const std::string& grpc_debug_url, + const std::string& watch_key) { const DebugNodeName2State* enabled_node_to_state = GetEnabledDebugOpStatesAtUrl(grpc_debug_url); auto it = enabled_node_to_state->find(watch_key); @@ -943,10 +949,10 @@ bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url, } } -absl::Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { +absl::Status DebugGrpcIO::CloseGrpcStream(const std::string& grpc_stream_url) { mutex_lock l(streams_mu_); - std::unordered_map>* + std::unordered_map>* stream_channels = GetStreamChannels(); if (stream_channels->find(grpc_stream_url) != stream_channels->end()) { // Stream of the specified address exists. Close it and remove it from @@ -961,18 +967,18 @@ absl::Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { } } -std::unordered_map* +std::unordered_map* DebugGrpcIO::GetEnabledDebugOpStates() { - static std::unordered_map* + static std::unordered_map* enabled_debug_op_states = - new std::unordered_map(); + new std::unordered_map(); return enabled_debug_op_states; } DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl( - const string& grpc_debug_url) { + const std::string& grpc_debug_url) { static mutex* debug_ops_state_mu = new mutex(); - std::unordered_map* states = + std::unordered_map* states = GetEnabledDebugOpStates(); mutex_lock l(*debug_ops_state_mu); @@ -984,7 +990,7 @@ DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl( } void DebugGrpcIO::SetDebugNodeKeyGrpcState( - const string& grpc_debug_url, const string& watch_key, + const std::string& grpc_debug_url, const std::string& watch_key, const EventReply::DebugOpStateChange::State new_state) { DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url); if (new_state == EventReply::DebugOpStateChange::DISABLED) { diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h index 95864c714682b6..99107971f0f2b4 100644 --- a/tensorflow/core/debug/debug_io_utils.h +++ b/tensorflow/core/debug/debug_io_utils.h @@ -36,15 +36,15 @@ limitations under the License. namespace tensorflow { -absl::Status ReadEventFromFile(const string& dump_file_path, Event* event); +absl::Status ReadEventFromFile(const std::string& dump_file_path, Event* event); struct DebugWatchAndURLSpec { - DebugWatchAndURLSpec(const string& watch_key, const string& url, + DebugWatchAndURLSpec(const std::string& watch_key, const std::string& url, const bool gated_grpc) : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {} - const string watch_key; - const string url; + const std::string watch_key; + const std::string url; const bool gated_grpc; }; @@ -63,10 +63,11 @@ class DebugIO { static absl::Status PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, - const int64_t executor_step_index, const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - const std::unordered_set& debug_urls); + const int64_t executor_step_index, + const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + const std::unordered_set& debug_urls); // Publishes a tensor to a debug target URL. // @@ -82,13 +83,15 @@ class DebugIO { // step_id: Step ID associated with the tensor. static absl::Status PublishDebugTensor( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const absl::Span debug_urls, - bool gated_grpc, int64_t step_id = -1); + const uint64_t wall_time_us, + const absl::Span debug_urls, bool gated_grpc, + int64_t step_id = -1); // Convenience overload of the method above for no gated_grpc by default. static absl::Status PublishDebugTensor( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const absl::Span debug_urls); + const uint64_t wall_time_us, + const absl::Span debug_urls); // Publishes a graph to a set of debug URLs. // @@ -96,8 +99,8 @@ class DebugIO { // graph: The graph to be published. // debug_urls: The set of debug URLs to publish the graph to. static absl::Status PublishGraph( - const Graph& graph, const string& device_name, - const std::unordered_set& debug_urls); + const Graph& graph, const std::string& device_name, + const std::unordered_set& debug_urls); // Determines whether a copy node needs to perform deep-copy of input tensor. // @@ -126,8 +129,8 @@ class DebugIO { // // Returns: // Whether this debug op should proceed. - static bool IsDebugNodeGateOpen(const string& watch_key, - const std::vector& debug_urls); + static bool IsDebugNodeGateOpen(const std::string& watch_key, + const std::vector& debug_urls); // Determines whether debug information should be sent through a grpc:// // debug URL given the current gRPC gating status. @@ -141,10 +144,10 @@ class DebugIO { // Returns: // Whether the sending of debug data to the debug_url should // proceed. - static bool IsDebugURLGateOpen(const string& watch_key, - const string& debug_url); + static bool IsDebugURLGateOpen(const std::string& watch_key, + const std::string& debug_url); - static absl::Status CloseDebugURL(const string& debug_url); + static absl::Status CloseDebugURL(const std::string& debug_url); }; // Helper class for debug ops. @@ -171,15 +174,15 @@ class DebugFileIO { // dump_file_path: The actual dump file path (passed as reference). static absl::Status DumpTensorToDir(const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, - const string& dump_root_dir, - string* dump_file_path); + const uint64_t wall_time_us, + const std::string& dump_root_dir, + std::string* dump_file_path); // Similar to the above, but for node inputs/outputs dumping feature. static absl::Status DumpTensorToDirForNodeDumping( const DebugNodeKey& debug_node_key, const Tensor& tensor, - uint64 wall_time_us, const string& dump_root_dir, string* dump_file_path, - int64_t step_id); + uint64_t wall_time_us, const std::string& dump_root_dir, + std::string* dump_file_path, int64_t step_id); // Get the full path to the dump file. // @@ -190,14 +193,14 @@ class DebugFileIO { // output_slot: Output slot index of the said node, e.g., 0. // debug_op: Name of the debug op, e.g., DebugIdentity. // wall_time_us: Time stamp of the dumped tensor, in microseconds (us). - static string GetDumpFilePath(const string& dump_root_dir, - const DebugNodeKey& debug_node_key, - const uint64 wall_time_us); + static std::string GetDumpFilePath(const std::string& dump_root_dir, + const DebugNodeKey& debug_node_key, + const uint64_t wall_time_us); // Similar to the above, but for node inputs/outputs dumping feature. - static string GetDumpFilePathForNodeDumping( - const string& dump_root_dir, const DebugNodeKey& debug_node_key, - uint64 wall_time_us, int64_t step_id); + static std::string GetDumpFilePathForNodeDumping( + const std::string& dump_root_dir, const DebugNodeKey& debug_node_key, + uint64_t wall_time_us, int64_t step_id); // Dumps an Event proto to a file. // @@ -206,8 +209,8 @@ class DebugFileIO { // dir_name: Directory path. // file_name: Base file name. static absl::Status DumpEventProtoToFile(const Event& event_proto, - const string& dir_name, - const string& file_name); + const std::string& dir_name, + const std::string& file_name); // Request additional bytes to be dumped to the file system. // @@ -222,31 +225,31 @@ class DebugFileIO { // Returns: // Whether the request is approved given the total dumping // limit. - static bool requestDiskByteUsage(uint64 bytes); + static bool requestDiskByteUsage(uint64_t bytes); // Reset the disk byte usage to zero. static void resetDiskByteUsage(); - static uint64 global_disk_bytes_limit_; + static uint64_t global_disk_bytes_limit_; private: // Encapsulates the Tensor in an Event protobuf and write it to file. static absl::Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, - const string& file_path); + const uint64_t wall_time_us, + const std::string& file_path); // Implemented ad hoc here for now. // TODO(cais): Replace with shared implementation once http://b/30497715 is // fixed. - static absl::Status RecursiveCreateDir(Env* env, const string& dir); + static absl::Status RecursiveCreateDir(Env* env, const std::string& dir); // Tracks how much disk has been used so far. - static uint64 disk_bytes_used_; + static uint64_t disk_bytes_used_; // Mutex for thread-safe access to disk_bytes_used_. static mutex bytes_mu_; // Default limit for the disk space. - static const uint64 kDefaultGlobalDiskBytesLimit; + static const uint64_t kDefaultGlobalDiskBytesLimit; friend class DiskUsageLimitTest; }; @@ -282,7 +285,7 @@ class DebugGrpcChannel { // server_stream_addr: Address (host name and port) of the debug stream // server implementing the EventListener service (see // debug_service.proto). E.g., "127.0.0.1:12345". - explicit DebugGrpcChannel(const string& server_stream_addr); + explicit DebugGrpcChannel(const std::string& server_stream_addr); virtual ~DebugGrpcChannel() {} @@ -337,8 +340,8 @@ class DebugGrpcChannel { absl::Status ReceiveServerRepliesAndClose(); private: - string server_stream_addr_; - string url_; + std::string server_stream_addr_; + std::string url_; ::grpc::ClientContext ctx_; std::shared_ptr<::grpc::Channel> channel_; std::unique_ptr stub_; @@ -356,7 +359,7 @@ class DebugGrpcIO { // Sends a tensor through a debug gRPC stream. static absl::Status SendTensorThroughGrpcStream( const DebugNodeKey& debug_node_key, const Tensor& tensor, - const uint64 wall_time_us, const string& grpc_stream_url, + const uint64_t wall_time_us, const std::string& grpc_stream_url, const bool gated); // Sends an Event proto through a debug gRPC stream. @@ -373,40 +376,40 @@ class DebugGrpcIO { // Returns: // The Status of the operation. static absl::Status SendEventProtoThroughGrpcStream( - const Event& event_proto, const string& grpc_stream_url, + const Event& event_proto, const std::string& grpc_stream_url, const bool receive_reply = false); // Receive an EventReply proto through a debug gRPC stream. static absl::Status ReceiveEventReplyProtoThroughGrpcStream( - EventReply* event_reply, const string& grpc_stream_url); + EventReply* event_reply, const std::string& grpc_stream_url); // Check whether a debug watch key is read-activated at a given gRPC URL. - static bool IsReadGateOpen(const string& grpc_debug_url, - const string& watch_key); + static bool IsReadGateOpen(const std::string& grpc_debug_url, + const std::string& watch_key); // Check whether a debug watch key is write-activated (i.e., read- and // write-activated) at a given gRPC URL. - static bool IsWriteGateOpen(const string& grpc_debug_url, - const string& watch_key); + static bool IsWriteGateOpen(const std::string& grpc_debug_url, + const std::string& watch_key); // Closes a gRPC stream to the given address, if it exists. // Thread-safety: Safe with respect to other calls to the same method and // calls to SendTensorThroughGrpcStream(). - static absl::Status CloseGrpcStream(const string& grpc_stream_url); + static absl::Status CloseGrpcStream(const std::string& grpc_stream_url); // Set the gRPC state of a debug node key. // TODO(cais): Include device information in watch_key. static void SetDebugNodeKeyGrpcState( - const string& grpc_debug_url, const string& watch_key, + const std::string& grpc_debug_url, const std::string& watch_key, const EventReply::DebugOpStateChange::State new_state); private: using DebugNodeName2State = - std::unordered_map; + std::unordered_map; // Returns a global map from grpc debug URLs to the corresponding // DebugGrpcChannels. - static std::unordered_map>* + static std::unordered_map>* GetStreamChannels(); // Get a DebugGrpcChannel object at a given URL, creating one if necessary. @@ -420,15 +423,16 @@ class DebugGrpcIO { // Returns: // Status of this operation. static absl::Status GetOrCreateDebugGrpcChannel( - const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel); + const std::string& grpc_stream_url, + DebugGrpcChannel** debug_grpc_channel); // Returns a map from debug URL to a map from debug op name to enabled state. - static std::unordered_map* + static std::unordered_map* GetEnabledDebugOpStates(); // Returns a map from debug op names to enabled state, for a given debug URL. static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl( - const string& grpc_debug_url); + const std::string& grpc_debug_url); // Clear enabled debug op state from all debug URLs (if any). static void ClearEnabledWatchKeys(); diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index d09465d80e5a01..fde63f53331cf1 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -107,7 +107,7 @@ TEST_F(DebugIOUtilsTest, DebugNodeKeysIsHashable) { TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) { Initialize(); - const string test_dir = + const std::string test_dir = absl::StrCat(testing::TmpDir(), "/DumpFloatTensorToFileSunnyDay"); if (!env_->FileExists(test_dir).ok()) { ASSERT_TRUE(env_->RecursivelyCreateDir(test_dir).ok()); @@ -115,11 +115,11 @@ TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) { // Append levels of nonexisting directories, to test that the function can // create directories. - const uint64 wall_time = env_->NowMicros(); + const uint64_t wall_time = env_->NowMicros(); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo/bar/qux/tensor_a", 0, "DebugIdentity"); - string dump_file_path; + std::string dump_file_path; TF_ASSERT_OK(DebugFileIO::DumpTensorToDir( kDebugNodeKey, *tensor_a_, wall_time, test_dir, &dump_file_path)); @@ -154,16 +154,16 @@ TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) { TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) { Initialize(); - const string test_dir = + const std::string test_dir = absl::StrCat(testing::TmpDir(), "/DumpStringTensorToFileSunnyDay"); if (!env_->FileExists(test_dir).ok()) { ASSERT_TRUE(env_->RecursivelyCreateDir(test_dir).ok()); } const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "quux/grault/tensor_b", 1, "DebugIdentity"); - const uint64 wall_time = env_->NowMicros(); + const uint64_t wall_time = env_->NowMicros(); - string dump_file_name; + std::string dump_file_name; absl::Status s = DebugFileIO::DumpTensorToDir( kDebugNodeKey, *tensor_b_, wall_time, test_dir, &dump_file_name); ASSERT_TRUE(s.ok()); @@ -209,17 +209,17 @@ TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) { Initialize(); // First, create the file at the path. - const string test_dir = + const std::string test_dir = absl::StrCat(testing::TmpDir(), "/DumpTensorToFileCannotCreateDirectory"); if (!env_->FileExists(test_dir).ok()) { ASSERT_TRUE(env_->RecursivelyCreateDir(test_dir).ok()); } - const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0"; + const std::string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0"; const DebugNodeKey kDebugNodeKey(kDeviceName, "baz/tensor_a", 0, "DebugIdentity"); - const string txt_file_dir = + const std::string txt_file_dir = io::JoinPath(test_dir, DebugNodeKey::DeviceNameToDevicePath(kDeviceName)); - const string txt_file_name = io::JoinPath(txt_file_dir, "baz"); + const std::string txt_file_name = io::JoinPath(txt_file_dir, "baz"); if (!env_->FileExists(txt_file_dir).ok()) { ASSERT_TRUE(env_->RecursivelyCreateDir(txt_file_dir).ok()); } @@ -238,9 +238,9 @@ TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) { // Second, try to dump the tensor to a path that requires "baz" to be a // directory, which should lead to an error. - const uint64 wall_time = env_->NowMicros(); + const uint64_t wall_time = env_->NowMicros(); - string dump_file_name; + std::string dump_file_name; absl::Status s = DebugFileIO::DumpTensorToDir( kDebugNodeKey, *tensor_a_, wall_time, test_dir, &dump_file_name); ASSERT_FALSE(s.ok()); @@ -261,13 +261,13 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) { const int kNumDumpRoots = 3; const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo/bar/qux/tensor_a", 0, "DebugIdentity"); - const uint64 wall_time = env_->NowMicros(); + const uint64_t wall_time = env_->NowMicros(); - std::vector dump_roots; - std::vector dump_file_paths; - std::vector urls; + std::vector dump_roots; + std::vector dump_file_paths; + std::vector urls; for (int i = 0; i < kNumDumpRoots; ++i) { - string dump_root = + std::string dump_root = absl::StrCat(testing::TmpDir(), "/PublicTensorToMultipleFileUrls_", i); dump_roots.push_back(dump_root); @@ -331,10 +331,10 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMemoryCallback) { const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo/bar/qux/tensor_a", 0, "DebugIdentity"); - const uint64 wall_time = env_->NowMicros(); + const uint64_t wall_time = env_->NowMicros(); bool called = false; - std::vector urls = {"memcbk://test_callback"}; + std::vector urls = {"memcbk://test_callback"}; ; auto* callback_registry = DebugCallbackRegistry::singleton(); @@ -367,8 +367,8 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", kConcurrentPubs); - const uint64 wall_time = env_->NowMicros(); - const string dump_root_base = + const uint64_t wall_time = env_->NowMicros(); + const std::string dump_root_base = absl::StrCat(testing::TmpDir(), "/PublishTensorConcurrentlyToPartiallyOverlappingPaths"); if (!env_->FileExists(dump_root_base).ok()) { @@ -376,8 +376,8 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { } mutex mu; - std::vector dump_roots TF_GUARDED_BY(mu); - std::vector dump_file_paths TF_GUARDED_BY(mu); + std::vector dump_roots TF_GUARDED_BY(mu); + std::vector dump_file_paths TF_GUARDED_BY(mu); int dump_count TF_GUARDED_BY(mu) = 0; int done_count TF_GUARDED_BY(mu) = 0; @@ -387,8 +387,8 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { &dump_file_paths, &wall_time, &kDebugNodeKey, &kConcurrentPubs, &all_done]() { // "gumpy" is the shared directory part of the path. - string dump_root; - string debug_url; + std::string dump_root; + std::string debug_url; { mutex_lock l(mu); dump_root = @@ -401,7 +401,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { debug_url = absl::StrCat("file://", dump_root); } - std::vector urls; + std::vector urls; urls.push_back(debug_url); absl::Status s = diff --git a/tensorflow/core/debug/debug_node_key.cc b/tensorflow/core/debug/debug_node_key.cc index 1fa51f138c2f6f..09510b8df1bfb8 100644 --- a/tensorflow/core/debug/debug_node_key.cc +++ b/tensorflow/core/debug/debug_node_key.cc @@ -26,9 +26,11 @@ const char* const DebugNodeKey::kMetadataFilePrefix = "_tfdbg_"; const char* const DebugNodeKey::kDeviceTag = "device_"; -DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name, - const int32_t output_slot, const string& debug_op, - const string& io_of_node, const bool is_input, +DebugNodeKey::DebugNodeKey(const std::string& device_name, + const std::string& node_name, + const int32_t output_slot, + const std::string& debug_op, + const std::string& io_of_node, const bool is_input, const int32_t io_index) : device_name(device_name), node_name(node_name), @@ -52,7 +54,8 @@ bool DebugNodeKey::operator!=(const DebugNodeKey& other) const { return !((*this) == other); } -const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) { +const std::string DebugNodeKey::DeviceNameToDevicePath( + const std::string& device_name) { return absl::StrCat(kMetadataFilePrefix, kDeviceTag, str_util::StringReplace( str_util::StringReplace(device_name, ":", "_", true), diff --git a/tensorflow/core/debug/debug_node_key.h b/tensorflow/core/debug/debug_node_key.h index 5decb5cc683643..867e0809314324 100644 --- a/tensorflow/core/debug/debug_node_key.h +++ b/tensorflow/core/debug/debug_node_key.h @@ -27,28 +27,29 @@ struct DebugNodeKey { static const char* const kMetadataFilePrefix; static const char* const kDeviceTag; - DebugNodeKey(const string& device_name, const string& node_name, - int32_t output_slot, const string& debug_op, - const string& io_of_node = "", bool is_input = false, + DebugNodeKey(const std::string& device_name, const std::string& node_name, + int32_t output_slot, const std::string& debug_op, + const std::string& io_of_node = "", bool is_input = false, int32_t io_index = -1); // Converts a device name string to a device path string. // E.g., /job:localhost/replica:0/task:0/cpu:0 will be converted to // ,job_localhost,replica_0,task_0,cpu_0. - static const string DeviceNameToDevicePath(const string& device_name); + static const std::string DeviceNameToDevicePath( + const std::string& device_name); bool operator==(const DebugNodeKey& other) const; bool operator!=(const DebugNodeKey& other) const; - const string device_name; - const string node_name; - const int32 output_slot; - const string debug_op; - const string debug_node_name; - const string device_path; - const string io_of_node; + const std::string device_name; + const std::string node_name; + const int32_t output_slot; + const std::string debug_op; + const std::string debug_node_name; + const std::string device_path; + const std::string io_of_node; const bool is_input; - const int32 io_index; + const int32_t io_index; }; } // namespace tensorflow diff --git a/tensorflow/core/debug/debugger_state_impl.cc b/tensorflow/core/debug/debugger_state_impl.cc index a1545ad1aa1516..23b70b431d8dd0 100644 --- a/tensorflow/core/debug/debugger_state_impl.cc +++ b/tensorflow/core/debug/debugger_state_impl.cc @@ -23,7 +23,7 @@ namespace tensorflow { DebuggerState::DebuggerState(const DebugOptions& debug_options) { for (const DebugTensorWatch& watch : debug_options.debug_tensor_watch_opts()) { - for (const string& url : watch.debug_urls()) { + for (const std::string& url : watch.debug_urls()) { debug_urls_.insert(url); } } @@ -33,16 +33,17 @@ DebuggerState::DebuggerState(const DebugOptions& debug_options) { } DebuggerState::~DebuggerState() { - for (const string& debug_url : debug_urls_) { + for (const std::string& debug_url : debug_urls_) { DebugIO::CloseDebugURL(debug_url).IgnoreError(); } } absl::Status DebuggerState::PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, - const int64_t executor_step_index, const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_names) { + const int64_t executor_step_index, + const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_names) { return DebugIO::PublishDebugMetadata(global_step, session_run_index, executor_step_index, input_names, output_names, target_names, debug_urls_); @@ -55,11 +56,11 @@ absl::Status DebugGraphDecorator::DecorateGraph(Graph* graph, Device* device) { } absl::Status DebugGraphDecorator::PublishGraph(const Graph& graph, - const string& device_name) { - std::unordered_set debug_urls; + const std::string& device_name) { + std::unordered_set debug_urls; for (const DebugTensorWatch& watch : debug_options_.debug_tensor_watch_opts()) { - for (const string& url : watch.debug_urls()) { + for (const std::string& url : watch.debug_urls()) { debug_urls.insert(url); } } diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h index c34aa8bb51a917..73e74738d59d3c 100644 --- a/tensorflow/core/debug/debugger_state_impl.h +++ b/tensorflow/core/debug/debugger_state_impl.h @@ -34,12 +34,13 @@ class DebuggerState : public DebuggerStateInterface { // details. absl::Status PublishDebugMetadata( const int64_t global_step, const int64_t session_run_count, - const int64_t executor_step_count, const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_names) override; + const int64_t executor_step_count, + const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_names) override; private: - std::unordered_set debug_urls_; + std::unordered_set debug_urls_; }; class DebugGraphDecorator : public DebugGraphDecoratorInterface { @@ -50,7 +51,7 @@ class DebugGraphDecorator : public DebugGraphDecoratorInterface { absl::Status DecorateGraph(Graph* graph, Device* device) override; absl::Status PublishGraph(const Graph& graph, - const string& device_name) override; + const std::string& device_name) override; private: DebugOptions debug_options_; diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc index e5f5ef7620ab99..4e58928e5693dd 100644 --- a/tensorflow/core/debug/grpc_session_debug_test.cc +++ b/tensorflow/core/debug/grpc_session_debug_test.cc @@ -49,7 +49,7 @@ SessionOptions Devices(int num_cpus, int num_gpus) { return result; } -void CreateGraphDef(GraphDef* graph_def, string node_names[3]) { +void CreateGraphDef(GraphDef* graph_def, std::string node_names[3]) { Graph graph(OpRegistry::Global()); Tensor a_tensor(DT_FLOAT, TensorShape({1, 2})); @@ -77,7 +77,7 @@ void IsSingleFloatValue(const Tensor& val, float expected_val) { ASSERT_EQ(val.flat()(0), expected_val); } -SessionOptions Options(const string& target, int placement_period) { +SessionOptions Options(const std::string& target, int placement_period) { SessionOptions options; // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target // string. @@ -115,18 +115,19 @@ class GrpcSessionDebugTest : public ::testing::Test { } } - const string GetDebugURL() { return debug_url_; } + const std::string GetDebugURL() { return debug_url_; } - void LoadTensorDumps(const string& subdir, std::vector* tensors) { - const string dirpath = io::JoinPath(dump_dir_, subdir); + void LoadTensorDumps(const std::string& subdir, + std::vector* tensors) { + const std::string dirpath = io::JoinPath(dump_dir_, subdir); if (!(Env::Default()->IsDirectory(dirpath).ok())) { return; } - std::vector filenames; + std::vector filenames; TF_ASSERT_OK(Env::Default()->GetChildren(dirpath, &filenames)); - for (const string& filename : filenames) { + for (const std::string& filename : filenames) { Event event; TF_ASSERT_OK(ReadEventFromFile(io::JoinPath(dirpath, filename), &event)); if (event.summary().value().size() == 1) { @@ -144,13 +145,13 @@ class GrpcSessionDebugTest : public ::testing::Test { debug_url_ = absl::StrCat("file://", dump_dir_); } - string dump_dir_; - string debug_url_; + std::string dump_dir_; + std::string debug_url_; }; TEST_F(GrpcSessionDebugTest, FileDebugURL) { GraphDef graph; - string node_names[3]; + std::string node_names[3]; CreateGraphDef(&graph, node_names); std::unique_ptr cluster; @@ -216,7 +217,8 @@ TEST_F(GrpcSessionDebugTest, FileDebugURL) { TF_CHECK_OK(session->Close()); } -void SetDevice(GraphDef* graph, const string& name, const string& dev) { +void SetDevice(GraphDef* graph, const std::string& name, + const std::string& dev) { for (size_t i = 0; i < graph->node_size(); ++i) { if (graph->node(i).name() == name) { graph->mutable_node(i)->set_device(dev); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 476ab423154c88..13d130d289418c 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -107,7 +107,7 @@ GraphMgr::Item::~Item() { // NOTE: node->device_name() is not set by GraphConstructor. We // expects that NodeDef in GraphDef given to workers fully specifies // device names. -static string SplitByDevice(const Node* node) { +static std::string SplitByDevice(const Node* node) { return node->assigned_device_name(); } @@ -144,7 +144,7 @@ absl::Status GraphMgr::DecorateAndPublishGraphForDebug( // // "executors" are filled with one executor per device if success and // the caller takes the ownership of returned executors. -absl::Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, +absl::Status GraphMgr::InitItem(const std::string& handle, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, const ConfigProto& config_proto, @@ -187,14 +187,14 @@ absl::Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph)); // Splits "graph" into multiple subgraphs by device names. - std::unordered_map partitions; + std::unordered_map partitions; PartitionOptions popts; popts.node_to_loc = SplitByDevice; - popts.new_name = [this](const string& prefix) { + popts.new_name = [this](const std::string& prefix) { mutex_lock l(mu_); return absl::StrCat(prefix, "_G", next_id_++); }; - popts.get_incarnation = [this](const string& name) -> int64 { + popts.get_incarnation = [this](const std::string& name) -> int64_t { Device* device = nullptr; absl::Status s = device_mgr_->LookupDevice(name, &device); if (s.ok()) { @@ -211,7 +211,7 @@ absl::Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions)); } - std::unordered_map> partition_graphs; + std::unordered_map> partition_graphs; for (auto& partition : partitions) { std::unique_ptr device_graph(new Graph(OpRegistry::Global())); GraphConstructorOptions device_opts; @@ -236,7 +236,7 @@ absl::Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, const auto& optimizer_opts = graph_options.optimizer_options(); GraphOptimizer optimizer(optimizer_opts); for (auto& p : partition_graphs) { - const string& device_name = p.first; + const std::string& device_name = p.first; std::unique_ptr& subgraph = p.second; item->units.resize(item->units.size() + 1); ExecutionUnit* unit = &(item->units.back()); @@ -316,14 +316,14 @@ absl::Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef, return absl::OkStatus(); } -absl::Status GraphMgr::Register(const string& handle, const GraphDef& gdef, +absl::Status GraphMgr::Register(const std::string& handle, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, const ConfigProto& config_proto, int64_t collective_graph_key, WorkerSession* session, DistributedFunctionLibraryRuntime* cluster_flr, - string* graph_handle) { + std::string* graph_handle) { Item* item = new Item; absl::Status s = InitItem(handle, gdef, graph_options, debug_options, config_proto, @@ -344,7 +344,7 @@ absl::Status GraphMgr::Register(const string& handle, const GraphDef& gdef, return absl::OkStatus(); } -absl::Status GraphMgr::Deregister(const string& handle) { +absl::Status GraphMgr::Deregister(const std::string& handle) { Item* item = nullptr; // Removes one item from table_. { @@ -380,7 +380,7 @@ absl::Status GraphMgr::DeregisterAll() { absl::Status GraphMgr::SendInputs(const int64_t step_id, const NamedTensors& in) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id).release(); - std::vector keys; + std::vector keys; std::vector tensors_to_send; keys.reserve(in.size()); tensors_to_send.reserve(in.size()); @@ -419,7 +419,7 @@ absl::Status GraphMgr::RecvOutputs(const int64_t step_id, NamedTensors* out) { void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out, StatusCallback done) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id).release(); - std::vector keys; + std::vector keys; std::vector* received_keys = new std::vector; keys.reserve(out->size()); received_keys->reserve(out->size()); @@ -443,13 +443,13 @@ void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out, } void GraphMgr::ExecuteAsync( - const string& handle, const int64_t step_id, const ExecutorOpts& opts, + const std::string& handle, const int64_t step_id, const ExecutorOpts& opts, const NamedTensors& in, WorkerSession* session, StepStatsCollector* collector, MutableRunGraphResponseWrapper* response, CancellationManager* cancellation_manager, tsl::CoordinationServiceAgent* coordination_service_agent, StatusCallback done) { - const uint64 start_time_usecs = Env::Default()->NowMicros(); + const uint64_t start_time_usecs = Env::Default()->NowMicros(); tsl::profiler::TraceMeProducer activity( // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone. [step_id] { @@ -498,7 +498,7 @@ void GraphMgr::ExecuteAsync( // Sends values specified by the caller. size_t input_size = 0; if (s.ok()) { - std::vector keys; + std::vector keys; std::vector tensors_to_send; keys.reserve(in.size()); tensors_to_send.reserve(in.size()); @@ -543,17 +543,19 @@ void GraphMgr::ExecuteAsync( } void GraphMgr::StartParallelExecutors( - const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous, - CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, - CostGraphDef* cost_graph, CancellationManager* cancellation_manager, - WorkerSession* session, int64_t start_time_usecs, + const std::string& handle, int64_t step_id, Item* item, + Rendezvous* rendezvous, CollectiveExecutor::Handle* ce_handle, + StepStatsCollector* collector, CostGraphDef* cost_graph, + CancellationManager* cancellation_manager, WorkerSession* session, + int64_t start_time_usecs, tsl::CoordinationServiceAgent* coordination_service_agent, StatusCallback done) { const int num_units = item->units.size(); CHECK_GE(num_units, 1); - ScopedStepContainer* step_container = new ScopedStepContainer( - step_id, - [this](const string& name) { device_mgr_->ClearContainers({name}); }); + ScopedStepContainer* step_container = + new ScopedStepContainer(step_id, [this](const std::string& name) { + device_mgr_->ClearContainers({name}); + }); // NOTE: Transfer one ref of rendezvous and item. ExecutorBarrier* barrier = new ExecutorBarrier(num_units, rendezvous, @@ -602,7 +604,7 @@ void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector, CostGraphDef* cost_graph) { if (collector && !skip_cost_models_) { // Build the cost model - std::unordered_map device_to_graph; + std::unordered_map device_to_graph; for (const auto& unit : item->units) { if (unit.build_cost_model > 0) { device_to_graph[unit.device->name()] = unit.graph.get(); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 5c8c7ce0f20c95..3458771a21e9b1 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -85,21 +85,21 @@ class GraphMgr { // Registers a graph. Fills in "handle". The registered graph retains a // reference to cluster_flr to do cross process function calls. - absl::Status Register(const string& handle, const GraphDef& gdef, + absl::Status Register(const std::string& handle, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, const ConfigProto& config_proto, int64_t collective_graph_key, WorkerSession* session, DistributedFunctionLibraryRuntime* cluster_flr, - string* graph_handle); + std::string* graph_handle); // Executes one step of a registered graph "handle". // // If "out" is not nullptr, "out" specifies all keys the execution // should receive upon finish. - typedef std::map NamedTensors; + typedef std::map NamedTensors; typedef std::function StatusCallback; - void ExecuteAsync(const string& handle, const int64_t step_id, + void ExecuteAsync(const std::string& handle, const int64_t step_id, const ExecutorOpts& opts, const NamedTensors& in, WorkerSession* session, StepStatsCollector* collector, MutableRunGraphResponseWrapper* response, @@ -113,7 +113,7 @@ class GraphMgr { StatusCallback done); // Deregisters a graph. - absl::Status Deregister(const string& handle); + absl::Status Deregister(const std::string& handle); // Deregister all graphs. absl::Status DeregisterAll(); @@ -137,10 +137,10 @@ class GraphMgr { ~Item() override; // Session handle. - string session; + std::string session; // Graph handle. - string handle; + std::string handle; // Session configuration options for the graph. ConfigProto session_config; @@ -177,13 +177,14 @@ class GraphMgr { // TODO(zhifengc): If the client does not call Deregister, we'll // lose memory over time. We should implement a timeout-based // mechanism to gc these graphs. - std::unordered_map table_; + std::unordered_map table_; void StartParallelExecutors( - const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous, - CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, - CostGraphDef* cost_graph, CancellationManager* cancellation_manager, - WorkerSession* session, int64_t start_time_usecs, + const std::string& handle, int64_t step_id, Item* item, + Rendezvous* rendezvous, CollectiveExecutor::Handle* ce_handle, + StepStatsCollector* collector, CostGraphDef* cost_graph, + CancellationManager* cancellation_manager, WorkerSession* session, + int64_t start_time_usecs, tsl::CoordinationServiceAgent* coordination_service_agent, StatusCallback done); @@ -194,7 +195,7 @@ class GraphMgr { void BuildCostModel(Item* item, StepStatsCollector* collector, CostGraphDef* cost_graph); - absl::Status InitItem(const string& handle, const GraphDef& gdef, + absl::Status InitItem(const std::string& handle, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, const ConfigProto& config_proto, diff --git a/tensorflow/core/distributed_runtime/scheduler.cc b/tensorflow/core/distributed_runtime/scheduler.cc index 95aed8f498efc8..5935465711a02e 100644 --- a/tensorflow/core/distributed_runtime/scheduler.cc +++ b/tensorflow/core/distributed_runtime/scheduler.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/scheduler.h" +#include #include #include "tensorflow/core/common_runtime/device.h" @@ -280,7 +281,7 @@ Microseconds GreedyScheduler::ComputeSchedule( const Node* GreedyScheduler::GetNodeWithHighestPriority( const std::vector& nodes) { const Node* curr_node = nullptr; - int64_t curr_priority = kint64max; + int64_t curr_priority = std::numeric_limits::max(); for (const Node* n : nodes) { if ((*priority_)[n->id()] < curr_priority) { curr_node = n; diff --git a/tensorflow/core/example/example_parser_configuration.cc b/tensorflow/core/example/example_parser_configuration.cc index 085d215656c978..7f3cbcd2a49936 100644 --- a/tensorflow/core/example/example_parser_configuration.cc +++ b/tensorflow/core/example/example_parser_configuration.cc @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { absl::Status FindNodeIndexByName(const tensorflow::GraphDef& graph, - const string& node_name, int* node_idx) { + const std::string& node_name, int* node_idx) { for (int i = 0; i < graph.node_size(); ++i) { const auto& node = graph.node(i); if (node.name() == node_name) { @@ -42,7 +42,7 @@ absl::Status FindNodeIndexByName(const tensorflow::GraphDef& graph, } absl::Status ExtractExampleParserConfiguration( - const tensorflow::GraphDef& graph, const string& node_name, + const tensorflow::GraphDef& graph, const std::string& node_name, tensorflow::Session* session, std::vector* fixed_len_features, std::vector* var_len_features) { @@ -95,7 +95,7 @@ absl::Status ExtractExampleParserConfiguration( // We must fetch the configuration input tensors to the ParseExample op. // Skipping index = 0, which is the serialized proto input. - std::vector fetch_names(node.input_size() - 1); + std::vector fetch_names(node.input_size() - 1); for (int i = 1; i < node.input_size(); ++i) { fetch_names[i - 1] = node.input(i); } @@ -134,7 +134,7 @@ absl::Status ExtractExampleParserConfiguration( int sparse_shapes_output_start = sparse_values_output_start + num_sparse; int dense_values_output_start = sparse_shapes_output_start + num_sparse; - string node_output_prefix = absl::StrCat(node_name, ":"); + std::string node_output_prefix = absl::StrCat(node_name, ":"); for (int i = 0; i < num_sparse; ++i) { VarLenFeature& config = (*var_len_features)[i]; @@ -166,7 +166,7 @@ absl::Status ExampleParserConfigurationProtoToFeatureVectors( std::vector* var_len_features) { const auto& feature_map = config_proto.feature_map(); for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) { - string key = it->first; + std::string key = it->first; const auto& config = it->second; if (config.has_fixed_len_feature()) { const auto& fixed_config = config.fixed_len_feature(); diff --git a/tensorflow/core/example/example_parser_configuration.h b/tensorflow/core/example/example_parser_configuration.h index dd2aacaee2c078..b202b035da16c5 100644 --- a/tensorflow/core/example/example_parser_configuration.h +++ b/tensorflow/core/example/example_parser_configuration.h @@ -38,7 +38,7 @@ namespace tensorflow { // Given a graph and the node_name of a ParseExample op, // extract the FixedLenFeature/VarLenFeature configurations. absl::Status ExtractExampleParserConfiguration( - const tensorflow::GraphDef& graph, const string& node_name, + const tensorflow::GraphDef& graph, const std::string& node_name, tensorflow::Session* session, std::vector* fixed_len_features, std::vector* var_len_features); diff --git a/tensorflow/core/example/example_parser_configuration_test.cc b/tensorflow/core/example/example_parser_configuration_test.cc index 8abbd705cbcbe7..d83984d3373139 100644 --- a/tensorflow/core/example/example_parser_configuration_test.cc +++ b/tensorflow/core/example/example_parser_configuration_test.cc @@ -29,7 +29,8 @@ limitations under the License. namespace tensorflow { namespace { -void ReadFileToStringOrDie(Env* env, const string& filename, string* output) { +void ReadFileToStringOrDie(Env* env, const std::string& filename, + std::string* output) { TF_CHECK_OK(ReadFileToString(env, filename, output)); } @@ -42,8 +43,8 @@ std::unique_ptr CreateSession() { class ExtractExampleParserConfigurationTest : public ::testing::Test { protected: void SetUp() override { - string proto_string; - string filename = + std::string proto_string; + std::string filename = io::JoinPath(testing::TensorFlowSrcRoot(), "core/example/testdata/parse_example_graph_def.pbtxt"); ReadFileToStringOrDie(Env::Default(), filename, &proto_string); diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc index 374bbe6093b717..8192c7b9ffa420 100644 --- a/tensorflow/core/example/feature_util_test.cc +++ b/tensorflow/core/example/feature_util_test.cc @@ -455,9 +455,9 @@ TEST(AppendFeatureValuesTest, StringValuesUsingInitializerList) { TEST(AppendFeatureValuesTest, StringVariablesUsingInitializerList) { Example example; - string string1("FOO"); - string string2("BAR"); - string string3("BAZ"); + std::string string1("FOO"); + std::string string2("BAR"); + std::string string3("BAZ"); AppendFeatureValues({string1, string2, string3}, "tag", &example); diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index e7aa3a0bf21c17..d8d38eb58e9ae2 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -96,7 +96,7 @@ constexpr int kMaxTensorNestDepth = 100; // to serialize, compute hash based on TensorProto string representation. // This approach may result different hash codes with identical Tensors if they // are defined with different TensorProto representations. -uint64 TensorProtoHash(const TensorProto& tp) { +uint64_t TensorProtoHash(const TensorProto& tp) { Tensor tensor(tp.dtype()); bool success = tensor.FromProto(tp); if (success) { @@ -112,7 +112,7 @@ uint64 TensorProtoHash(const TensorProto& tp) { // string representation. Tensors with identical content potentially can have a // different hash code if they are defined with different TensorProto // representations. -uint64 FastTensorProtoHash(const TensorProto& tp) { +uint64_t FastTensorProtoHash(const TensorProto& tp) { if (attr_value_util_internal::TensorByteSize(tp) > kMaxAttrValueTensorByteSize) { return DeterministicProtoHash64(tp); @@ -180,15 +180,17 @@ bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs, return AreSerializedProtosEqual(lhs_tp, rhs_tp); } -using TensorProtoHasher = std::function; +using TensorProtoHasher = std::function; -uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) { +uint64_t AttrValueHash(const AttrValue& a, + const TensorProtoHasher& tensor_hash) { if (a.has_tensor()) return tensor_hash(a.tensor()); if (a.has_func()) { const NameAttrList& func = a.func(); - uint64 h = Hash64(func.name()); - std::map map(func.attr().begin(), func.attr().end()); + uint64_t h = Hash64(func.name()); + std::map map(func.attr().begin(), + func.attr().end()); for (const auto& pair : map) { h = Hash64(pair.first.data(), pair.first.size(), h); h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h); @@ -200,8 +202,8 @@ uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) { return DeterministicProtoHash64(a); } -string SummarizeString(const string& str) { - string escaped = absl::CEscape(str); +std::string SummarizeString(const std::string& str) { + std::string escaped = absl::CEscape(str); // If the string is long, replace the middle with ellipses. constexpr int kMaxStringSummarySize = 80; @@ -216,7 +218,7 @@ string SummarizeString(const string& str) { } } -string SummarizeTensor(const TensorProto& tensor_proto) { +std::string SummarizeTensor(const TensorProto& tensor_proto) { Tensor t; int64_t tensor_byte_size = attr_value_util_internal::TensorByteSize(tensor_proto); @@ -233,8 +235,8 @@ string SummarizeTensor(const TensorProto& tensor_proto) { return t.DebugString(); } -string SummarizeFunc(const NameAttrList& func) { - std::vector entries; +std::string SummarizeFunc(const NameAttrList& func) { + std::vector entries; for (const auto& p : func.attr()) { entries.push_back(absl::StrCat(p.first, "=", SummarizeAttrValue(p.second))); } @@ -242,7 +244,8 @@ string SummarizeFunc(const NameAttrList& func) { return absl::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]"); } -bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) { +bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, + std::string to_parse) { int nests = 0; int maxed_out = to_parse.length(); int open_curly = to_parse.find('{'); @@ -292,7 +295,7 @@ bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) { } // namespace -string SummarizeAttrValue(const AttrValue& attr_value) { +std::string SummarizeAttrValue(const AttrValue& attr_value) { switch (attr_value.value_case()) { case AttrValue::kS: return SummarizeString(attr_value.s()); @@ -309,7 +312,7 @@ string SummarizeAttrValue(const AttrValue& attr_value) { case AttrValue::kTensor: return SummarizeTensor(attr_value.tensor()); case AttrValue::kList: { - std::vector pieces; + std::vector pieces; if (attr_value.list().s_size() > 0) { for (int i = 0; i < attr_value.list().s_size(); ++i) { pieces.push_back(SummarizeString(attr_value.list().s(i))); @@ -472,7 +475,7 @@ absl::Status AttrValueHasType(const AttrValue& attr_value, bool ParseAttrValue(absl::string_view type, absl::string_view text, AttrValue* out) { // Parse type. - string field_name; + std::string field_name; bool is_list = absl::ConsumePrefix(&type, "list("); if (absl::ConsumePrefix(&type, "string")) { field_name = "s"; @@ -500,7 +503,7 @@ bool ParseAttrValue(absl::string_view type, absl::string_view text, } // Construct a valid text proto message to parse. - string to_parse; + std::string to_parse; if (is_list) { // TextFormat parser considers "i: 7" to be the same as "i: [7]", // but we only want to allow list values with []. @@ -550,8 +553,8 @@ void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, FIELD) -DEFINE_SET_ATTR_VALUE_ONE(const string&, s) -DEFINE_SET_ATTR_VALUE_LIST(absl::Span, s) +DEFINE_SET_ATTR_VALUE_ONE(const std::string&, s) +DEFINE_SET_ATTR_VALUE_LIST(absl::Span, s) DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i) DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i) @@ -585,7 +588,7 @@ void SetAttrValue(const absl::Span value, } } -void MoveAttrValue(std::vector&& value, AttrValue* out) { +void MoveAttrValue(std::vector&& value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (auto& v : value) { out->mutable_list()->add_s(std::move(v)); @@ -689,8 +692,8 @@ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, const NameAttrList& af = a.func(); const NameAttrList& bf = b.func(); if (af.name() != bf.name()) return false; - std::unordered_map am(af.attr().begin(), - af.attr().end()); + std::unordered_map am(af.attr().begin(), + af.attr().end()); for (const auto& bm_pair : bf.attr()) { const auto& iter = am.find(bm_pair.first); if (iter == am.end()) return false; @@ -708,11 +711,11 @@ bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, return AreSerializedProtosEqual(a, b); } -uint64 AttrValueHash(const AttrValue& a) { +uint64_t AttrValueHash(const AttrValue& a) { return AttrValueHash(a, TensorProtoHash); } -uint64 FastAttrValueHash(const AttrValue& a) { +uint64_t FastAttrValueHash(const AttrValue& a) { return AttrValueHash(a, FastTensorProtoHash); } diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h index b6f7c972c71624..135bfe67231f37 100644 --- a/tensorflow/core/framework/attr_value_util.h +++ b/tensorflow/core/framework/attr_value_util.h @@ -76,12 +76,12 @@ void SetAttrValue(const Tensor& value, AttrValue* out); void SetAttrValue(const TensorProto& value, AttrValue* out); void SetAttrValue(const NameAttrList& value, AttrValue* out); -void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); -void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(absl::Span value, AttrValue* out); @@ -97,7 +97,7 @@ void SetAttrValue(absl::Span value, AttrValue* out); void SetAttrValue(const AttrValue& value, AttrValue* out); -void MoveAttrValue(std::vector&& value, AttrValue* out); +void MoveAttrValue(std::vector&& value, AttrValue* out); // Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other // words, if two AttrValues compare equal according to AreAttrValuesEqual, @@ -105,7 +105,7 @@ void MoveAttrValue(std::vector&& value, AttrValue* out); // Similarly to protobuf deterministic serialization, hash value is // guaranteed to be stable only for a given binary. In particular, one should // probably not persist the returned value. -uint64 AttrValueHash(const AttrValue& a); +uint64_t AttrValueHash(const AttrValue& a); // WARNING: Equality check might return false-negative for large (> 32mb) // tensors defined with different TensorProto representations. @@ -117,7 +117,7 @@ uint64 AttrValueHash(const AttrValue& a); // bool_val), they will have different hash code and equals will return false. // Small (less than 32mb) tensors with different TensorProto representations // hashed/compared by their tensor content. -uint64 FastAttrValueHash(const AttrValue& a); +uint64_t FastAttrValueHash(const AttrValue& a); // Returns true if a and b have the same value. If false negatives are allowed, // then compares proto representation to avoid construction of large (> 32mb) // tensors. @@ -134,7 +134,7 @@ bool HasPlaceHolder(const AttrValue& val); // SubstituteFunc is given a placeholder string. If the placeholder is // unknown, SubstituteFunc returns false. Otherwise, overwrites the // attr value and returns true. -using SubstituteFunc = std::function; +using SubstituteFunc = std::function; bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value); } // namespace tensorflow diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index 4e8daeb8f04dde..d6d685ef4c49f0 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -36,14 +36,14 @@ AttrValue V(T value) { return ret; } -AttrValue P(const string& p) { +AttrValue P(const std::string& p) { AttrValue ret; ret.set_placeholder(p); return ret; } -AttrValue F(const string& name, - std::vector> pairs) { +AttrValue F(const std::string& name, + std::vector> pairs) { AttrValue ret; ret.mutable_func()->set_name(name); ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end()); @@ -51,7 +51,8 @@ AttrValue F(const string& name, } AttrValue Fs( - std::vector>>> + std::vector< + std::pair>>> funcs) { AttrValue ret; for (const auto& func : funcs) { @@ -82,7 +83,7 @@ TEST(AttrValueUtil, HasType) { } SubstituteFunc ReplaceTWith(const AttrValue& val) { - return [val](const string& placeholder, AttrValue* target) { + return [val](const std::string& placeholder, AttrValue* target) { if (placeholder == "T") { *target = val; return true; @@ -142,14 +143,14 @@ TEST(AttrValueUtil, DeepAttr) { TEST(AttrValueUtil, SummarizeAttrValueDoesNotElideShortStrings) { AttrValue attr_value; - SetAttrValue(string(40, '-'), &attr_value); - EXPECT_EQ(absl::StrCat("\"", string(40, '-'), "\""), + SetAttrValue(std::string(40, '-'), &attr_value); + EXPECT_EQ(absl::StrCat("\"", std::string(40, '-'), "\""), SummarizeAttrValue(attr_value)); } TEST(AttrValueUtil, SummarizeAttrValueElidesLongStrings) { AttrValue attr_value; - SetAttrValue(string(80, '-'), &attr_value); + SetAttrValue(std::string(80, '-'), &attr_value); EXPECT_EQ("\"----------...----------\"", SummarizeAttrValue(attr_value)); } @@ -197,7 +198,7 @@ TEST(AttrValueUtil, TensorByteSizeShouldNotOverflow) { } } -AttrValue FromText(const string& text) { +AttrValue FromText(const std::string& text) { AttrValue attr; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &attr)); return attr; diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index b2f4b856ec9feb..e4456d888df736 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -32,11 +32,11 @@ struct RegistrationInfo { // what is effectively a static instance of the collective implementation. // During param resolution of collective ops we return this static instance. // The actual op execution gets a fresh instance using `factory`. - RegistrationInfo(const string& n, CollectiveRegistry::Factory f) + RegistrationInfo(const std::string& n, CollectiveRegistry::Factory f) : name(n), factory(std::move(f)), param_resolver_instance(this->factory()) {} - string name; + std::string name; CollectiveRegistry::Factory factory; CollectiveImplementationInterface* param_resolver_instance; }; @@ -48,13 +48,13 @@ std::vector* MutableCollectiveRegistry() { } } // namespace -string CollGroupRuntimeDetails::ToString() const { +std::string CollGroupRuntimeDetails::ToString() const { return absl::StrCat("CollGroupRuntimeDetails {communicator_key=", absl::CEscape(communicator_key), "}"); } -string CollGroupParams::ToString() const { - string v = strings::StrCat( +std::string CollGroupParams::ToString() const { + std::string v = strings::StrCat( "CollGroupParams {group_key=", group_key, " group_size=", group_size, " device_type=", device_type.type_string(), " num_tasks=", num_tasks, " runtime_details=", runtime_details.ToString(), " devices {"); @@ -94,8 +94,8 @@ CollInstanceParams& CollInstanceParams::operator=( return *this; } -string CollInstanceParams::ToString() const { - string v = +std::string CollInstanceParams::ToString() const { + std::string v = strings::StrCat("CollInstanceParams { instance_key=", instance_key, " type=", type, " data_type=", DataTypeString(data_type), " shape=", shape.DebugString(), " devices {"); @@ -134,8 +134,9 @@ string CollInstanceParams::ToString() const { return v; } -string CollectiveParams::ToString() const { - string v = absl::StrCat("CollectiveParams ", name, " {", group.ToString()); +std::string CollectiveParams::ToString() const { + std::string v = + absl::StrCat("CollectiveParams ", name, " {", group.ToString()); absl::StrAppend(&v, " ", instance.ToString()); strings::StrAppend(&v, " default_rank=", default_rank, " is_source=", is_source, " source_rank=", source_rank, @@ -156,7 +157,7 @@ CollectiveContext::CollectiveContext( CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator, const DeviceMgr* dev_mgr, OpKernelContext* ctx, OpKernelContext::Params* op_params, const CollectiveParams* col_params, - const string& exec_key, int64_t step_id, const Tensor* input, + const std::string& exec_key, int64_t step_id, const Tensor* input, Tensor* output) : col_exec(col_exec), nccl_communicator(nccl_communicator), @@ -177,14 +178,14 @@ int64_t CollectiveExecutor::kInvalidId = -1; /*static*/ absl::Status CollectiveRegistry::Lookup( - const string& collective_name, + const std::string& collective_name, CollectiveImplementationInterface** implementation) { return LookupHelper(collective_name, implementation, false); } /*static*/ absl::Status CollectiveRegistry::LookupParamResolverInstance( - const string& collective_name, + const std::string& collective_name, CollectiveImplementationInterface** implementation) { return LookupHelper(collective_name, implementation, true); } @@ -198,7 +199,7 @@ void CollectiveRegistry::GetAll( } /*static*/ -absl::Status CollectiveRegistry::Register(const string& collective_name, +absl::Status CollectiveRegistry::Register(const std::string& collective_name, Factory factory) { std::vector* registry = MutableCollectiveRegistry(); for (const RegistrationInfo& reg_info : *registry) { @@ -212,7 +213,7 @@ absl::Status CollectiveRegistry::Register(const string& collective_name, /*static*/ absl::Status CollectiveRegistry::LookupHelper( - const string& collective_name, + const std::string& collective_name, CollectiveImplementationInterface** implementation, bool param_resolver) { std::vector* registry = MutableCollectiveRegistry(); for (const RegistrationInfo& reg_info : *registry) { diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index 8fca00f0e3b515..cdb22129e813d4 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -56,16 +56,16 @@ enum CollectiveType { // the OpKernel. Currently, this struct is used to set communicator key for // NCCL-based collective implementation. struct CollGroupRuntimeDetails { - string communicator_key; // for communicator-based techniques e.g. NCCL - string ToString() const; + std::string communicator_key; // for communicator-based techniques e.g. NCCL + std::string ToString() const; }; struct CollGroupMember { DeviceAttributes device; - string task; + std::string task; bool is_local; // User provided rank - int32 rank = -1; + int32_t rank = -1; }; // Data common to all members of a device group. @@ -73,8 +73,8 @@ struct CollGroupMember { // particular to an instance so it is stored there. struct CollGroupParams { // Inputs from Collective ops: - int32 group_key; - int32 group_size; + int32_t group_key; + int32_t group_size; DeviceType device_type; int user_specified_rank = -1; // rank provided by the user. // Generated from Collective Group Resolver: @@ -83,10 +83,10 @@ struct CollGroupParams { // True if every task has the same number of devices. bool same_num_devices_per_task = false; // Task -> number of devices on that task. - std::unordered_map num_devices_per_task; - int32 num_tasks; // number of distinct tasks in group + std::unordered_map num_devices_per_task; + int32_t num_tasks; // number of distinct tasks in group CollGroupRuntimeDetails runtime_details; - string ToString() const; + std::string ToString() const; CollGroupParams() : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} }; @@ -99,7 +99,7 @@ struct CollGroupParams { // interpretation. On first execution the runtime will update this // structure with decisions that will guide all subsequent executions. struct CollImplDetails { - string collective_name; + std::string collective_name; std::vector> subdiv_permutations; // subdiv_offsets and max_subdivs_per_device are used together as follows: // When subdiv_offsets is provided (non-empty) it is used as is. When @@ -110,10 +110,10 @@ struct CollImplDetails { int max_subdivs_per_device = -1; // Upper bound on subdivisions per device. std::vector subdiv_offsets; std::vector subdiv_source_rank; // rank of source in each subdiv - std::vector - dependencies; // collective instances on which this node depends - string communication_hint; // user-supplied hint for implementation choice, - // e.g. ring or nccl + std::vector + dependencies; // collective instances on which this node depends + std::string communication_hint; // user-supplied hint for implementation + // choice, e.g. ring or nccl float timeout_seconds; // If non zero, set a completion timeout for the // collective op to detect staleness. }; @@ -122,16 +122,16 @@ struct CollImplDetails { // TODO(b/163171014) Refactor this struct to not be a union of all fields. struct CollInstanceParams { // Identifies all participating graph nodes. - int32 instance_key = -1; + int32_t instance_key = -1; // The full identifier includes both instance_key and step_id. int64_t step_id = 0; CollectiveType type = UNDEFINED_COLLECTIVE; DataType data_type = DT_FLOAT; TensorShape shape = {0}; CollImplDetails impl_details; - string ToString() const; + std::string ToString() const; CollInstanceParams& operator=(const struct CollInstanceParams& other); - std::vector devices; // permuter only + std::vector devices; // permuter only // For permuter only // Each rank in the permutation is a receiver. @@ -148,7 +148,7 @@ struct CollectiveParams : public core::RefCounted { CollGroupParams group; CollInstanceParams instance; - string name = ""; // node name used only for log or error messages + std::string name = ""; // node name used only for log or error messages int default_rank = -1; // index of this op within device_names bool is_source = false; // broadcast only int source_rank = -1; // broadcast only @@ -156,7 +156,7 @@ struct CollectiveParams : public core::RefCounted { std::vector subdiv_rank; OpKernel* merge_op = nullptr; // reduction only OpKernel* final_op = nullptr; // reduction only - string ToString() const; + std::string ToString() const; bool run_group_initialization = true; bool is_stateless = false; }; @@ -169,12 +169,12 @@ class DeviceResolverInterface { virtual ~DeviceResolverInterface() {} // Populates *attributes with the DeviceAttributes of the specified device. - virtual absl::Status GetDeviceAttributes(const string& device, + virtual absl::Status GetDeviceAttributes(const std::string& device, DeviceAttributes* attributes) = 0; // Returns all device attributes of a task. virtual absl::Status GetAllDeviceAttributes( - const string& task, std::vector* attributes) = 0; + const std::string& task, std::vector* attributes) = 0; // Updates device attributes. It returns error if any device already // exists in the DeviceResolver and has a different incarnation. @@ -284,19 +284,17 @@ class CollectiveRemoteAccess { public: virtual ~CollectiveRemoteAccess() {} - virtual void RecvFromPeer(const string& peer_device, const string& peer_task, - bool peer_is_local, const string& key, - Device* to_device, DeviceContext* to_device_ctx, - const AllocatorAttributes& to_alloc_attr, - Tensor* to_tensor, - const DeviceLocality& client_locality, - int dev_to_dev_stream_index, - CancellationManager* cancellation_manager, - const StatusCallback& done) = 0; - - virtual void PostToPeer(const string& peer_device, const string& peer_task, - const string& key, Device* from_device, - DeviceContext* from_device_ctx, + virtual void RecvFromPeer( + const std::string& peer_device, const std::string& peer_task, + bool peer_is_local, const std::string& key, Device* to_device, + DeviceContext* to_device_ctx, const AllocatorAttributes& to_alloc_attr, + Tensor* to_tensor, const DeviceLocality& client_locality, + int dev_to_dev_stream_index, CancellationManager* cancellation_manager, + const StatusCallback& done) = 0; + + virtual void PostToPeer(const std::string& peer_device, + const std::string& peer_task, const std::string& key, + Device* from_device, DeviceContext* from_device_ctx, const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor, const DeviceLocality& client_locality, @@ -306,7 +304,8 @@ class CollectiveRemoteAccess { // Checks the health of a collective peer. It probes the peer to see if it is // alive. Note that if a peer has restarted, it's considered a different one, // so CheckPeerHealth fails. - virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, + virtual void CheckPeerHealth(const std::string& peer_task, + int64_t timeout_in_ms, const StatusCallback& done) = 0; virtual BufRendezvous* buf_rendezvous() = 0; @@ -322,7 +321,7 @@ class CollectiveExecutor : public core::RefCounted { virtual void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params, - const string& exec_key, StatusCallback done) { + const std::string& exec_key, StatusCallback done) { done(errors::Internal( "A collective Op has been called in a context in which " "a CollectiveExecutor has not been provided.")); @@ -404,27 +403,28 @@ struct CollectiveContext { OpKernelContext* op_ctx; // Not owned OpKernelContext::Params* op_params; // Not owned core::IntrusivePtr col_params; - const string exec_key; + const std::string exec_key; const int64_t step_id; const Tensor* input; // Not owned Tensor* output; // Not owned Device* device; // The device for which this instance labors - const string device_name; + const std::string device_name; DeviceLocality device_locality; CollectiveContext(CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator, const DeviceMgr* dev_mgr, OpKernelContext* ctx, OpKernelContext::Params* op_params, - const CollectiveParams* col_params, const string& exec_key, - int64_t step_id, const Tensor* input, Tensor* output); + const CollectiveParams* col_params, + const std::string& exec_key, int64_t step_id, + const Tensor* input, Tensor* output); }; class NcclCommunicatorInterface { public: virtual ~NcclCommunicatorInterface() = default; - virtual string GenerateCommunicatorKey() = 0; + virtual std::string GenerateCommunicatorKey() = 0; virtual void Enqueue(std::shared_ptr col_ctx, StatusCallback done) = 0; @@ -474,7 +474,7 @@ class CollectiveRegistry { // `collective_name`. If found, creates an instance of the implementation and // assign to `implementation`. static absl::Status Lookup( - const string& collective_name, + const std::string& collective_name, CollectiveImplementationInterface** implementation); // Looks up a previously registered CollectiveImplementation under @@ -482,7 +482,7 @@ class CollectiveRegistry { // implementation via `implementation`. This instance should only be used to // call InitializateCollectiveParams. static absl::Status LookupParamResolverInstance( - const string& collective_name, + const std::string& collective_name, CollectiveImplementationInterface** implementation); // Returns all registered collective implementations. @@ -496,10 +496,11 @@ class CollectiveRegistry { // the CollectiveImplementation. Also creates a static instance of the // implementation - this instance is used during param resolution and should // only be used to call InitializeCollectiveParams. - static absl::Status Register(const string& collective_name, Factory factory); + static absl::Status Register(const std::string& collective_name, + Factory factory); static absl::Status LookupHelper( - const string& collective_name, + const std::string& collective_name, CollectiveImplementationInterface** implementation, bool param_resolver); }; @@ -507,7 +508,7 @@ class CollectiveRegistry { // create a global static object. class CollectiveRegistration { public: - CollectiveRegistration(const string& collective_name, + CollectiveRegistration(const std::string& collective_name, CollectiveRegistry::Factory factory) { TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); } diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 0f495b17a69544..bcfd94424e59c7 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -164,10 +164,10 @@ absl::Status EinsumShape(shape_inference::InferenceContext* c) { // We assume that the equation has a valid format. Either (x),(y)->(z) // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or // more latin alphabets and contains at most one ellipsis ('...'). - string equation; + std::string equation; TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation)); - absl::InlinedVector input_labels; - string output_labels; + absl::InlinedVector input_labels; + std::string output_labels; TF_RETURN_IF_ERROR( ValidateEinsumEquation(equation, &input_labels, &output_labels)); @@ -391,7 +391,7 @@ absl::Status BiasAddShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; // Fetch the data_format attribute, which may not exist. - string data_format; + std::string data_format; absl::Status s = c->GetAttr("data_format", &data_format); if (s.ok() && data_format == "NCHW") { @@ -449,7 +449,7 @@ absl::Status BiasAddShape(shape_inference::InferenceContext* c) { absl::Status BiasAddGradShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; // Fetch the data_format attribute, which may not exist. - string data_format; + std::string data_format; absl::Status s = c->GetAttr("data_format", &data_format); if (s.ok() && data_format == "NCHW") { @@ -465,7 +465,7 @@ absl::Status BiasAddGradShape(shape_inference::InferenceContext* c) { absl::Status CheckFormatConstraintsOnShape( const TensorFormat tensor_format, const ShapeHandle shape_handle, - const string& tensor_name, shape_inference::InferenceContext* c) { + const std::string& tensor_name, shape_inference::InferenceContext* c) { if (tensor_format == FORMAT_NCHW_VECT_C) { // Check that the vect dim has size 4 or 32. const int num_dims = c->Rank(shape_handle); @@ -593,7 +593,7 @@ namespace { absl::Status Conv2DShapeImpl(shape_inference::InferenceContext* c, bool supports_explicit_padding) { - string data_format_str, filter_format_str; + std::string data_format_str, filter_format_str; if (!c->GetAttr("data_format", &data_format_str).ok()) { data_format_str = "NHWC"; } @@ -626,7 +626,7 @@ absl::Status Conv2DShapeImpl(shape_inference::InferenceContext* c, TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); - std::vector dilations; + std::vector dilations; TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); if (dilations.size() != 4) { @@ -635,7 +635,7 @@ absl::Status Conv2DShapeImpl(shape_inference::InferenceContext* c, dilations.size()); } - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). @@ -808,7 +808,7 @@ absl::Status ConvShape(shape_inference::InferenceContext* c) { } // Default format is NHWC for 2D and NDHWC for 3D. - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); bool channels_last_format; if (data_format_str == "CHANNELS_LAST") { @@ -827,7 +827,7 @@ absl::Status ConvShape(shape_inference::InferenceContext* c) { // Determine number of spatial dims. int spatial_dims = standard_input_rank - 2; - std::vector dilations; + std::vector dilations; TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); // Default case. if (dilations.empty()) { @@ -840,7 +840,7 @@ absl::Status ConvShape(shape_inference::InferenceContext* c) { " values, but got: ", dilations.size())); } - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != standard_input_rank) { return absl::InvalidArgumentError( @@ -1004,10 +1004,10 @@ absl::Status Conv3DShape(shape_inference::InferenceContext* c) { ShapeHandle filter_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); - string data_format; + std::string data_format; absl::Status s = c->GetAttr("data_format", &data_format); - std::vector dilations; + std::vector dilations; TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); if (dilations.size() != 5) { @@ -1016,7 +1016,7 @@ absl::Status Conv3DShape(shape_inference::InferenceContext* c) { dilations.size()); } - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 5) { return errors::InvalidArgument( @@ -1113,7 +1113,7 @@ absl::Status Conv3DShape(shape_inference::InferenceContext* c) { } absl::Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { - string data_format_str; + std::string data_format_str; if (!c->GetAttr("data_format", &data_format_str).ok()) { data_format_str = "NHWC"; } @@ -1188,7 +1188,7 @@ absl::Status Conv2DBackpropFilterWithBiasShape( shape_inference::InferenceContext* c) { ShapeHandle input_shape; // Fetch the data_format attribute, which may not exist. - string data_format; + std::string data_format; absl::Status s = c->GetAttr("data_format", &data_format); TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); @@ -1213,7 +1213,7 @@ absl::Status DepthwiseConv2DNativeShapeImpl( ShapeHandle filter_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { @@ -1223,7 +1223,7 @@ absl::Status DepthwiseConv2DNativeShapeImpl( strides.size()); } - std::vector dilations; + std::vector dilations; if (!c->GetAttr("dilations", &dilations).ok()) { dilations.resize(4, 1); } @@ -1235,7 +1235,7 @@ absl::Status DepthwiseConv2DNativeShapeImpl( dilations.size()); } - string data_format_str; + std::string data_format_str; absl::Status s = c->GetAttr("data_format", &data_format_str); TensorFormat data_format; if (!s.ok() || !FormatFromString(data_format_str, &data_format)) { @@ -1338,7 +1338,7 @@ absl::Status DepthwiseConv2DNativeShapeWithExplicitPadding( } absl::Status AvgPoolShape(shape_inference::InferenceContext* c) { - string data_format_str; + std::string data_format_str; TensorFormat data_format; absl::Status s = c->GetAttr("data_format", &data_format_str); if (s.ok()) { @@ -1354,7 +1354,7 @@ absl::Status AvgPoolShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { return errors::InvalidArgument( @@ -1362,7 +1362,7 @@ absl::Status AvgPoolShape(shape_inference::InferenceContext* c) { strides.size()); } - std::vector kernel_sizes; + std::vector kernel_sizes; TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); if (kernel_sizes.size() != 4) { return errors::InvalidArgument( @@ -1415,7 +1415,7 @@ absl::Status AvgPoolGradShape(shape_inference::InferenceContext* c) { } absl::Status FusedBatchNormShape(shape_inference::InferenceContext* c) { - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; if (!FormatFromString(data_format_str, &data_format)) { @@ -1465,7 +1465,7 @@ absl::Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) { absl::Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(FusedBatchNormV3Shape(c)); - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; if (!FormatFromString(data_format_str, &data_format)) { @@ -1488,7 +1488,7 @@ absl::Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { } absl::Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; if (!FormatFromString(data_format_str, &data_format)) { @@ -1537,7 +1537,7 @@ absl::Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { return absl::OkStatus(); } - string data_format_str; + std::string data_format_str; TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); TensorFormat data_format; if (!FormatFromString(data_format_str, &data_format)) { @@ -1565,19 +1565,20 @@ absl::Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { } absl::Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, - int32* lower_diag_index, int32* upper_diag_index) { + int32_t* lower_diag_index, + int32_t* upper_diag_index) { // This function assumes that the shape of diag_index_tensor is fully defined. if (diag_index_tensor->dims() == 0) { - *lower_diag_index = diag_index_tensor->scalar()(); + *lower_diag_index = diag_index_tensor->scalar()(); *upper_diag_index = *lower_diag_index; } else { int32_t num_elements = diag_index_tensor->dim_size(0); if (num_elements == 1) { - *lower_diag_index = diag_index_tensor->vec()(0); + *lower_diag_index = diag_index_tensor->vec()(0); *upper_diag_index = *lower_diag_index; } else if (num_elements == 2) { - *lower_diag_index = diag_index_tensor->vec()(0); - *upper_diag_index = diag_index_tensor->vec()(1); + *lower_diag_index = diag_index_tensor->vec()(0); + *upper_diag_index = diag_index_tensor->vec()(1); } else { return errors::InvalidArgument( "diag_index must be a vector with one or two elements. It has ", @@ -1815,7 +1816,7 @@ absl::Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { absl::Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, bool supports_explicit_padding) { - string data_format_str; + std::string data_format_str; TensorFormat data_format; absl::Status s = c->GetAttr("data_format", &data_format_str); if (s.ok()) { @@ -1831,7 +1832,7 @@ absl::Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { return errors::InvalidArgument( @@ -1839,7 +1840,7 @@ absl::Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, strides.size()); } - std::vector kernel_sizes; + std::vector kernel_sizes; TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); if (kernel_sizes.size() != 4) { return errors::InvalidArgument( @@ -1924,7 +1925,7 @@ absl::Status MaxPoolShapeWithExplicitPadding( absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { - string data_format_str; + std::string data_format_str; TensorFormat data_format; absl::Status s = c->GetAttr("data_format", &data_format_str); if (s.ok()) { @@ -1940,8 +1941,8 @@ absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); - std::vector kernel_sizes; - std::vector strides; + std::vector kernel_sizes; + std::vector strides; if (c->num_inputs() + 2 == num_inputs) { TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); @@ -1962,7 +1963,7 @@ absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, return absl::OkStatus(); } kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); - auto kernel_sizes_vec = kernel_sizes_tensor->flat(); + auto kernel_sizes_vec = kernel_sizes_tensor->flat(); std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin()); @@ -1972,7 +1973,7 @@ absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, return absl::OkStatus(); } strides.resize(strides_tensor->shape().num_elements()); - auto strides_vec = strides_tensor->flat(); + auto strides_vec = strides_tensor->flat(); std::copy_n(&strides_vec(0), strides.size(), strides.begin()); } @@ -2029,10 +2030,10 @@ absl::Status Pool3DShape(shape_inference::InferenceContext* c) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); - string data_format; + std::string data_format; absl::Status s = c->GetAttr("data_format", &data_format); - std::vector strides; + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 5) { return errors::InvalidArgument( @@ -2041,7 +2042,7 @@ absl::Status Pool3DShape(shape_inference::InferenceContext* c) { strides.size()); } - std::vector kernel_sizes; + std::vector kernel_sizes; TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); if (kernel_sizes.size() != 5) { return errors::InvalidArgument( @@ -2181,8 +2182,8 @@ absl::Status ReductionShape(InferenceContext* c) { const int32_t input_rank = c->Rank(input); std::set true_indices; if (reduction_indices_t->dtype() == DataType::DT_INT32) { - TF_RETURN_IF_ERROR(ReductionShapeHelper(reduction_indices_t, - input_rank, &true_indices)); + TF_RETURN_IF_ERROR(ReductionShapeHelper( + reduction_indices_t, input_rank, &true_indices)); } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { TF_RETURN_IF_ERROR(ReductionShapeHelper( reduction_indices_t, input_rank, &true_indices)); @@ -2247,13 +2248,13 @@ absl::Status ConcatShapeHelper(InferenceContext* c, int start_value_index, // shape. int64_t concat_dim; if (concat_dim_t->dtype() == DT_INT32) { - concat_dim = static_cast(concat_dim_t->flat()(0)); + concat_dim = static_cast(concat_dim_t->flat()(0)); } else { concat_dim = concat_dim_t->flat()(0); } // Minimum required number of dimensions. - const int64 min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; + const int64_t min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; ShapeHandle output_before; ShapeHandle output_after; @@ -2510,7 +2511,7 @@ absl::Status SliceShape(InferenceContext* c) { SliceHelper(c, begin_value, sizes_value, &dims)); } else { TF_RETURN_IF_ERROR( - SliceHelper(c, begin_value, sizes_value, &dims)); + SliceHelper(c, begin_value, sizes_value, &dims)); } c->set_output(0, c->MakeShape(dims)); return absl::OkStatus(); @@ -2749,7 +2750,7 @@ absl::Status SparseReduceShapeFn(InferenceContext* c) { const Tensor* axes_tensor = c->input_tensor(3); if (shape_tensor != nullptr && axes_tensor != nullptr) { auto shape_vec = shape_tensor->flat(); - auto axes_vec = axes_tensor->flat(); + auto axes_vec = axes_tensor->flat(); int64_t ndims = shape_vec.size(); absl::flat_hash_set axes; @@ -2797,7 +2798,7 @@ absl::Status QuantizedConv2DShape(InferenceContext* c) { } absl::Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { - std::vector fused_ops; + std::vector fused_ops; TF_RETURN_IF_ERROR(c->GetAttr("fused_ops", &fused_ops)); ShapeHandle unused, channel; bool fused_sum, fused_bias, fused_requantize; diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 9bc8a20208096f..a97915901cc027 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -220,7 +220,7 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) { TEST(CommonShapeFnsTest, Einsum_ShapeFn) { ShapeInferenceTestOp op("Einsum"); - auto set_equation = [&op](int n, string equation) { + auto set_equation = [&op](int n, std::string equation) { std::vector input_list; input_list.reserve(n); for (int i = 0; i < n; ++i) { @@ -629,8 +629,9 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { TEST(CommonShapeFnsTest, ConvTest) { ShapeInferenceTestOp op("Conv"); - auto set_op = [&op](const std::vector& strides, const string& padding, - string data_format, int batch_dims, int groups) { + auto set_op = [&op](const std::vector& strides, + const std::string& padding, std::string data_format, + int batch_dims, int groups) { TF_CHECK_OK(NodeDefBuilder("test", op.name) .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) @@ -715,9 +716,11 @@ TEST(CommonShapeFnsTest, ConvTest) { TEST(CommonShapeFnsTest, Conv2DFormatsTest) { ShapeInferenceTestOp op("Conv2D"); - auto set_op = [&op](const std::vector& strides, const string& padding, - const string& data_format, const string& filter_format, - const std::vector& explicit_paddings = {}) { + auto set_op = [&op](const std::vector& strides, + const std::string& padding, + const std::string& data_format, + const std::string& filter_format, + const std::vector& explicit_paddings = {}) { TF_CHECK_OK(NodeDefBuilder("test", op.name) .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) @@ -761,15 +764,17 @@ TEST(CommonShapeFnsTest, Conv2DFormatsTest) { INFER_OK(op, "[1,1,4,4,32];[32,1,2,1,32]", "[d0_0,1,3,2,d0_4]"); } -class Conv2DShapeTest : public ::testing::TestWithParam {}; +class Conv2DShapeTest : public ::testing::TestWithParam {}; TEST_P(Conv2DShapeTest, Conv2DShapeTest) { - const string op_name = GetParam(); + const std::string op_name = GetParam(); ShapeInferenceTestOp op(op_name); - auto set_op = [&op](const std::vector& strides, const string& padding, - const string& data_format, const string& filter_format, - const std::vector& explicit_paddings = {}) { - string format; + auto set_op = [&op](const std::vector& strides, + const std::string& padding, + const std::string& data_format, + const std::string& filter_format, + const std::vector& explicit_paddings = {}) { + std::string format; if (op.name == "Conv") format = (data_format == "NHWC") ? "CHANNELS_LAST" : "CHANNELS_FIRST"; else @@ -974,13 +979,14 @@ TEST_P(Conv2DShapeTest, Conv2DShapeTest) { } TEST_P(Conv2DShapeTest, Conv2DDilatedShapeTest) { - const string op_name = GetParam(); + const std::string op_name = GetParam(); ShapeInferenceTestOp op(op_name); - auto set_op = [&op](const std::vector& dilations, - const std::vector& strides, const string& padding, - const string& data_format, - const std::vector& explicit_paddings = {}) { - string format; + auto set_op = [&op](const std::vector& dilations, + const std::vector& strides, + const std::string& padding, + const std::string& data_format, + const std::vector& explicit_paddings = {}) { + std::string format; if (op.name == "Conv") format = (data_format == "NHWC") ? "CHANNELS_LAST" : "CHANNELS_FIRST"; else @@ -1129,8 +1135,8 @@ TEST(CommonShapeFnsTest, Conv3DShapeRankTest) { TEST(CommonShapeFnsTest, Conv3DGroupsTest) { ShapeInferenceTestOp op("Conv3D"); - auto set_op = [&op](const std::vector& strides, - const string& padding) { + auto set_op = [&op](const std::vector& strides, + const std::string& padding) { TF_CHECK_OK(NodeDefBuilder("test", "Conv3D") .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) @@ -1166,13 +1172,13 @@ TEST(CommonShapeFnsTest, Conv3DGroupsTest) { INSTANTIATE_TEST_SUITE_P(CommonShapeFnsTest, Conv2DShapeTest, ::testing::Values("Conv2D", "Conv")); -class Conv3DShapeTest : public ::testing::TestWithParam {}; +class Conv3DShapeTest : public ::testing::TestWithParam {}; TEST_P(Conv3DShapeTest, Conv3DShapeTest) { - const string op_name = GetParam(); + const std::string op_name = GetParam(); ShapeInferenceTestOp op(op_name); - auto set_op = [&op](const std::vector& strides, - const string& padding) { + auto set_op = [&op](const std::vector& strides, + const std::string& padding) { TF_CHECK_OK(NodeDefBuilder("test", op.name) .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) @@ -1245,11 +1251,11 @@ TEST_P(Conv3DShapeTest, Conv3DShapeTest) { } TEST_P(Conv3DShapeTest, Conv3DDilatedShapeTest) { - const string op_name = GetParam(); + const std::string op_name = GetParam(); ShapeInferenceTestOp op(op_name); - auto set_op = [&op](const std::vector& dilations, - const std::vector& strides, - const string& padding) { + auto set_op = [&op](const std::vector& dilations, + const std::vector& strides, + const std::string& padding) { TF_CHECK_OK(NodeDefBuilder("test", op.name) .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) @@ -1300,7 +1306,7 @@ INSTANTIATE_TEST_SUITE_P(CommonShapeFnsTest, Conv3DShapeTest, TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) { ShapeInferenceTestOp op("DepthwiseConv2dNative"); - std::vector strides = {{1, 1, 1, 1}}; + std::vector strides = {{1, 1, 1, 1}}; TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative") .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) @@ -1344,9 +1350,10 @@ TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) { TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { ShapeInferenceTestOp op("AvgPool"); - auto set_op = [&op](const std::vector& strides, - const std::vector& ksizes, const string& padding, - const string& data_format) { + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, + const std::string& padding, + const std::string& data_format) { TF_CHECK_OK(NodeDefBuilder("test", "AvgPool") .Input("input", 0, DT_FLOAT) .Attr("strides", strides) @@ -1390,9 +1397,10 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { ShapeInferenceTestOp op("MaxPool"); - auto set_op = [&op](const std::vector& strides, - const std::vector& ksizes, const string& padding, - const string& data_format) { + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, + const std::string& padding, + const std::string& data_format) { TF_CHECK_OK(NodeDefBuilder("test", "MaxPool") .Input("input", 0, DT_FLOAT) .Attr("strides", strides) @@ -1426,9 +1434,10 @@ TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { ShapeInferenceTestOp op("MaxPoolV2"); Tensor ksizes_tensor, strides_tensor; auto set_op = [&op, &ksizes_tensor, &strides_tensor]( - const std::vector& strides, - const std::vector& ksizes, const string& padding, - const string& data_format) { + const std::vector& strides, + const std::vector& ksizes, + const std::string& padding, + const std::string& data_format) { TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2") .Input("input", 0, DT_FLOAT) .Input("ksize", 1, DT_INT32) @@ -1436,11 +1445,11 @@ TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { .Attr("padding", padding) .Attr("data_format", data_format) .Finalize(&op.node_def)); - ksizes_tensor = test::AsTensor(ksizes); + ksizes_tensor = test::AsTensor(ksizes); op.input_tensors.resize(3); op.input_tensors[0] = nullptr; op.input_tensors[1] = &ksizes_tensor; - strides_tensor = test::AsTensor(strides); + strides_tensor = test::AsTensor(strides); op.input_tensors[2] = &strides_tensor; }; @@ -1466,8 +1475,9 @@ TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { TEST(CommonShapeFnsTest, Pool3DShapeTest) { ShapeInferenceTestOp op("MaxPool3D"); - auto set_op = [&op](const std::vector& strides, - const std::vector& ksizes, const string& padding) { + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, + const std::string& padding) { TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D") .Input("input", 0, DT_FLOAT) .Attr("strides", strides) @@ -1524,28 +1534,28 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) { INFER_OK(op, "[2,4,5];[2]", "?"); INFER_OK(op, "?;[2]", "?"); - Tensor indices = test::AsTensor({1, 2}); + Tensor indices = test::AsTensor({1, 2}); op.input_tensors[1] = &indices; // Reduction indices available INFER_OK(op, "[2,4,5];[2]", "[d0_0]"); // Wrapped indices - indices = test::AsTensor({-1, -2}); + indices = test::AsTensor({-1, -2}); op.input_tensors[1] = &indices; INFER_OK(op, "[2,4,5];[2]", "[d0_0]"); // Scalar - indices = test::AsScalar(0); + indices = test::AsScalar(0); op.input_tensors[1] = &indices; INFER_OK(op, "[2,4,5];[]", "[d0_1,d0_2]"); - indices = test::AsScalar(-4); + indices = test::AsScalar(-4); op.input_tensors[1] = &indices; INFER_ERROR("Invalid reduction dimension", op, "[2,4,5];[]"); // Empty reduction indices - indices = test::AsTensor({}); + indices = test::AsTensor({}); op.input_tensors[1] = &indices; INFER_OK(op, "[2,4,5];[0]", "[d0_0,d0_1,d0_2]"); @@ -1555,7 +1565,7 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) { .Input("reduction_indices", 1, DT_INT32) .Attr("keep_dims", true) .Finalize(&op.node_def)); - indices = test::AsTensor({-1, -2}); + indices = test::AsTensor({-1, -2}); op.input_tensors[1] = &indices; INFER_OK(op, "[2,4,5];[2]", "[d0_0, 1, 1]"); @@ -1572,9 +1582,9 @@ TEST(CommonShapeFnsTest, Reduce_ShapeFn) { INFER_OK(op, "[?,?,?];[?,?]", "[?,?,?]"); // And when the tensor is specified, it's still allowed. op.input_tensors[1] = &indices; - indices = test::AsTensor({-1, -2}, TensorShape({2, 1})); + indices = test::AsTensor({-1, -2}, TensorShape({2, 1})); INFER_OK(op, "[2,4,5];[2,1]", "[d0_0, 1, 1]"); - indices = test::AsTensor({-1, -2}, TensorShape({1, 2})); + indices = test::AsTensor({-1, -2}, TensorShape({1, 2})); INFER_OK(op, "[2,4,5];[1,2]", "[d0_0, 1, 1]"); } diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h index 3cc270b323d92f..a70ecb85214e31 100644 --- a/tensorflow/core/framework/control_flow.h +++ b/tensorflow/core/framework/control_flow.h @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -const uint64 kIllegalFrameId = ~0uLL; +const uint64_t kIllegalFrameId = ~0uLL; const int64_t kIllegalIterId = -1; // For the purpose of control flow, every tensor produced by TensorFlow is @@ -30,12 +30,12 @@ const int64_t kIllegalIterId = -1; // 'frame_id' and an 'iter_id'. The tensor value it represents is produced // in the frame with frame_id at the iteration of iter_id. struct FrameAndIter { - uint64 frame_id = kIllegalFrameId; + uint64_t frame_id = kIllegalFrameId; int64_t iter_id = kIllegalIterId; FrameAndIter() {} - FrameAndIter(uint64 frame, int64_t iter) { + FrameAndIter(uint64_t frame, int64_t iter) { frame_id = frame; iter_id = iter; } @@ -48,7 +48,7 @@ struct FrameAndIter { struct FrameAndIterHash { size_t operator()(const FrameAndIter& key) const { // Make sure there are no padding bytes that we don't want - CHECK_EQ(sizeof(uint64) + sizeof(int64_t), sizeof(FrameAndIter)); + CHECK_EQ(sizeof(uint64_t) + sizeof(int64_t), sizeof(FrameAndIter)); return Hash64(reinterpret_cast(&key), sizeof(FrameAndIter)); } }; diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 491eb5293f22ad..69593a67d90352 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -52,8 +52,9 @@ static mutex* get_dataset_op_registry_lock() { return &dataset_op_registry_lock; } -static std::unordered_set* get_dataset_op_registry() { - static std::unordered_set* names = new std::unordered_set; +static std::unordered_set* get_dataset_op_registry() { + static std::unordered_set* names = + new std::unordered_set; return names; } @@ -97,8 +98,8 @@ class DatasetVariantWrapper { DatasetBase* get() const { return dataset_; } - string TypeName() const { return "tensorflow::DatasetVariantWrapper"; } - string DebugString() const { + std::string TypeName() const { return "tensorflow::DatasetVariantWrapper"; } + std::string DebugString() const { if (dataset_) { return dataset_->DebugString(); } else { @@ -131,9 +132,11 @@ class WrappedDatasetVariantWrapper { Tensor get() const { return ds_tensor_; } - string TypeName() const { return "tensorflow::WrappedDatasetVariantWrapper"; } + std::string TypeName() const { + return "tensorflow::WrappedDatasetVariantWrapper"; + } - string DebugString() const { + std::string DebugString() const { return "tensorflow::WrappedDatasetVariantWrapper::DebugString"; } @@ -324,7 +327,7 @@ absl::Status GraphDefBuilderWrapper::AddDataset( } absl::Status GraphDefBuilderWrapper::AddFunction( - SerializationContext* ctx, const string& function_name, + SerializationContext* ctx, const std::string& function_name, const FunctionLibraryDefinition& lib_def) { if (b_->HasFunction(function_name)) { VLOG(1) << "Function with name " << function_name << "already exists in" @@ -338,7 +341,7 @@ absl::Status GraphDefBuilderWrapper::AddFunction( } FunctionDefLibrary def; *def.add_function() = *f_def; - const string gradient_func = lib_def.FindGradient(function_name); + const std::string gradient_func = lib_def.FindGradient(function_name); if (!gradient_func.empty()) { GradientDef* g_def = def.add_gradient(); g_def->set_function_name(function_name); @@ -380,8 +383,8 @@ void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val)); } -bool GraphDefBuilderWrapper::HasAttr(const string& name, - const string& attr_name) const { +bool GraphDefBuilderWrapper::HasAttr(const std::string& name, + const std::string& attr_name) const { const OpDef* op_def = nullptr; absl::Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def); if (!s.ok() || op_def == nullptr) { @@ -535,11 +538,11 @@ absl::Status MemoryCheckpoint::Save(IteratorStateWriter* writer) const { absl::Status IteratorBase::InitializeBase(IteratorContext* ctx, const IteratorBase* parent) { parent_ = parent; - id_ = - Hash64CombineUnordered(Hash64(prefix()), reinterpret_cast(this)); + id_ = Hash64CombineUnordered(Hash64(prefix()), + reinterpret_cast(this)); if (parent_) { parent_id_ = Hash64CombineUnordered(Hash64(parent_->prefix()), - reinterpret_cast(parent_)); + reinterpret_cast(parent_)); // This block of code is executed only when `parent_` is not a `nullptr` // because we do not create a `Node` in the `Model` for `RootDataset`. if (const auto& model = ctx->model()) { @@ -626,17 +629,17 @@ std::string FullName(const std::string& prefix, const std::string& name) { return strings::StrCat(kFullNameRandomHex, kPipe, prefix, kColon, name); } -absl::Status ExtractIteratorPrefix(absl::string_view key, string* prefix) { +absl::Status ExtractIteratorPrefix(absl::string_view key, std::string* prefix) { if (!absl::StartsWith(key, data::kFullNameRandomHex)) { return errors::InvalidArgument("Key: ", key, " was not generated using full_name."); } - std::vector split_keys = str_util::Split(key, data::kPipe); + std::vector split_keys = str_util::Split(key, data::kPipe); if (split_keys.size() != 2) { return errors::InvalidArgument("Key: ", key, " was not generated using full_name."); } - string real_key = split_keys[1]; + std::string real_key = split_keys[1]; const int pos = real_key.rfind(kColon); *prefix = real_key.substr(0, pos); return absl::OkStatus(); @@ -811,10 +814,11 @@ absl::Status DatasetBase::ComputeNumSources() { return absl::OkStatus(); } -absl::Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { +absl::Status DatasetBase::CheckRandomAccessCompatible( + const int64_t index) const { CardinalityOptions options; options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_MODERATE); - int64 cardinality = Cardinality(options); + int64_t cardinality = Cardinality(options); if (cardinality == kInfiniteCardinality || cardinality == kUnknownCardinality) { return tensorflow::errors::FailedPrecondition( @@ -829,13 +833,13 @@ absl::Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { return absl::OkStatus(); } -absl::Status DatasetBase::Get(OpKernelContext* ctx, int64 index, +absl::Status DatasetBase::Get(OpKernelContext* ctx, int64_t index, std::vector* out_tensors) const { return errors::Unimplemented("Random access is not implemented for dataset ", DebugString()); } -absl::Status DatasetBase::Get(AnyContext ctx, int64 index, +absl::Status DatasetBase::Get(AnyContext ctx, int64_t index, std::vector* out_tensors) const { return errors::Unimplemented("Random access is not implemented for dataset ", DebugString()); @@ -876,7 +880,7 @@ absl::Status DatasetBase::MergeOptionsFromInputs() { absl::Status DatasetBase::MakeIterator( IteratorContext* ctx, const IteratorBase* parent, - const string& output_prefix, + const std::string& output_prefix, std::unique_ptr* iterator) const { if (type_string() == "OptionsDataset" || type_string() == "FinalizeDataset") { std::vector inputs; @@ -1103,8 +1107,8 @@ DatasetBaseIterator::~DatasetBaseIterator() { params_.dataset->Unref(); } -string DatasetBaseIterator::BuildTraceMeName() { - string result = +std::string DatasetBaseIterator::BuildTraceMeName() { + std::string result = strings::StrCat(params_.prefix, "#", traceme_metadata_, ",id=", id_); if (parent_) { absl::StrAppend(&result, ",parent_id=", parent_id_); @@ -1274,8 +1278,8 @@ void DatasetOpKernel::Compute(OpKernelContext* ctx) { } } -string DatasetOpKernel::TraceString(const OpKernelContext& ctx, - bool verbose) const { +std::string DatasetOpKernel::TraceString(const OpKernelContext& ctx, + bool verbose) const { return tsl::profiler::TraceMeOp(name_view(), type_string_view()); } @@ -1310,7 +1314,7 @@ bool DatasetOpKernel::IsDatasetOp(const OpDef& op_def) { // Check if the suffix matches "DatasetV[0-9]+". size_t index = op_name.length() - 1; - while (index >= 0 && isdigit(op_name[index])) { + while (index >= 0 && absl::ascii_isdigit(op_name[index])) { index--; } constexpr absl::string_view kDatasetPrefix = "DatasetV"; diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index b807208647c1cb..2471c3dc08cc0a 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -87,7 +87,7 @@ void MergeOptions(const protobuf::MessageLite& source, protobuf::MessageLite* destination); } // namespace internal -using TraceMeMetadata = std::vector>; +using TraceMeMetadata = std::vector>; // Maps the index of dataset elements to a globally shuffled index. See the // comment for IteratorContext::Params::index_mapper for more details. @@ -211,7 +211,7 @@ class IteratorStateWriter { std::string FullName(const std::string& prefix, const std::string& name); // Extracts iterator prefix from key generated by `FullName`. -absl::Status ExtractIteratorPrefix(absl::string_view key, string* prefix); +absl::Status ExtractIteratorPrefix(absl::string_view key, std::string* prefix); // Interface for objects that can be checkpointed. class Checkpointable { @@ -264,7 +264,7 @@ class GraphDefBuilderWrapper { return absl::OkStatus(); } - absl::Status AddVector(const std::vector& val, Node** output) { + absl::Status AddVector(const std::vector& val, Node** output) { Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({static_cast(val.size())})); for (size_t i = 0; i < val.size(); i++) { @@ -350,7 +350,7 @@ class GraphDefBuilderWrapper { // or any of its dependent functions are stateful, and the context does not // explicitly permit stateful functions, returns an InvalidArgument error. absl::Status AddFunction(SerializationContext* ctx, - const string& function_name, + const std::string& function_name, const FunctionLibraryDefinition& lib_def); template @@ -371,9 +371,10 @@ class GraphDefBuilderWrapper { private: void AddPlaceholderInternal(const Tensor& val, Node** output); void AddTensorInternal(const Tensor& val, Node** output); - bool HasAttr(const string& op_type_name, const string& attr_name) const; + bool HasAttr(const std::string& op_type_name, + const std::string& attr_name) const; - bool HasAttr(const OpDef* op_def, const string& attr_name) const { + bool HasAttr(const OpDef* op_def, const std::string& attr_name) const { for (const auto& attr : op_def->attr()) { if (attr.name() == attr_name) { return true; @@ -509,35 +510,35 @@ class MemoryCheckpoint final : public IteratorStateWriter { // BEGIN implementation of `IteratorStateWriter` interface absl::Status WriteScalar(absl::string_view key, int64_t val) override { - string prefix; + std::string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } absl::Status WriteScalar(absl::string_view name, absl::string_view key, int64_t val) override { - auto id = id_registry_->Add(string(name), string(key)); + auto id = id_registry_->Add(std::string(name), std::string(key)); int_values_[id] = val; return absl::OkStatus(); } absl::Status WriteScalar(absl::string_view key, const tstring& val) override { - string prefix; + std::string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } absl::Status WriteScalar(absl::string_view name, absl::string_view key, const tstring& val) override { - auto id = id_registry_->Add(string(name), string(key)); + auto id = id_registry_->Add(std::string(name), std::string(key)); str_values_[id] = val; return absl::OkStatus(); } absl::Status WriteTensor(absl::string_view key, const Tensor& val) override { - string prefix; + std::string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteTensor(prefix, key, val); } absl::Status WriteTensor(absl::string_view name, absl::string_view key, const Tensor& val) override { - auto id = id_registry_->Add(string(name), string(key)); + auto id = id_registry_->Add(std::string(name), std::string(key)); tensor_values_[id] = val; return absl::OkStatus(); } @@ -614,7 +615,8 @@ class SerializationContext { : resource_mgr(ctx->resource_manager()), device_name(ctx->device()->attributes().name()) {} - std::vector>* input_list = nullptr; // Not owned. + std::vector>* input_list = + nullptr; // Not owned. // Indicates what to do if the dataset depends on external state. ExternalStatePolicy external_state_policy = @@ -653,7 +655,7 @@ class SerializationContext { explicit SerializationContext(Params params) : params_(params) {} - std::vector>* input_list() { + std::vector>* input_list() { return params_.input_list; } @@ -773,7 +775,7 @@ class IteratorContext { // Records the number of ParallelInterleave operations in the path from the // root node to this node (not including this node) in the input pipeline // tree. - int64 interleave_depth = 0; + int64_t interleave_depth = 0; // Marks whether the iterator is restored from a checkpoint. bool is_restoring = false; @@ -795,7 +797,7 @@ class IteratorContext { std::function)> runner = nullptr; // Number of threads used for executing user-defined functions. - int32 runner_threadpool_size = 0; + int32_t runner_threadpool_size = 0; // Split providers indicating which splits to process. May be empty, // indicating that the iterator should process all splits. @@ -891,7 +893,7 @@ class IteratorContext { MemoryCheckpoint* checkpoint() { return &checkpoint_; } - int64 interleave_depth() { return params_.interleave_depth; } + int64_t interleave_depth() { return params_.interleave_depth; } bool is_restoring() { return params_.is_restoring; } @@ -909,7 +911,7 @@ class IteratorContext { return ¶ms_.runner; } - int32 runner_threadpool_size() { return params_.runner_threadpool_size; } + int32_t runner_threadpool_size() { return params_.runner_threadpool_size; } std::vector> split_providers() const { return params_.split_providers; @@ -949,7 +951,7 @@ class IteratorContext { params_.index_mapper = index_mapper; }; - std::unique_ptr CreateThreadPool(const string& name, + std::unique_ptr CreateThreadPool(const std::string& name, int num_threads) { if (params_.thread_pool) { // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which @@ -1010,7 +1012,7 @@ class IteratorContext { } } - std::unique_ptr StartThread(const string& name, + std::unique_ptr StartThread(const std::string& name, std::function fn) { if (params_.thread_factory) { return params_.thread_factory->StartThread(name, std::move(fn)); @@ -1133,7 +1135,7 @@ class IteratorBase : public Checkpointable { // Returns a string that identifies the sequence of iterators leading up to // this iterator. - virtual const string& prefix() const = 0; + virtual const std::string& prefix() const = 0; // Indicates whether the iterator is compatible with symbolic checkpointing. virtual bool SymbolicCheckpointCompatible() const { return false; } @@ -1248,9 +1250,9 @@ class IteratorBase : public Checkpointable { class DatasetContext { public: struct Params { - string type_string; // op type name of this dataset. - string node_name; // graph node name of this dataset op, uniquely - // identifying the dataset in the graph. + std::string type_string; // op type name of this dataset. + std::string node_name; // graph node name of this dataset op, uniquely + // identifying the dataset in the graph. }; explicit DatasetContext(Params params) : params_(std::move(params)) {} @@ -1260,8 +1262,8 @@ class DatasetContext { params_.node_name = ctx->op_kernel().name(); } - const string& type_string() const { return params_.type_string; } - const string& node_name() const { return params_.node_name; } + const std::string& type_string() const { return params_.type_string; } + const std::string& node_name() const { return params_.node_name; } private: Params params_; @@ -1304,11 +1306,11 @@ class DatasetBase : public core::RefCounted { : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {} // Op type name of this dataset. - const string& type_string() const { return type_string_; } + const std::string& type_string() const { return type_string_; } // Graph node name of this dataset op, uniquely identifying the dataset in // the graph. - const string& node_name() const { return node_name_; } + const std::string& node_name() const { return node_name_; } const Metadata& metadata() const { return metadata_; } @@ -1330,18 +1332,18 @@ class DatasetBase : public core::RefCounted { // The prefix identifies the sequence of iterators leading up to the newly // created iterator. absl::Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent, - const string& output_prefix, + const std::string& output_prefix, std::unique_ptr* iterator) const; absl::Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent, - const string& output_prefix, + const std::string& output_prefix, std::unique_ptr* iterator) const { return MakeIterator(&ctx, parent, output_prefix, iterator); } // Returns a new iterator restored from the checkpoint data in `reader`. absl::Status MakeIteratorFromCheckpoint( - IteratorContext* ctx, const string& output_prefix, + IteratorContext* ctx, const std::string& output_prefix, IteratorStateReader* reader, std::unique_ptr* iterator) const { std::unique_ptr it; @@ -1357,7 +1359,7 @@ class DatasetBase : public core::RefCounted { } absl::Status MakeIteratorFromCheckpoint( - IteratorContext&& ctx, const string& output_prefix, + IteratorContext&& ctx, const std::string& output_prefix, IteratorStateReader* reader, std::unique_ptr* iterator) const { return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator); @@ -1405,7 +1407,7 @@ class DatasetBase : public core::RefCounted { } // A human-readable debug string for this dataset. - virtual string DebugString() const = 0; + virtual std::string DebugString() const = 0; // Stores the dataset's input datasets in `*inputs`. The pointers stored in // `*inputs` are borrowed. The only valid non-ok return status is @@ -1423,16 +1425,16 @@ class DatasetBase : public core::RefCounted { virtual absl::Status CheckExternalState() const = 0; // Indicates whether the dataset is compatible with random access. - absl::Status CheckRandomAccessCompatible(const int64 index) const; + absl::Status CheckRandomAccessCompatible(const int64_t index) const; // Return the element at a particular index for a randomly accessible dataset. - virtual absl::Status Get(OpKernelContext* ctx, int64 index, + virtual absl::Status Get(OpKernelContext* ctx, int64_t index, std::vector* out_tensors) const; // Same as above, but with an `AnyContext`, which can be constructed from // either an `OpKernelContext` or `IteratorContext`. Used to support datasets // that provide random access through both the dataset and iterator APIs. - virtual absl::Status Get(AnyContext ctx, int64 index, + virtual absl::Status Get(AnyContext ctx, int64_t index, std::vector* out_tensors) const; // Returns true if the dataset and its inputs support random access. @@ -1487,7 +1489,7 @@ class DatasetBase : public core::RefCounted { Node** node) const = 0; virtual std::unique_ptr MakeIteratorInternal( - const string& prefix) const = 0; + const std::string& prefix) const = 0; void set_options(const Options& options) { options_ = options; } @@ -1505,8 +1507,8 @@ class DatasetBase : public core::RefCounted { // how they appear for this dataset. absl::Status MergeOptionsFromInputs(); - const string type_string_; - const string node_name_; + const std::string type_string_; + const std::string node_name_; Metadata metadata_; Options options_; mutable mutex mu_; @@ -1527,7 +1529,7 @@ class DatasetBaseIterator : public IteratorBase { const DatasetBase* dataset; // Identifies the sequence of iterators leading up to this iterator. - const string prefix; + const std::string prefix; }; explicit DatasetBaseIterator(const BaseParams& params); @@ -1544,13 +1546,13 @@ class DatasetBaseIterator : public IteratorBase { return params_.dataset->output_shapes(); } - const string& prefix() const override { return params_.prefix; } + const std::string& prefix() const override { return params_.prefix; } // Returns a name to be used for the TraceMe event. // // NOTE: TraceMe supports passing key-value pairs of "arguments" using the // following format "name#arg_1=value_,...,arg_n=value_n". - string BuildTraceMeName(); + std::string BuildTraceMeName(); absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) final; @@ -1602,7 +1604,7 @@ class DatasetBaseIterator : public IteratorBase { virtual absl::Status SkipInternal(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence, int* num_skipped); - string full_name(const string& name) const { + std::string full_name(const std::string& name) const { return FullName(params_.prefix, name); } @@ -1693,7 +1695,7 @@ class DatasetBaseIterator : public IteratorBase { return ctx->model() && node_; } - string traceme_metadata_; + std::string traceme_metadata_; BaseParams params_; }; @@ -1707,7 +1709,7 @@ class DatasetIterator : public DatasetBaseIterator { const DatasetType* dataset; // Identifies the sequence of iterators leading up to this iterator. - const string prefix; + const std::string prefix; }; explicit DatasetIterator(const Params& params) @@ -1774,7 +1776,8 @@ class DatasetOpKernel : public OpKernel { // names that end with "Dataset" or "DatasetV[0-9]+". static bool IsDatasetOp(const OpDef& op_def); - string TraceString(const OpKernelContext& ctx, bool verbose) const override; + std::string TraceString(const OpKernelContext& ctx, + bool verbose) const override; protected: // Subclasses should implement this method. It will be called during Compute diff --git a/tensorflow/core/framework/dataset_stateful_op_allowlist.h b/tensorflow/core/framework/dataset_stateful_op_allowlist.h index cc25c801bf60b1..14b16b309ea5c1 100644 --- a/tensorflow/core/framework/dataset_stateful_op_allowlist.h +++ b/tensorflow/core/framework/dataset_stateful_op_allowlist.h @@ -25,17 +25,17 @@ namespace data { // See below macro for usage details. class AllowlistedStatefulOpRegistry { public: - absl::Status Add(string op_name) { + absl::Status Add(std::string op_name) { op_names_.insert(std::move(op_name)); return absl::OkStatus(); } - absl::Status Remove(string op_name) { + absl::Status Remove(std::string op_name) { op_names_.erase(op_name); return absl::OkStatus(); } - bool Contains(const string& op_name) { return op_names_.count(op_name); } + bool Contains(const std::string& op_name) { return op_names_.count(op_name); } static AllowlistedStatefulOpRegistry* Global() { static auto* reg = new AllowlistedStatefulOpRegistry; @@ -49,7 +49,7 @@ class AllowlistedStatefulOpRegistry { AllowlistedStatefulOpRegistry operator=( AllowlistedStatefulOpRegistry const& copy) = delete; - std::unordered_set op_names_; + std::unordered_set op_names_; }; } // namespace data diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc index 66213ea5721b13..b572de72e54113 100644 --- a/tensorflow/core/framework/dataset_test.cc +++ b/tensorflow/core/framework/dataset_test.cc @@ -68,8 +68,8 @@ TEST_P(DatasetTestTotalBytes, TestTotalBytes) { } std::vector tensor_tf_int_32s() { - return {test::AsTensor({1, 2, 3, 4, 5}), - test::AsTensor({1, 2, 3, 4})}; + return {test::AsTensor({1, 2, 3, 4, 5}), + test::AsTensor({1, 2, 3, 4})}; } std::vector tensor_tf_int_64s() { diff --git a/tensorflow/core/framework/device.cc b/tensorflow/core/framework/device.cc index 59730e3ce1d436..1adb6e7eaf1641 100644 --- a/tensorflow/core/framework/device.cc +++ b/tensorflow/core/framework/device.cc @@ -41,8 +41,8 @@ void Device::Sync(const DoneCallback& done) { done(Sync()); } // static DeviceAttributes Device::BuildDeviceAttributes( - const string& name, DeviceType device, Bytes memory_limit, - const DeviceLocality& locality, const string& physical_device_desc) { + const std::string& name, DeviceType device, Bytes memory_limit, + const DeviceLocality& locality, const std::string& physical_device_desc) { DeviceAttributes da; da.set_name(name); do { diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc index 44db0a284f1f79..891d45f237e61e 100644 --- a/tensorflow/core/framework/device_base.cc +++ b/tensorflow/core/framework/device_base.cc @@ -66,7 +66,7 @@ const DeviceAttributes& DeviceBase::attributes() const { std::abort(); } -const string& DeviceBase::name() const { +const std::string& DeviceBase::name() const { LOG(FATAL) << "DeviceBase does not implement name()"; // Crash OK std::abort(); } diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index fe5099fa361429..15c4e6bba6ae9e 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -269,7 +269,7 @@ class DeviceBase { // device memory tagged with an earlier freed-at count is really unencumbered // by pending uses. For this to be useful the device memory allocator must // be tagging deallocated memory chunks using the same counter. - virtual uint64 SafeAllocFrontier(uint64 old_value) { return 0; } + virtual uint64_t SafeAllocFrontier(uint64_t old_value) { return 0; } // Copies `input_tensor` to `output_tensor`, where both tensors are on this // device. This function assumes that `output_tensor` has already been diff --git a/tensorflow/core/framework/device_factory.cc b/tensorflow/core/framework/device_factory.cc index 392b44f2eb177c..d6374d41a93bb7 100644 --- a/tensorflow/core/framework/device_factory.cc +++ b/tensorflow/core/framework/device_factory.cc @@ -47,14 +47,14 @@ struct FactoryItem { bool is_pluggable_device; }; -std::unordered_map& device_factories() { - static std::unordered_map* factories = - new std::unordered_map; +std::unordered_map& device_factories() { + static std::unordered_map* factories = + new std::unordered_map; return *factories; } -bool IsDeviceFactoryEnabled(const string& device_type) { - std::vector enabled_devices; +bool IsDeviceFactoryEnabled(const std::string& device_type) { + std::vector enabled_devices; TF_CHECK_OK(tensorflow::ReadStringsFromEnvVar( /*env_var_name=*/"TF_ENABLED_DEVICE_TYPES", /*default_val=*/"", &enabled_devices)); @@ -67,9 +67,9 @@ bool IsDeviceFactoryEnabled(const string& device_type) { } // namespace // static -int32 DeviceFactory::DevicePriority(const string& device_type) { +int32_t DeviceFactory::DevicePriority(const std::string& device_type) { tf_shared_lock l(*get_device_factory_lock()); - std::unordered_map& factories = device_factories(); + std::unordered_map& factories = device_factories(); auto iter = factories.find(device_type); if (iter != factories.end()) { return iter->second.priority; @@ -78,9 +78,9 @@ int32 DeviceFactory::DevicePriority(const string& device_type) { return -1; } -bool DeviceFactory::IsPluggableDevice(const string& device_type) { +bool DeviceFactory::IsPluggableDevice(const std::string& device_type) { tf_shared_lock l(*get_device_factory_lock()); - std::unordered_map& factories = device_factories(); + std::unordered_map& factories = device_factories(); auto iter = factories.find(device_type); if (iter != factories.end()) { return iter->second.is_pluggable_device; @@ -89,7 +89,7 @@ bool DeviceFactory::IsPluggableDevice(const string& device_type) { } // static -void DeviceFactory::Register(const string& device_type, +void DeviceFactory::Register(const std::string& device_type, std::unique_ptr factory, int priority, bool is_pluggable_device) { if (!IsDeviceFactoryEnabled(device_type)) { @@ -98,7 +98,7 @@ void DeviceFactory::Register(const string& device_type, return; } mutex_lock l(*get_device_factory_lock()); - std::unordered_map& factories = device_factories(); + std::unordered_map& factories = device_factories(); auto iter = factories.find(device_type); if (iter == factories.end()) { factories[device_type] = {std::move(factory), priority, @@ -113,7 +113,7 @@ void DeviceFactory::Register(const string& device_type, } } -DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { +DeviceFactory* DeviceFactory::GetFactory(const std::string& device_type) { tf_shared_lock l(*get_device_factory_lock()); auto it = device_factories().find(device_type); if (it == device_factories().end()) { @@ -128,7 +128,7 @@ DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { } absl::Status DeviceFactory::ListAllPhysicalDevices( - std::vector* devices) { + std::vector* devices) { // CPU first. A CPU device is required. // TODO(b/183974121): Consider merge the logic into the loop below. auto cpu_factory = GetFactory("CPU"); @@ -156,7 +156,7 @@ absl::Status DeviceFactory::ListAllPhysicalDevices( } absl::Status DeviceFactory::ListPluggablePhysicalDevices( - std::vector* devices) { + std::vector* devices) { tf_shared_lock l(*get_device_factory_lock()); for (auto& p : device_factories()) { if (p.second.is_pluggable_device) { @@ -168,7 +168,7 @@ absl::Status DeviceFactory::ListPluggablePhysicalDevices( } absl::Status DeviceFactory::GetAnyDeviceDetails( - int device_index, std::unordered_map* details) { + int device_index, std::unordered_map* details) { if (device_index < 0) { return errors::InvalidArgument("Device index out of bounds: ", device_index); @@ -183,7 +183,7 @@ absl::Status DeviceFactory::GetAnyDeviceDetails( } // TODO(b/183974121): Consider merge the logic into the loop below. - std::vector devices; + std::vector devices; TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices)); if (device_index < devices.size()) { return cpu_factory->GetDeviceDetails(device_index, details); @@ -211,7 +211,7 @@ absl::Status DeviceFactory::GetAnyDeviceDetails( } absl::Status DeviceFactory::AddCpuDevices( - const SessionOptions& options, const string& name_prefix, + const SessionOptions& options, const std::string& name_prefix, std::vector>* devices) { auto cpu_factory = GetFactory("CPU"); if (!cpu_factory) { @@ -228,7 +228,7 @@ absl::Status DeviceFactory::AddCpuDevices( } absl::Status DeviceFactory::AddDevices( - const SessionOptions& options, const string& name_prefix, + const SessionOptions& options, const std::string& name_prefix, std::vector>* devices) { // CPU first. A CPU device is required. // TODO(b/183974121): Consider merge the logic into the loop below. @@ -263,9 +263,9 @@ absl::Status DeviceFactory::AddDevices( return absl::OkStatus(); } -std::unique_ptr DeviceFactory::NewDevice(const string& type, - const SessionOptions& options, - const string& name_prefix) { +std::unique_ptr DeviceFactory::NewDevice( + const std::string& type, const SessionOptions& options, + const std::string& name_prefix) { auto device_factory = GetFactory(type); if (!device_factory) { return nullptr; diff --git a/tensorflow/core/framework/device_factory.h b/tensorflow/core/framework/device_factory.h index 8b07d15cfc0dac..e30a4538fa939a 100644 --- a/tensorflow/core/framework/device_factory.h +++ b/tensorflow/core/framework/device_factory.h @@ -58,34 +58,35 @@ class DeviceFactory { // Helper for tests. Create a single device of type "type". The // returned device is always numbered zero, so if creating multiple // devices of the same type, supply distinct name_prefix arguments. - static std::unique_ptr NewDevice(const string& type, + static std::unique_ptr NewDevice(const std::string& type, const SessionOptions& options, - const string& name_prefix); + const std::string& name_prefix); // Iterate through all device factories and build a list of all of the // possible physical devices. // // CPU is are added first. - static absl::Status ListAllPhysicalDevices(std::vector* devices); + static absl::Status ListAllPhysicalDevices(std::vector* devices); // Iterate through all device factories and build a list of all of the // possible pluggable physical devices. static absl::Status ListPluggablePhysicalDevices( - std::vector* devices); + std::vector* devices); // Get details for a specific device among all device factories. // 'device_index' indexes into devices from ListAllPhysicalDevices. static absl::Status GetAnyDeviceDetails( - int device_index, std::unordered_map* details); + int device_index, std::unordered_map* details); // For a specific device factory list all possible physical devices. - virtual absl::Status ListPhysicalDevices(std::vector* devices) = 0; + virtual absl::Status ListPhysicalDevices( + std::vector* devices) = 0; // Get details for a specific device for a specific factory. Subclasses // can store arbitrary device information in the map. 'device_index' indexes // into devices from ListPhysicalDevices. virtual absl::Status GetDeviceDetails( - int device_index, std::unordered_map* details) { + int device_index, std::unordered_map* details) { return absl::OkStatus(); } @@ -106,7 +107,7 @@ class DeviceFactory { // higher than the packaged devices. See calls to // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used // for built-in devices. - static int32 DevicePriority(const std::string& device_type); + static int32_t DevicePriority(const std::string& device_type); // Returns true if 'device_type' is registered from plugin. Returns false if // 'device_type' is a first-party device. diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc index ec424f890883eb..c295e18bca197b 100644 --- a/tensorflow/core/framework/fake_input.cc +++ b/tensorflow/core/framework/fake_input.cc @@ -36,7 +36,7 @@ class FakeInputImpl { absl::Status AddInputToBuilder(); private: - static string FakeNodeName(int in_index); + static std::string FakeNodeName(int in_index); absl::Status GetN(int* n) const; absl::Status GetDataType(DataType* dt) const; void NSources(int n, DataType dt) const; @@ -44,7 +44,7 @@ class FakeInputImpl { const OpDef* const op_def_; const OpDef::ArgDef* const arg_; - const string in_node_; + const std::string in_node_; const NodeDef* const node_def_; NodeDefBuilder* const builder_; @@ -120,9 +120,9 @@ absl::Status FakeInputImpl::AddInputToBuilder() { } // static -string FakeInputImpl::FakeNodeName(int in_index) { +std::string FakeInputImpl::FakeNodeName(int in_index) { char c = 'a' + (in_index % 26); - return string(&c, 1); + return std::string(&c, 1); } absl::Status FakeInputImpl::GetN(int* n) const { diff --git a/tensorflow/core/framework/full_type_inference_util.cc b/tensorflow/core/framework/full_type_inference_util.cc index 029ca251b536c2..2fc478466337e7 100644 --- a/tensorflow/core/framework/full_type_inference_util.cc +++ b/tensorflow/core/framework/full_type_inference_util.cc @@ -342,7 +342,7 @@ TypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx) { }; } -TypeInferenceFn FunctionCall(const string& func_attr_name) { +TypeInferenceFn FunctionCall(const std::string& func_attr_name) { return [func_attr_name](const TypeRefVector& input_types, const FunctionTypeInferrer& infer_function_rets) -> absl::StatusOr { diff --git a/tensorflow/core/framework/full_type_inference_util.h b/tensorflow/core/framework/full_type_inference_util.h index 3117613bcd130f..211768f4a0083b 100644 --- a/tensorflow/core/framework/full_type_inference_util.h +++ b/tensorflow/core/framework/full_type_inference_util.h @@ -122,7 +122,7 @@ TypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx); // Helper for ops with semantics of calling a function. The function is // specified indirectly, as the name of an attribute that holds the actual // function name. -TypeInferenceFn FunctionCall(const string& func_attr_name); +TypeInferenceFn FunctionCall(const std::string& func_attr_name); // Compose the type of a function by concatenating the outputs of multiple // type inference functions. If func_list is {type inference function 1, type diff --git a/tensorflow/core/framework/full_type_util.cc b/tensorflow/core/framework/full_type_util.cc index 0bc3e04faf0e49..ea5fad4f704ff3 100644 --- a/tensorflow/core/framework/full_type_util.cc +++ b/tensorflow/core/framework/full_type_util.cc @@ -51,7 +51,7 @@ OpTypeConstructor Nullary(FullTypeId t) { }; } -OpTypeConstructor Unary(FullTypeId t, const string& var_name) { +OpTypeConstructor Unary(FullTypeId t, const std::string& var_name) { return [t, var_name](OpDef* op_def) { FullTypeDef* tdef = op_def->mutable_output_arg(0)->mutable_experimental_full_type(); @@ -93,7 +93,8 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype) { }; } -OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) { +OpTypeConstructor UnaryTensorContainer(FullTypeId t, + const std::string& var_name) { return [t, var_name](OpDef* op_def) { FullTypeDef* tdef = op_def->mutable_output_arg(0)->mutable_experimental_full_type(); @@ -110,7 +111,7 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) { } OpTypeConstructor VariadicTensorContainer(FullTypeId t, - const string& var_name) { + const std::string& var_name) { return [t, var_name](OpDef* op_def) { FullTypeDef* tdef = op_def->mutable_output_arg(0)->mutable_experimental_full_type(); diff --git a/tensorflow/core/framework/full_type_util.h b/tensorflow/core/framework/full_type_util.h index 76045e1bf1a777..392871a189651e 100644 --- a/tensorflow/core/framework/full_type_util.h +++ b/tensorflow/core/framework/full_type_util.h @@ -52,7 +52,7 @@ OpTypeConstructor NoOutputs(); OpTypeConstructor Nullary(FullTypeId t); // Helper for a type constructor of [FT_VAR[]]. -OpTypeConstructor Unary(FullTypeId t, const string& var_name); +OpTypeConstructor Unary(FullTypeId t, const std::string& var_name); // Helper for a type constructor of [FT_ANY]. OpTypeConstructor UnaryGeneric(FullTypeId t); @@ -61,7 +61,8 @@ OpTypeConstructor UnaryGeneric(FullTypeId t); OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype); // Helper for a type constructor of [FT_VAR[]]. -OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name); +OpTypeConstructor UnaryTensorContainer(FullTypeId t, + const std::string& var_name); // Helper for a type constructor of // [FT_FOR_EACH[ @@ -69,7 +70,8 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name); // FT_TENSOR[FT_VAR[]], // FT_VAR[]]. // Multi-valued type variables will expand the template (see full_type.proto). -OpTypeConstructor VariadicTensorContainer(FullTypeId t, const string& var_name); +OpTypeConstructor VariadicTensorContainer(FullTypeId t, + const std::string& var_name); // Type specialization and inference logic. This function narrows the type // specified in an op definition. Such types are usually generic and dependent diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 2b778ca0134c70..91653d2cb0936f 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -122,7 +122,7 @@ absl::Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, namespace { template -void AddAttr(const string& name, const T& val, NodeDef* ndef) { +void AddAttr(const std::string& name, const T& val, NodeDef* ndef) { SetAttrValue(val, &((*ndef->mutable_attr())[name])); } @@ -207,7 +207,7 @@ class FunctionInstantiationHelper { "Expected arg_index to be equal to the number of nodes in result.", " Got ", arg_index, " and ", result_.nodes.size()); } - string name = arg_def.name(); + std::string name = arg_def.name(); if (dtypes.size() > 1) { absl::StrAppend(&name, "_", i); } @@ -259,7 +259,7 @@ class FunctionInstantiationHelper { ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes)); // Note that we rely on the backwards-compatibility test enforcing // that output_arg(*).name() doesn't change here. - const string base_name = + const std::string base_name = absl::StrCat(node.name(), ":", node_sig->output_arg(i).name()); TF_RETURN_IF_ERROR( AddItem(base_name, {false, arg_index, start, is_type_list, dtypes})); @@ -299,7 +299,7 @@ class FunctionInstantiationHelper { " >= ", fnode.input_size()); } // Look up the next input. - const string& input_name = fnode.input(fnode_arg_index); + const std::string& input_name = fnode.input(fnode_arg_index); const auto* item = GetItemOrNull(input_name); if (item == nullptr) { return errors::InvalidArgument( @@ -331,15 +331,15 @@ class FunctionInstantiationHelper { // Control deps. for (int i = fnode_arg_index; i < fnode.input_size(); ++i) { - const string& input = fnode.input(i); + const std::string& input = fnode.input(i); if (input.empty() || input[0] != '^') { return errors::InvalidArgument("Expected input[", i, "] == '", input, "' to be a control input."); } int nid = -1; - const string node_name = input.substr(1); - const string node_colon = node_name + ":"; - const string node_colon_bound = node_name + ";"; + const std::string node_name = input.substr(1); + const std::string node_colon = node_name + ":"; + const std::string node_colon_bound = node_name + ";"; // index_ is a map sorted lexicographically, so the key we are looking for // must lie in the range [node_name, node_colon_bound). auto it = index_.lower_bound(node_name); @@ -379,7 +379,7 @@ class FunctionInstantiationHelper { absl::Status AddReturnNode( const OpDef::ArgDef& ret_def, AttrSlice attrs, - const ::tensorflow::protobuf::Map& ret_map, + const ::tensorflow::protobuf::Map& ret_map, bool ints_on_device, int* ret_index) { auto ret_iter = ret_map.find(ret_def.name()); if (ret_iter == ret_map.end()) { @@ -401,7 +401,7 @@ class FunctionInstantiationHelper { DataTypeVectorString(item->dtypes)); } for (size_t i = 0; i < dtypes.size(); ++i) { - string name = absl::StrCat(ret_def.name(), "_RetVal"); + std::string name = absl::StrCat(ret_def.name(), "_RetVal"); if (dtypes.size() > 1) { absl::StrAppend(&name, "_", i); } @@ -456,7 +456,7 @@ class FunctionInstantiationHelper { }; // Adds an item into the input name index. - absl::Status AddItem(const string& name, const NameInfoItem& item) { + absl::Status AddItem(const std::string& name, const NameInfoItem& item) { if (!index_.insert({name, item}).second) { return errors::InvalidArgument( absl::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret", @@ -466,20 +466,20 @@ class FunctionInstantiationHelper { return absl::OkStatus(); } - const NameInfoItem* GetItemOrNull(const string& name) const { + const NameInfoItem* GetItemOrNull(const std::string& name) const { return gtl::FindOrNull(index_, name); } - string Dep(int node_index) const { + std::string Dep(int node_index) const { return absl::StrCat("^", Name(node_index)); } - string Name(int node_index) const { + std::string Name(int node_index) const { CHECK_LT(node_index, nodes_.size()); return nodes_[node_index].name; } - string Name(int node_index, int output_index) const { + std::string Name(int node_index, int output_index) const { if (output_index == 0) { return Name(node_index); } else { @@ -487,7 +487,7 @@ class FunctionInstantiationHelper { } } - NodeDef* AddNode(const string& name) { + NodeDef* AddNode(const std::string& name) { result_.nodes.emplace_back(); NodeDef* gnode = &result_.nodes.back(); gnode->set_name(name); @@ -510,11 +510,11 @@ class FunctionInstantiationHelper { GetFunctionSignature get_function_; InstantiationResult& result_; // A small index for all names that can be used as a node's input arguments. - std::map index_; + std::map index_; // This contains information about a node in the new graph including the node // names and input nodes' indexes. struct NodeInfo { - string name; + std::string name; // Data inputs where means arg k of node n. std::vector> data_inputs; // Control inputs (dependencies). @@ -525,8 +525,8 @@ class FunctionInstantiationHelper { }; // Various helpers Print(proto) to print relevant protos to ascii. -string Print(const OpDef::ArgDef& arg) { - string out; +std::string Print(const OpDef::ArgDef& arg) { + std::string out; absl::StrAppend(&out, arg.name(), ":"); if (arg.is_ref()) absl::StrAppend(&out, "Ref("); if (!arg.number_attr().empty()) { @@ -545,13 +545,13 @@ string Print(const OpDef::ArgDef& arg) { // When hash_string_attrs = true, string attributes are hashed instead of being // truncated with ellipses. This is done to reduce the chance of collisions when // looking up functions using the canonical representation. -string Print(const AttrValue& attr_value, - const bool hash_string_attrs = false) { +std::string Print(const AttrValue& attr_value, + const bool hash_string_attrs = false) { if (attr_value.value_case() == AttrValue::kType) { return DataTypeString(attr_value.type()); } else if ((attr_value.value_case() == AttrValue::kList) && (attr_value.list().type_size() > 0)) { - string ret = "{"; + std::string ret = "{"; for (int i = 0; i < attr_value.list().type_size(); ++i) { if (i > 0) absl::StrAppend(&ret, ", "); absl::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); @@ -562,7 +562,7 @@ string Print(const AttrValue& attr_value, if (attr_value.func().attr_size() == 0) { return attr_value.func().name(); } - std::vector entries; + std::vector entries; for (const auto& p : attr_value.func().attr()) { entries.push_back(absl::StrCat(p.first, "=", Print(p.second))); } @@ -576,11 +576,11 @@ string Print(const AttrValue& attr_value, } // TODO(josh11b): Merge this with SummarizeNodeDef(). -string Print(const NodeDef& n) { - string out; +std::string Print(const NodeDef& n) { + std::string out; absl::StrAppend(&out, n.name(), " = ", n.op()); if (n.attr_size() > 0) { - std::vector entries; + std::vector entries; for (auto& a : n.attr()) { entries.push_back(absl::StrCat(a.first, "=", Print(a.second))); } @@ -598,7 +598,7 @@ string Print(const NodeDef& n) { } absl::StrAppend(&out, "("); std::vector dat; - std::vector dep; + std::vector dep; for (absl::string_view s : n.input()) { if (absl::ConsumePrefix(&s, "^")) { dep.emplace_back(s); @@ -613,8 +613,8 @@ string Print(const NodeDef& n) { return out; } -string Print(const FunctionDef& fdef) { - string out; +std::string Print(const FunctionDef& fdef) { + std::string out; const OpDef& sig = fdef.signature(); absl::StrAppend(&out, "\n", sig.name()); if (sig.attr_size() > 0) { @@ -654,7 +654,7 @@ string Print(const FunctionDef& fdef) { return out; } -string Print(absl::Span nodes) { +std::string Print(absl::Span nodes) { std::vector arg; std::vector ret; std::vector body; @@ -678,7 +678,7 @@ string Print(absl::Span nodes) { }; std::sort(arg.begin(), arg.end(), comp); std::sort(ret.begin(), ret.end(), comp); - string out; + std::string out; absl::StrAppend(&out, "\n("); auto get_type_and_device = [](const NodeDef& n) { DataType dt; @@ -714,7 +714,7 @@ string Print(absl::Span nodes) { // The _RetVal op should have a unique non-control input. We assert that // here and add it to the output. bool found_non_control_input = false; - for (const string& input : n->input()) { + for (const std::string& input : n->input()) { if (!input.empty() && input[0] != '^') { DCHECK_EQ(found_non_control_input, false) << "RetVal node has more than one non-control input: " @@ -735,7 +735,7 @@ string Print(absl::Span nodes) { return out; } -absl::Status AddDefaultAttrs(const string& op, +absl::Status AddDefaultAttrs(const std::string& op, const GetFunctionSignature& get_function, AttrValueMap* attrs) { const OpDef* op_def = nullptr; @@ -799,7 +799,8 @@ absl::Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, } } - auto substitute = [attr_values, &sig](const string& name, AttrValue* val) { + auto substitute = [attr_values, &sig](const std::string& name, + AttrValue* val) { // Look for a specified value... if (const AttrValue* v = attr_values.FindByString(name)) { *val = *v; @@ -870,9 +871,9 @@ absl::Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, return absl::OkStatus(); } -string DebugString(const FunctionDef& func_def) { return Print(func_def); } +std::string DebugString(const FunctionDef& func_def) { return Print(func_def); } -string DebugString(const GraphDef& instantiated_func_def) { +std::string DebugString(const GraphDef& instantiated_func_def) { std::vector ptrs; for (const NodeDef& n : instantiated_func_def.node()) { ptrs.push_back(&n); @@ -880,7 +881,7 @@ string DebugString(const GraphDef& instantiated_func_def) { return Print(ptrs); } -string DebugString(absl::Span instantiated_func_nodes) { +std::string DebugString(absl::Span instantiated_func_nodes) { std::vector ptrs; for (const NodeDef& n : instantiated_func_nodes) { ptrs.push_back(&n); @@ -888,8 +889,8 @@ string DebugString(absl::Span instantiated_func_nodes) { return Print(ptrs); } -string DebugStringWhole(const GraphDef& gdef) { - string ret; +std::string DebugStringWhole(const GraphDef& gdef) { + std::string ret; for (const auto& fdef : gdef.library().function()) { absl::StrAppend(&ret, Print(fdef)); } @@ -905,8 +906,8 @@ namespace { // Returns the name -> attr mapping of fdef's attrs that have a value set. In // Python, it's possible to access unset attrs, which returns a default value // and adds an unset attr to the map. -std::map GetSetAttrs(const FunctionDef& fdef) { - std::map set_attrs; +std::map GetSetAttrs(const FunctionDef& fdef) { + std::map set_attrs; for (const auto& pair : fdef.attr()) { if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) { set_attrs[pair.first] = pair.second; @@ -920,8 +921,8 @@ std::map GetSetAttrs(const FunctionDef& fdef) { bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { if (!OpDefEqual(f1.signature(), f2.signature())) return false; - std::map f1_attrs = GetSetAttrs(f1); - std::map f2_attrs = GetSetAttrs(f2); + std::map f1_attrs = GetSetAttrs(f1); + std::map f2_attrs = GetSetAttrs(f2); if (f1_attrs.size() != f2_attrs.size()) return false; for (const auto& iter1 : f1_attrs) { auto iter2 = f2_attrs.find(iter1.first); @@ -933,25 +934,25 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { return false; } - std::map ret1(f1.ret().begin(), f1.ret().end()); - std::map ret2(f2.ret().begin(), f2.ret().end()); + std::map ret1(f1.ret().begin(), f1.ret().end()); + std::map ret2(f2.ret().begin(), f2.ret().end()); if (ret1 != ret2) return false; - std::map control_ret1(f1.control_ret().begin(), - f1.control_ret().end()); - std::map control_ret2(f2.control_ret().begin(), - f2.control_ret().end()); + std::map control_ret1(f1.control_ret().begin(), + f1.control_ret().end()); + std::map control_ret2(f2.control_ret().begin(), + f2.control_ret().end()); if (control_ret1 != control_ret2) return false; return true; } -uint64 FunctionDefHash(const FunctionDef& fdef) { +uint64_t FunctionDefHash(const FunctionDef& fdef) { // signature - uint64 h = OpDefHash(fdef.signature()); + uint64_t h = OpDefHash(fdef.signature()); // attrs - std::map attrs = GetSetAttrs(fdef); + std::map attrs = GetSetAttrs(fdef); for (const auto& p : attrs) { h = Hash64(p.first.data(), p.first.size(), h); h = Hash64Combine(AttrValueHash(p.second), h); @@ -961,15 +962,15 @@ uint64 FunctionDefHash(const FunctionDef& fdef) { h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h); // output names - std::map ret(fdef.ret().begin(), fdef.ret().end()); + std::map ret(fdef.ret().begin(), fdef.ret().end()); for (const auto& p : ret) { h = Hash64(p.first.data(), p.first.size(), h); h = Hash64(p.second.data(), p.second.size(), h); } // control output names - std::map control_ret(fdef.control_ret().begin(), - fdef.control_ret().end()); + std::map control_ret(fdef.control_ret().begin(), + fdef.control_ret().end()); for (const auto& p : control_ret) { h = Hash64(p.first.data(), p.first.size(), h); h = Hash64(p.second.data(), p.second.size(), h); @@ -981,14 +982,14 @@ uint64 FunctionDefHash(const FunctionDef& fdef) { static constexpr const char* const kExecutorAttr = "_executor"; /* static */ -string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options, - AttrSlice attrs) { +std::string FunctionLibraryRuntime::ExecutorType( + const InstantiateOptions& options, AttrSlice attrs) { if (!options.executor_type.empty()) { return options.executor_type; } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) { return executor_attr->s(); } else { - return string(); + return std::string(); } } @@ -999,7 +1000,7 @@ class AttrKeyAndValue { kRaw, kCEscape, }; - AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value, + AttrKeyAndValue(absl::string_view key_name, int key_suffix, std::string value, ValueRepresentationOp value_op = kRaw) : key_name_(key_name), key_suffix_(key_suffix), @@ -1016,7 +1017,7 @@ class AttrKeyAndValue { } } - void AppendTo(bool first, string* s) const { + void AppendTo(bool first, std::string* s) const { absl::string_view v; bool add_escaped = false; if ((value_op_ == kCEscape) && NeedsEscaping(value_)) { @@ -1037,9 +1038,9 @@ class AttrKeyAndValue { } private: - static bool NeedsEscaping(const string& s) { + static bool NeedsEscaping(const std::string& s) { for (auto c : s) { - if (!isalnum(c) && (c != ' ')) { + if (!absl::ascii_isalnum(c) && (c != ' ')) { return true; } } @@ -1049,16 +1050,17 @@ class AttrKeyAndValue { absl::string_view key_name_; int key_suffix_; // -1 if missing ValueRepresentationOp value_op_; - string value_; + std::string value_; }; } // namespace -string GetFunctionResourceInputDevice( +std::string GetFunctionResourceInputDevice( const Tensor& input, const int arg_index, const FunctionDef& function_def, - absl::flat_hash_map>* composite_devices) { + absl::flat_hash_map>* + composite_devices) { const auto& handles = input.flat(); const ResourceHandle& handle0 = handles(0); - string composite_device; + std::string composite_device; auto iter = function_def.arg_attr().find(arg_index); if (iter != function_def.arg_attr().end()) { auto arg_attr = iter->second.attr().find("_composite_device"); @@ -1078,8 +1080,9 @@ string GetFunctionResourceInputDevice( } } -string Canonicalize(const string& funcname, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options) { +std::string Canonicalize( + const std::string& funcname, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options) { absl::InlinedVector entries; entries.reserve(attrs.size() + static_cast(!options.target.empty()) + options.input_devices.size()); @@ -1118,12 +1121,13 @@ string Canonicalize(const string& funcname, AttrSlice attrs, entries.push_back( AttrKeyAndValue("_state_handle", -1, options.state_handle)); } - string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs); + std::string executor_type = + FunctionLibraryRuntime::ExecutorType(options, attrs); if (!executor_type.empty()) { entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type)); } if (options.config_proto.ByteSize() > 0) { - string config_proto_serialized; + std::string config_proto_serialized; SerializeToStringDeterministic(options.config_proto, &config_proto_serialized); entries.push_back(AttrKeyAndValue("_config_proto", -1, @@ -1131,7 +1135,7 @@ string Canonicalize(const string& funcname, AttrSlice attrs, AttrKeyAndValue::kCEscape)); } std::sort(entries.begin(), entries.end()); - string result = absl::StrCat(funcname, "["); + std::string result = absl::StrCat(funcname, "["); bool first = true; for (const auto& entry : entries) { entry.AppendTo(first, &result); @@ -1141,7 +1145,7 @@ string Canonicalize(const string& funcname, AttrSlice attrs, return result; } -string Canonicalize(const string& funcname, AttrSlice attrs) { +std::string Canonicalize(const std::string& funcname, AttrSlice attrs) { static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions = new FunctionLibraryRuntime::InstantiateOptions; return Canonicalize(funcname, attrs, *kEmptyOptions); @@ -1373,12 +1377,13 @@ void FunctionLibraryDefinition::Initialize( } } -bool FunctionLibraryDefinition::Contains(const string& func) const { +bool FunctionLibraryDefinition::Contains(const std::string& func) const { tf_shared_lock l(mu_); return records_.find(func) != records_.end(); } -const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const { +const FunctionDef* FunctionLibraryDefinition::Find( + const std::string& func) const { tf_shared_lock l(mu_); auto result = FindHelper(func); if (result) { @@ -1389,13 +1394,13 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const { } core::RefCountPtr FunctionLibraryDefinition::FindRecord( - const string& func) const { + const std::string& func) const { tf_shared_lock l(mu_); return FindHelper(func); } core::RefCountPtr FunctionLibraryDefinition::FindHelper( - const string& func) const { + const std::string& func) const { auto iter = records_.find(func); if (iter == records_.end()) { return nullptr; @@ -1474,7 +1479,7 @@ absl::Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, } absl::Status FunctionLibraryDefinition::CopyFunctionDefFrom( - const string& name, const FunctionLibraryDefinition& other) { + const std::string& name, const FunctionLibraryDefinition& other) { if (default_registry() != other.default_registry()) { return errors::InvalidArgument( "Cannot copy function '", name, @@ -1516,7 +1521,7 @@ absl::Status FunctionLibraryDefinition::AddGradientDef( absl::Status FunctionLibraryDefinition::AddGradientDefHelper( const GradientDef& grad, bool* added) { *added = false; - string* entry = &func_grad_[grad.function_name()]; + std::string* entry = &func_grad_[grad.function_name()]; if (!entry->empty()) { if (*entry != grad.gradient_func()) { return errors::InvalidArgument( @@ -1545,8 +1550,8 @@ absl::Status FunctionLibraryDefinition::AddLibrary( mutex_lock l2(other.mu_); // Remember the funcs and grads that we added successfully so that // we can roll them back on error. - std::vector funcs; - std::vector funcs_with_grads; + std::vector funcs; + std::vector funcs_with_grads; absl::Status s; bool added; for (const auto& [name, record] : other.records_) { @@ -1603,8 +1608,8 @@ absl::Status FunctionLibraryDefinition::AddLibrary( // Remember the funcs and grads that we added successfully so that // we can roll them back on error. mutex_lock l(mu_); - std::vector funcs; - std::vector funcs_with_grads; + std::vector funcs; + std::vector funcs_with_grads; absl::Status s; bool added; for (FunctionDef& fdef : *lib_def.mutable_function()) { @@ -1641,7 +1646,7 @@ absl::Status FunctionLibraryDefinition::AddLibrary( } absl::Status FunctionLibraryDefinition::ReplaceFunction( - const string& func, const FunctionDef& fdef, + const std::string& func, const FunctionDef& fdef, const StackTracesMap& stack_traces) { mutex_lock l(mu_); bool added; @@ -1660,14 +1665,15 @@ absl::Status FunctionLibraryDefinition::ReplaceGradient( return absl::OkStatus(); } -absl::Status FunctionLibraryDefinition::RemoveFunction(const string& func) { +absl::Status FunctionLibraryDefinition::RemoveFunction( + const std::string& func) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); return absl::OkStatus(); } absl::Status FunctionLibraryDefinition::RemoveFunctionHelper( - const string& func) { + const std::string& func) { auto iter = records_.find(func); if (iter == records_.end()) { return errors::InvalidArgument("Tried to remove non-existent function '", @@ -1688,7 +1694,8 @@ void FunctionLibraryDefinition::Clear() { func_grad_.clear(); } -absl::Status FunctionLibraryDefinition::RemoveGradient(const string& func) { +absl::Status FunctionLibraryDefinition::RemoveGradient( + const std::string& func) { const auto& i = func_grad_.find(func); if (i == func_grad_.end()) { return errors::InvalidArgument("Tried to remove non-existent gradient '", @@ -1699,16 +1706,16 @@ absl::Status FunctionLibraryDefinition::RemoveGradient(const string& func) { } absl::Status FunctionLibraryDefinition::Remove( - const std::vector& funcs, - const std::vector& funcs_with_grads) { + const std::vector& funcs, + const std::vector& funcs_with_grads) { absl::Status s; - for (const string& f : funcs) { + for (const std::string& f : funcs) { s = RemoveFunctionHelper(f); if (!s.ok()) { return s; } } - for (const string& f : funcs_with_grads) { + for (const std::string& f : funcs_with_grads) { s = RemoveGradient(f); if (!s.ok()) { return s; @@ -1717,17 +1724,19 @@ absl::Status FunctionLibraryDefinition::Remove( return absl::OkStatus(); } -string FunctionLibraryDefinition::FindGradient(const string& func) const { +std::string FunctionLibraryDefinition::FindGradient( + const std::string& func) const { tf_shared_lock l(mu_); return gtl::FindWithDefault(func_grad_, func, ""); } -string FunctionLibraryDefinition::FindGradientHelper(const string& func) const { +std::string FunctionLibraryDefinition::FindGradientHelper( + const std::string& func) const { return gtl::FindWithDefault(func_grad_, func, ""); } absl::Status FunctionLibraryDefinition::LookUp( - const string& op, const OpRegistrationData** op_reg_data) const { + const std::string& op, const OpRegistrationData** op_reg_data) const { tf_shared_lock l(mu_); auto iter = records_.find(op); if (iter != records_.end()) { @@ -1737,11 +1746,11 @@ absl::Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } -string FunctionLibraryDefinition::UniqueFunctionName( +std::string FunctionLibraryDefinition::UniqueFunctionName( absl::string_view prefix) const { tf_shared_lock l(mu_); int index = 0; - string name = absl::StrCat(prefix, index); + std::string name = absl::StrCat(prefix, index); while (records_.find(name) != records_.end()) { ++index; name = absl::StrCat(prefix, index); @@ -1763,8 +1772,8 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) { return nullptr; } - const string& func_name = forward_func_attrs->name(); - const string& grad_name = FindGradient(func_name); + const std::string& func_name = forward_func_attrs->name(); + const std::string& grad_name = FindGradient(func_name); // If 'func' has a user-defined gradient function, uses the grad // function's attrs to see if noinline is specified. Otherwise, // uses func's attrs. @@ -1782,8 +1791,8 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( } } -std::vector FunctionLibraryDefinition::ListFunctionNames() const { - std::vector function_names; +std::vector FunctionLibraryDefinition::ListFunctionNames() const { + std::vector function_names; tf_shared_lock l(mu_); function_names.reserve(records_.size()); for (const auto& it : records_) { @@ -1808,7 +1817,7 @@ FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { template absl::Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, - const string& attr, + const std::string& attr, T* value) const { const FunctionDef* fdef = GetAttrImpl(ndef); if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) { @@ -1819,7 +1828,7 @@ absl::Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, template absl::Status FunctionLibraryDefinition::GetAttr(const Node& node, - const string& attr, + const std::string& attr, T* value) const { return GetAttr(node.def(), attr, value); } @@ -1839,25 +1848,25 @@ constexpr char kApiImplements[] = "api_implements"; template -std::set ReachableFunctions(const FunctionLibraryDefinition& flib, - NodeIter begin, NodeIter end, - OpTypeGetter op_type_getter, - AttrGetter attr_getter) { +std::set ReachableFunctions(const FunctionLibraryDefinition& flib, + NodeIter begin, NodeIter end, + OpTypeGetter op_type_getter, + AttrGetter attr_getter) { // Functions that are reachable from the graph. - std::set reachable_funcs; + std::set reachable_funcs; // For any functions, if it has attribute "api_implements" = // "some_interface" and it is reachable, then it means any other // function with same attribute name and value could also be potentially // reachable, eg via implementation_selector swapping the nodedef. - absl::flat_hash_set reachable_api_interface; + absl::flat_hash_set reachable_api_interface; // Functions might be reachable from the nested function calls, so we keep a // queue of functions that we have to check. absl::InlinedVector, 4> func_queue; // Add reachable and not already processed functions to the functions queue. - const auto add_to_func_queue = [&](const string& func_name) { + const auto add_to_func_queue = [&](const std::string& func_name) { auto record = flib.FindRecord(func_name); if (record && reachable_funcs.find(func_name) == reachable_funcs.end()) { func_queue.push_back(std::move(record)); @@ -1866,19 +1875,20 @@ std::set ReachableFunctions(const FunctionLibraryDefinition& flib, // If any function with certain API name is reachable, all the other functions // with same API name should also be checked. - const auto add_function_with_api_interface = [&](const string& api_name) { - if (!reachable_api_interface.contains(api_name)) { - reachable_api_interface.insert(api_name); - for (const auto& func_name : flib.ListFunctionNames()) { - const auto record = flib.FindRecord(func_name); - const auto attr_it = record->fdef().attr().find(kApiImplements); - if (attr_it != record->fdef().attr().end() && - attr_it->second.s() == api_name) { - add_to_func_queue(func_name); + const auto add_function_with_api_interface = + [&](const std::string& api_name) { + if (!reachable_api_interface.contains(api_name)) { + reachable_api_interface.insert(api_name); + for (const auto& func_name : flib.ListFunctionNames()) { + const auto record = flib.FindRecord(func_name); + const auto attr_it = record->fdef().attr().find(kApiImplements); + if (attr_it != record->fdef().attr().end() && + attr_it->second.s() == api_name) { + add_to_func_queue(func_name); + } + } } - } - } - }; + }; const auto process_attr_value = [&](const AttrValue& attr_value) { // 1. AttrValue.func @@ -1913,7 +1923,7 @@ std::set ReachableFunctions(const FunctionLibraryDefinition& flib, auto func = std::move(func_queue.back()); func_queue.pop_back(); - const string& func_name = func->fdef().signature().name(); + const std::string& func_name = func->fdef().signature().name(); reachable_funcs.insert(func_name); const auto attr_it = func->fdef().attr().find(kApiImplements); @@ -1937,7 +1947,7 @@ std::set ReachableFunctions(const FunctionLibraryDefinition& flib, std::for_each(func_body.begin(), func_body.end(), process_node_def); // Check if the function has a registered gradient. - const string grad_func_name = flib.FindGradient(func_name); + const std::string grad_func_name = flib.FindGradient(func_name); if (!grad_func_name.empty()) add_to_func_queue(grad_func_name); } @@ -1949,19 +1959,19 @@ template reachable_funcs = ReachableFunctions( + std::set reachable_funcs = ReachableFunctions( flib, begin, end, op_type_getter, attr_getter); FunctionLibraryDefinition reachable_flib(flib.default_registry(), FunctionDefLibrary()); - for (const string& func_name : reachable_funcs) { + for (const std::string& func_name : reachable_funcs) { // This should never fail, because we copy functions from a valid flib and // use the same default registry. absl::Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib); TF_DCHECK_OK(added); - const string grad_func_name = flib.FindGradient(func_name); + const std::string grad_func_name = flib.FindGradient(func_name); if (!grad_func_name.empty()) { GradientDef grad; grad.set_function_name(func_name); @@ -1975,9 +1985,9 @@ FunctionLibraryDefinition ReachableFunctionLibraryDefinition( return reachable_flib; } -string AllocatorAttributesToString( +std::string AllocatorAttributesToString( const std::vector& attrs) { - string result("["); + std::string result("["); // AllocatorAttribute::DebugString produces around 85 bytes now. result.reserve(100 * attrs.size()); for (const AllocatorAttributes& attr : attrs) { @@ -2036,7 +2046,7 @@ FunctionLibraryDefinition::ReachableDefinitions( } } -string FunctionLibraryRuntime::Options::DebugString() const { +std::string FunctionLibraryRuntime::Options::DebugString() const { return absl::StrCat( "FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous), " cancellation_manager=", IsSet(cancellation_manager), @@ -2060,8 +2070,8 @@ void FunctionDefHelper::AttrValueWrapper::InitFromString( } FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( - const string& name, - absl::Span> attrs) { + const std::string& name, + absl::Span> attrs) { AttrValueWrapper ret; ret.proto.mutable_func()->set_name(name); for (const auto& a : attrs) { @@ -2077,10 +2087,10 @@ NodeDef FunctionDefHelper::Node::ToNodeDef() const { for (const auto& a : this->attr) { n.mutable_attr()->insert({a.first, a.second.proto}); } - for (const string& a : this->arg) { + for (const std::string& a : this->arg) { n.add_input(a); } - for (const string& d : this->dep) { + for (const std::string& d : this->dep) { n.add_input(absl::StrCat("^", d)); } if (!this->device.empty()) { @@ -2099,11 +2109,11 @@ NodeDef FunctionDefHelper::Node::ToNodeDef() const { /* static */ FunctionDef FunctionDefHelper::Create( - const string& function_name, absl::Span in_def, - absl::Span out_def, absl::Span attr_def, - absl::Span node_def, - absl::Span> ret_def, - absl::Span> control_ret_def) { + const std::string& function_name, absl::Span in_def, + absl::Span out_def, + absl::Span attr_def, absl::Span node_def, + absl::Span> ret_def, + absl::Span> control_ret_def) { FunctionDef fdef; // Signature @@ -2149,19 +2159,19 @@ FunctionDef FunctionDefHelper::Create( /* static */ FunctionDef FunctionDefHelper::Create( - const string& function_name, absl::Span in_def, - absl::Span out_def, absl::Span attr_def, - absl::Span node_def, - absl::Span> ret_def) { + const std::string& function_name, absl::Span in_def, + absl::Span out_def, + absl::Span attr_def, absl::Span node_def, + absl::Span> ret_def) { return Create(function_name, in_def, out_def, attr_def, node_def, ret_def, /*control_ret_def=*/{}); } /* static */ -FunctionDef FunctionDefHelper::Define(const string& name, - absl::Span arg_def, - absl::Span ret_def, - absl::Span attr_def, +FunctionDef FunctionDefHelper::Define(const std::string& name, + absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, absl::Span node_def) { FunctionDef fdef; OpDefBuilder b(name); @@ -2174,7 +2184,7 @@ FunctionDef FunctionDefHelper::Define(const string& name, fdef.mutable_signature()->Swap(&op_reg_data.op_def); // Mapping from legacy output names to NodeDef outputs. - std::unordered_map ret_index; + std::unordered_map ret_index; for (const auto& a : fdef.signature().input_arg()) { ret_index[a.name()] = a.name(); } @@ -2190,13 +2200,13 @@ FunctionDef FunctionDefHelper::Define(const string& name, for (const auto& a : src.attr) { n->mutable_attr()->insert({a.first, a.second.proto}); } - for (const string& a : src.arg) { + for (const std::string& a : src.arg) { const auto iter = ret_index.find(a); CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '" << n->name() << "' of " << name; n->add_input(iter->second); } - for (const string& d : src.dep) { + for (const std::string& d : src.dep) { n->add_input(absl::StrCat("^", d)); } @@ -2227,29 +2237,29 @@ FunctionDef FunctionDefHelper::Define(const string& name, return fdef; } -FunctionDef FunctionDefHelper::Define(absl::Span arg_def, - absl::Span ret_def, - absl::Span attr_def, +FunctionDef FunctionDefHelper::Define(absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, absl::Span node_def) { return Define("_", arg_def, ret_def, attr_def, node_def); } namespace gradient { -typedef std::unordered_map OpGradFactory; +typedef std::unordered_map OpGradFactory; OpGradFactory* GetOpGradFactory() { static OpGradFactory* factory = new OpGradFactory; return factory; } -bool RegisterOp(const string& op, Creator func) { +bool RegisterOp(const std::string& op, Creator func) { CHECK(GetOpGradFactory()->insert({op, func}).second) << "Duplicated gradient for " << op; return true; } -absl::Status GetOpGradientCreator(const string& op, Creator* creator) { +absl::Status GetOpGradientCreator(const std::string& op, Creator* creator) { auto fac = GetOpGradFactory(); auto iter = fac->find(op); if (iter == fac->end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 7fbf120afd6741..ed2ec8c075db08 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -125,7 +125,7 @@ class FunctionDefHelper { // Constructs an AttrValue.func given the "name" and "attrs". static AttrValueWrapper FunctionRef( const std::string& name, - absl::Span> attrs); + absl::Span> attrs); static AttrValueWrapper FunctionRef(const std::string& name) { return FunctionRef(name, {}); } @@ -141,11 +141,11 @@ class FunctionDefHelper { struct Node { // When constructing a NodeDef, the first entry in ret is used as // the node name, the remaining values are ignored. - std::vector ret; + std::vector ret; std::string op; - std::vector arg; - std::vector> attr; - std::vector dep; + std::vector arg; + std::vector> attr; + std::vector dep; std::string device; // Required if the op has zero outputs. Otherwise, ret[0] used as name if @@ -157,8 +157,8 @@ class FunctionDefHelper { CHECK(!ret.empty()); return ret[0]; } - std::vector original_node_names; - std::vector original_func_names; + std::vector original_node_names; + std::vector original_func_names; NodeDef ToNodeDef() const; }; @@ -170,33 +170,33 @@ class FunctionDefHelper { // - `control_ret_def` holds a mapping from the function control // output names to the nodes from `node_def`. static FunctionDef Create( - const std::string& function_name, absl::Span in_def, - absl::Span out_def, absl::Span attr_def, - absl::Span node_def, - absl::Span> ret_def, - absl::Span> control_ret_def); + const std::string& function_name, absl::Span in_def, + absl::Span out_def, + absl::Span attr_def, absl::Span node_def, + absl::Span> ret_def, + absl::Span> control_ret_def); // Creates a FunctionDef from the given parameters. Node inputs must use // function encoding (node_name:output_name[:output_index]). // - `ret_def` holds a mapping from the function output names from `out_def` // to the node outputs from `node_def`. static FunctionDef Create( - const std::string& function_name, absl::Span in_def, - absl::Span out_def, absl::Span attr_def, - absl::Span node_def, - absl::Span> ret_def); + const std::string& function_name, absl::Span in_def, + absl::Span out_def, + absl::Span attr_def, absl::Span node_def, + absl::Span> ret_def); // TODO(josh11b): Get rid of these and transition to the one above. static FunctionDef Define(const std::string& function_name, - absl::Span arg_def, - absl::Span ret_def, - absl::Span attr_def, + absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, absl::Span node_def); // Defines an anonymous function. I.e., its name is not relevant. - static FunctionDef Define(absl::Span arg_def, - absl::Span ret_def, - absl::Span attr_def, + static FunctionDef Define(absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, absl::Span node_def); // Helpers to construct a constant scalar. @@ -258,7 +258,7 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( // GetFunctionSignature(func name, opdef) returns OK if the func name is found // and opdef is filled with a pointer to the corresponding signature // (a OpDef proto). Otherwise, returns an error. -typedef std::function +typedef std::function GetFunctionSignature; struct InstantiationResult { @@ -293,7 +293,7 @@ bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); // Return a hash of `fdef` that is consistent with FunctionDefsEqual method. // In other words, if two fdefs compare equal, their hash values will be the // same. -uint64 FunctionDefHash(const FunctionDef& fdef); +uint64_t FunctionDefHash(const FunctionDef& fdef); class CallFrameInterface { public: @@ -566,7 +566,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { } // Returns all the function names in the FunctionLibraryDefinition. - std::vector ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_); + std::vector ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_); const OpRegistryInterface* default_registry() const { return default_registry_; @@ -658,7 +658,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { void Initialize(const FunctionDefLibrary& library, const FunctionDefLibraryStackTraces& library_traces); - core::RefCountPtr FindHelper(const string& func) const + core::RefCountPtr FindHelper(const std::string& func) const TF_SHARED_LOCKS_REQUIRED(mu_); std::string FindGradientHelper(const std::string& func) const TF_SHARED_LOCKS_REQUIRED(mu_); @@ -681,8 +681,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Remove all functions in `funcs` and all gradients of functions in // `funcs_with_grads` from this library. - absl::Status Remove(const std::vector& funcs, - const std::vector& funcs_with_grads) + absl::Status Remove(const std::vector& funcs, + const std::vector& funcs_with_grads) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Remove `func` from the library. Returns non-OK Status unless `func` is in @@ -698,10 +698,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface { mutable mutex mu_; const OpRegistryInterface* default_registry_; - gtl::FlatMap records_ TF_GUARDED_BY(mu_); - gtl::FlatMap func_grad_ TF_GUARDED_BY(mu_); + gtl::FlatMap records_ TF_GUARDED_BY(mu_); + gtl::FlatMap func_grad_ TF_GUARDED_BY(mu_); // Maps from function name to optimized function graph. - gtl::FlatMap()>> + gtl::FlatMap()>> optimized_function_graph_creator_map_ TF_GUARDED_BY(mu_); }; @@ -752,7 +753,7 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // function's inputs. The device of resource inputs must be the device // backing the resource, not the CPU device backing the resource handle. // Must have the same length as number of inputs to the function. - std::vector input_devices; + std::vector input_devices; // For multi-device functions, a vector of canonical device names for // function's outputs. @@ -780,14 +781,15 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // resource output, and node producing that resource is a function call, // runtime will leave device specification empty and will rely on Placer to // infer correct device. - std::vector output_devices; + std::vector output_devices; // If set, it indicates the original output indices of a component function. absl::optional> ret_indices = absl::nullopt; // Maps from a CompositeDevice name to a list of underlying physical // devices. - absl::flat_hash_map*> composite_devices; + absl::flat_hash_map*> + composite_devices; // This interface is EXPERIMENTAL and subject to change. // @@ -836,8 +838,8 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // If provided, this optimization function will be invoked before // the placer for multi-device functions. - std::function /*ret_node_names*/, - std::vector /*keep_node_names*/, + std::function /*ret_node_names*/, + std::vector /*keep_node_names*/, FunctionLibraryDefinition*, const DeviceSet&, Device* /*cpu_device*/, std::unique_ptr*)> optimize_graph_fn; @@ -899,7 +901,7 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // Instantiates the function enabling soft placement or outside compilation. bool allow_soft_placement = false; }; - typedef uint64 Handle; + typedef uint64_t Handle; virtual absl::Status Instantiate(const std::string& function_name, AttrSlice attrs, const InstantiateOptions& options, @@ -1055,7 +1057,7 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // Returns the graph version number. virtual int graph_def_version() const = 0; - typedef uint64 LocalHandle; + typedef uint64_t LocalHandle; // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to // the caller), FunctionLibraryRuntime (owned by the returned @@ -1088,7 +1090,8 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // `composite_devices` if the input device is a composite device. std::string GetFunctionResourceInputDevice( const Tensor& input, const int arg_index, const FunctionDef& function_def, - absl::flat_hash_map>* composite_devices); + absl::flat_hash_map>* + composite_devices); // Returns a canonicalized string for the instantiation of the function of the // given "name", attributes "attrs", and "options". @@ -1173,7 +1176,7 @@ class DistributedFunctionLibraryRuntime { FunctionLibraryRuntime::DoneCallback done) = 0; // Clean up a previously instantiated function on remote worker. - virtual void CleanUp(uint64 step_id, + virtual void CleanUp(uint64_t step_id, FunctionLibraryRuntime::LocalHandle handle, FunctionLibraryRuntime::DoneCallback done) = 0; diff --git a/tensorflow/core/framework/function_handle_cache.cc b/tensorflow/core/framework/function_handle_cache.cc index 6b9119b681af88..d0d995cbcc3712 100644 --- a/tensorflow/core/framework/function_handle_cache.cc +++ b/tensorflow/core/framework/function_handle_cache.cc @@ -33,10 +33,10 @@ FunctionHandleCache::~FunctionHandleCache() { } absl::Status FunctionHandleCache::Instantiate( - const string& function_name, AttrSlice attrs, + const std::string& function_name, AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions options, FunctionLibraryRuntime::Handle* handle) { - string key = Canonicalize(function_name, attrs, options); + std::string key = Canonicalize(function_name, attrs, options); FunctionLibraryRuntime::Handle h; { tf_shared_lock l(mu_); diff --git a/tensorflow/core/framework/function_handle_cache.h b/tensorflow/core/framework/function_handle_cache.h index 1bd67138d1964f..317c53823c1685 100644 --- a/tensorflow/core/framework/function_handle_cache.h +++ b/tensorflow/core/framework/function_handle_cache.h @@ -34,7 +34,7 @@ class FunctionHandleCache { // // The cache retains the ownership of the handle. In particular, the caller // should not invoke `ReleaseHandle`. - absl::Status Instantiate(const string& function_name, AttrSlice attrs, + absl::Status Instantiate(const std::string& function_name, AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions options, FunctionLibraryRuntime::Handle* handle); @@ -45,8 +45,8 @@ class FunctionHandleCache { private: mutex mu_; FunctionLibraryRuntime* lib_ = nullptr; // not owned - const string state_handle_; - std::unordered_map handles_ + const std::string state_handle_; + std::unordered_map handles_ TF_GUARDED_BY(mu_); }; diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 1a396876e00166..fcae39d0277bab 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -54,7 +54,7 @@ using ::testing::UnorderedElementsAreArray; class Attrs { public: Attrs(const std::initializer_list< // NOLINT(runtime/explicit) - std::pair> + std::pair> attrs) { for (const auto& aval : attrs) { map_.insert({aval.first, aval.second.proto}); @@ -69,7 +69,7 @@ class Attrs { typedef FunctionDefHelper FDH; -absl::Status GetOpSig(const string& op, const OpDef** sig) { +absl::Status GetOpSig(const std::string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } @@ -490,7 +490,7 @@ WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { } TEST(TFunc, Body_TypeList) { - const Tensor kZero = test::AsScalar(0); + const Tensor kZero = test::AsScalar(0); auto fdef = FDH::Create( // Name "Test", @@ -633,7 +633,7 @@ TEST(TFunc, IntsOnDeviceArgSet) { EXPECT_EQ("_DeviceRetval", result.nodes[4].op()); } -static void HasError(const absl::Status& s, const string& substr) { +static void HasError(const absl::Status& s, const std::string& substr) { EXPECT_TRUE(absl::StrContains(s.ToString(), substr)) << ">>" << s << "<<, expected substring >>" << substr << "<<"; } @@ -1229,7 +1229,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); } -GradientDef MakeGradDef(const string& f, const string& g) { +GradientDef MakeGradDef(const std::string& f, const std::string& g) { GradientDef grad; grad.set_function_name(f); grad.set_gradient_func(g); @@ -1239,8 +1239,8 @@ GradientDef MakeGradDef(const string& f, const string& g) { TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { // Create lib def containing two functions with equal names FunctionDefLibrary proto; - const string x2_name = test::function::XTimesTwo().signature().name(); - const string x4_name = test::function::XTimesFour().signature().name(); + const std::string x2_name = test::function::XTimesTwo().signature().name(); + const std::string x4_name = test::function::XTimesFour().signature().name(); *proto.add_function() = test::function::XTimesTwo(); FunctionDef fdef = test::function::XTimesFour(); fdef.mutable_signature()->set_name(x2_name); @@ -1275,9 +1275,9 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { } TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { - const string x2_name = test::function::XTimesTwo().signature().name(); - const string x4_name = test::function::XTimesFour().signature().name(); - const string wx_name = test::function::WXPlusB().signature().name(); + const std::string x2_name = test::function::XTimesTwo().signature().name(); + const std::string x4_name = test::function::XTimesFour().signature().name(); + const std::string wx_name = test::function::WXPlusB().signature().name(); // Create FunctionLibraryDefinition with // (func = XTimesTwo, grad = XTimesFour) @@ -1311,9 +1311,9 @@ TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { } TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) { - const string x2_name = test::function::XTimesTwo().signature().name(); - const string x4_name = test::function::XTimesFour().signature().name(); - const string wx_name = test::function::WXPlusB().signature().name(); + const std::string x2_name = test::function::XTimesTwo().signature().name(); + const std::string x4_name = test::function::XTimesFour().signature().name(); + const std::string wx_name = test::function::WXPlusB().signature().name(); // Create FunctionLibraryDefinition with // (func = XTimesTwo, grad = XTimesFour) @@ -1372,8 +1372,8 @@ TEST(FunctionLibraryDefinitionTest, ListFunctionNames) { TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); TF_CHECK_OK(lib_def.AddFunctionDef(test::function::WXPlusB())); - const std::vector function_names = lib_def.ListFunctionNames(); - const std::vector expected = {"XTimesTwo", "WXPlusB"}; + const std::vector function_names = lib_def.ListFunctionNames(); + const std::vector expected = {"XTimesTwo", "WXPlusB"}; EXPECT_EQ(function_names, expected); } @@ -1399,7 +1399,7 @@ TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) { } template -void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) { +void SetAttrValue(FunctionDef* fdef, const std::string& attr, const T& value) { AttrValue attr_value; SetAttrValue(value, &attr_value); fdef->mutable_attr()->insert({attr, attr_value}); @@ -1421,7 +1421,7 @@ TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) { TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation)); EXPECT_EQ(annotation, true); - string str; + std::string str; TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str)); EXPECT_EQ(str, "some string data"); } @@ -1462,8 +1462,8 @@ TEST(FunctionLibraryDefinitionTest, ReachableDefinitions) { using ::tensorflow::test::function::NDef; using FDH = ::tensorflow::FunctionDefHelper; - const auto make_simple_fdef = [](const string& name, - const string& interface_name) { + const auto make_simple_fdef = [](const std::string& name, + const std::string& interface_name) { auto func_def = FDH::Create( name, {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"}, {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, @@ -1616,7 +1616,7 @@ TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) { // Equal functions const FunctionDef fdef1 = test::function::XTimesTwo(); FunctionDef fdef2 = test::function::XTimesTwo(); - uint64 hash1 = FunctionDefHash(fdef1); + uint64_t hash1 = FunctionDefHash(fdef1); EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2)); EXPECT_EQ(hash1, FunctionDefHash(fdef2)); @@ -1760,7 +1760,7 @@ TEST(InstantiateFunctionTest, ResourceInputDevice) { *(*arg_attrs.mutable_attr())["_composite_device"].mutable_s() = "/device:COMPOSITE:0"; (*fdef.mutable_arg_attr())[0] = arg_attrs; - absl::flat_hash_map> composite_devices; + absl::flat_hash_map> composite_devices; Tensor arg0(DT_RESOURCE, TensorShape({2})); ResourceHandle resource_handle0; @@ -1773,9 +1773,9 @@ TEST(InstantiateFunctionTest, ResourceInputDevice) { Tensor arg1(DT_RESOURCE, TensorShape({})); arg1.scalar()() = resource_handle0; - const string device0 = GetFunctionResourceInputDevice( + const std::string device0 = GetFunctionResourceInputDevice( arg0, /*arg_index=*/0, fdef, &composite_devices); - const string device1 = GetFunctionResourceInputDevice( + const std::string device1 = GetFunctionResourceInputDevice( arg1, /*arg_index=*/1, fdef, &composite_devices); EXPECT_EQ(device0, "/device:COMPOSITE:0"); diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 5e5c64d2a2a5ee..1b968b939365a7 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -48,13 +48,14 @@ GraphDef GDef(absl::Span nodes, } // Helper to construct a NodeDef. -NodeDef NDef(absl::string_view name, absl::string_view op, - absl::Span inputs, - absl::Span> attrs, - const string& device) { +NodeDef NDef( + absl::string_view name, absl::string_view op, + absl::Span inputs, + absl::Span> attrs, + const std::string& device) { NodeDef n; - n.set_name(string(name)); - n.set_op(string(op)); + n.set_name(name); + n.set_op(op); for (const auto& in : inputs) n.add_input(in); n.set_device(device); for (const auto& na : attrs) @@ -609,8 +610,8 @@ FunctionDef XYXLessThanOrEqualToN(int64_t N) { } FunctionDef RandomUniformLess() { - const Tensor kZero = test::AsScalar(0); - const Tensor kOne = test::AsScalar(1); + const Tensor kZero = test::AsScalar(0); + const Tensor kOne = test::AsScalar(1); const Tensor k005 = test::AsScalar(0.05); return FDH::Define( @@ -703,8 +704,8 @@ FunctionDef MakeBatchDataset() { } FunctionDef MakeMapDataset(bool has_other_args) { - std::vector args = {"input_dataset: variant"}; - std::vector inputs = {"input_dataset"}; + std::vector args = {"input_dataset: variant"}; + std::vector inputs = {"input_dataset"}; if (has_other_args) { args.emplace_back("other_arguments: Targuments"); inputs.emplace_back("other_arguments"); diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 93cae697e62d15..b4cbf057cbe0a8 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -34,15 +34,14 @@ namespace function { class Attrs { public: Attrs(const std::initializer_list< // NOLINT(runtime/explicit) - std::pair>& attrs) { + std::pair>& attrs) { for (const auto& aval : attrs) { map_.insert({aval.first, aval.second.proto}); } } - Attrs( - const std::vector>& - attrs) { + Attrs(const std::vector< + std::pair>& attrs) { for (const auto& aval : attrs) { map_.insert({aval.first, aval.second.proto}); } @@ -55,12 +54,12 @@ class Attrs { }; // Helper to construct a NodeDef. -NodeDef NDef( - absl::string_view name, absl::string_view op, - absl::Span inputs, - absl::Span> - attrs = {}, - const string& device = ""); +NodeDef NDef(absl::string_view name, absl::string_view op, + absl::Span inputs, + absl::Span> + attrs = {}, + const std::string& device = ""); // Helper to construct a GraphDef proto. GraphDef GDef(absl::Span nodes, diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index c603ced808d370..9f54e3eecfdccd 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -35,8 +35,8 @@ limitations under the License. namespace tensorflow { -string SummarizeGraphDef(const GraphDef& graph_def) { - string ret; +std::string SummarizeGraphDef(const GraphDef& graph_def) { + std::string ret; absl::StrAppend(&ret, "versions = ", graph_def.versions().ShortDebugString(), ";\n"); for (const NodeDef& node : graph_def.node()) { @@ -85,7 +85,7 @@ absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, static absl::Status RemoveNewDefaultAttrsFromNodeDef( NodeDef* node_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, - std::set>* op_attr_removed) { + std::set>* op_attr_removed) { const OpDef* producer_op_def; const OpDef* consumer_op_def; TF_RETURN_IF_ERROR( @@ -93,7 +93,7 @@ static absl::Status RemoveNewDefaultAttrsFromNodeDef( TF_RETURN_IF_ERROR( consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def)); - std::vector to_remove; + std::vector to_remove; for (const auto& attr : node_def->attr()) { // If the attr is not in consumer_op_def and doesn't start with '_'... if (!absl::StartsWith(attr.first, "_") && @@ -117,7 +117,7 @@ static absl::Status RemoveNewDefaultAttrsFromNodeDef( // We separate identifying which attrs should be removed from // actually removing them to avoid invalidating the loop iterators // above. - for (const string& attr_name : to_remove) { + for (const std::string& attr_name : to_remove) { node_def->mutable_attr()->erase(attr_name); if (op_attr_removed != nullptr) { op_attr_removed->insert(std::make_pair(node_def->op(), attr_name)); @@ -127,7 +127,7 @@ static absl::Status RemoveNewDefaultAttrsFromNodeDef( return absl::OkStatus(); } -static bool IsFunction(const GraphDef& graph_def, const string& op_name) { +static bool IsFunction(const GraphDef& graph_def, const std::string& op_name) { for (const auto& func_def : graph_def.library().function()) { if (op_name == func_def.signature().name()) return true; } @@ -137,7 +137,7 @@ static bool IsFunction(const GraphDef& graph_def, const string& op_name) { absl::Status RemoveNewDefaultAttrsFromGraphDef( GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, - std::set>* op_attr_removed) { + std::set>* op_attr_removed) { // TODO(joshL): Make IsFunction() faster by collecting the names of // all functions as a preprocessing step. for (int n = 0; n < graph_def->node_size(); ++n) { @@ -184,7 +184,7 @@ void StripDefaultAttributes(const OpRegistryInterface& op_registry, for (const OpDef::AttrDef& attr_def : op_def->attr()) { if (attr_def.has_default_value()) { AttrValueMap* attrs = node->mutable_attr(); - const string& name = attr_def.name(); + const std::string& name = attr_def.name(); auto iter = attrs->find(name); if (iter != attrs->end()) { const AttrValue& default_value = attr_def.default_value(); @@ -202,9 +202,9 @@ void StripDefaultAttributes(const OpRegistryInterface& op_registry, } void OpsUsedByGraph(const GraphDef& graph_def, - std::set* ops_used_in_graph) { + std::set* ops_used_in_graph) { // Map function names to definitions. - std::unordered_map name_to_function; + std::unordered_map name_to_function; for (const auto& function : graph_def.library().function()) { name_to_function.insert( std::make_pair(function.signature().name(), &function)); @@ -212,11 +212,11 @@ void OpsUsedByGraph(const GraphDef& graph_def, // Collect the sorted list of op names. Since functions can reference // functions, we need a recursive traversal. - std::set used_ops; // Includes both primitive ops and functions + std::set used_ops; // Includes both primitive ops and functions std::vector functions_to_process; // A subset of used_ops // Collect the logic to mark an op in a lambda; it'll be used twice below. const auto mark_op_as_used = [&used_ops, &functions_to_process, - &name_to_function](const string& op) { + &name_to_function](const std::string& op) { if (used_ops.insert(op).second) { // If it's a function, we'll need to process further const auto it = name_to_function.find(op); @@ -239,7 +239,7 @@ void OpsUsedByGraph(const GraphDef& graph_def, // Filter out function names to produce output. // TODO(josh11b): Change the above code to produce this directly. ops_used_in_graph->clear(); - for (const string& op_name : used_ops) { + for (const std::string& op_name : used_ops) { if (name_to_function.find(op_name) == name_to_function.end()) { ops_used_in_graph->insert(op_name); } @@ -249,12 +249,12 @@ void OpsUsedByGraph(const GraphDef& graph_def, absl::Status StrippedOpListForGraph(const GraphDef& graph_def, const OpRegistryInterface& op_registry, OpList* stripped_op_list) { - std::set used_ops; + std::set used_ops; OpsUsedByGraph(graph_def, &used_ops); // Build the stripped op list in sorted order, ignoring functions. stripped_op_list->clear_op(); - for (const string& op_name : used_ops) { + for (const std::string& op_name : used_ops) { const OpDef* op_def; TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def)); OpDef* stripped_op = stripped_op_list->add_op(); diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index a164ac310fe4ed..b3e335e776f3f6 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -29,7 +29,7 @@ class NodeDef; // Produce a human-readable version of a GraphDef that is more concise // than a text-format proto. -string SummarizeGraphDef(const GraphDef& graph_def); +std::string SummarizeGraphDef(const GraphDef& graph_def); // Validates the syntax of a GraphDef provided externally. // @@ -97,7 +97,7 @@ absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, absl::Status RemoveNewDefaultAttrsFromGraphDef( GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, - std::set>* op_attr_removed); + std::set>* op_attr_removed); // Goes over the `nodes` and removes attributes that are set to their // default values according to op_registry. @@ -115,7 +115,7 @@ void StripDefaultAttributes(const OpRegistryInterface& op_registry, // // This returns the ops used as a set of strings. void OpsUsedByGraph(const GraphDef& graph_def, - std::set* ops_used_in_graph); + std::set* ops_used_in_graph); // This function computes the stripped_op_list field of MetaGraphDef // and similar protos. The op_registry should contain the ops used to diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 12a1ee29fe792e..503f2cc93af194 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -59,7 +59,7 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) { .Finalize(graph_def.add_node())); GraphDef expected_graph_def = graph_def; - std::set> op_attr_removed; + std::set> op_attr_removed; TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry, &op_attr_removed)); @@ -80,7 +80,7 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) { .Finalize(graph_def.add_node())); GraphDef expected_graph_def = graph_def; - std::set> op_attr_removed; + std::set> op_attr_removed; TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry, &op_attr_removed)); @@ -106,7 +106,7 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) { TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry) .Finalize(produced_graph_def.add_node())); - std::set> op_attr_removed; + std::set> op_attr_removed; TF_ASSERT_OK( RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, producer_registry, &op_attr_removed)); @@ -116,7 +116,8 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) { .Finalize(expected_graph_def.add_node())); TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def); - std::set> expected_removed({{"UsesDefault", "a"}}); + std::set> expected_removed( + {{"UsesDefault", "a"}}); EXPECT_EQ(expected_removed, op_attr_removed); } @@ -142,7 +143,7 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) { .Finalize(produced_graph_def.add_node())); GraphDef expected_graph_def = produced_graph_def; - std::set> op_attr_removed; + std::set> op_attr_removed; TF_ASSERT_OK( RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, producer_registry, &op_attr_removed)); @@ -174,7 +175,7 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) { .Finalize(produced_graph_def.add_node())); GraphDef expected_graph_def = produced_graph_def; - std::set> op_attr_removed; + std::set> op_attr_removed; TF_ASSERT_OK( RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, producer_registry, &op_attr_removed)); @@ -213,7 +214,7 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) { TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry) .Finalize(produced_graph_def.add_node())); - std::set> op_attr_removed; + std::set> op_attr_removed; TF_ASSERT_OK( RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry, producer_registry, &op_attr_removed)); @@ -231,7 +232,8 @@ TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) { EXPECT_EQ(expected_graph_def.library().DebugString(), produced_graph_def.library().DebugString()); - std::set> expected_removed({{"UsesDefault", "a"}}); + std::set> expected_removed( + {{"UsesDefault", "a"}}); EXPECT_EQ(expected_removed, op_attr_removed); } @@ -272,7 +274,7 @@ TEST(StripDefaultAttributesTest, NonDefaultNotStripped) { TEST(StrippedOpListForGraphTest, FlatTest) { // Make four ops OpList op_list; - for (const string& op : {"A", "B", "C", "D"}) { + for (const std::string& op : {"A", "B", "C", "D"}) { OpDef* op_def = op_list.add_op(); op_def->set_name(op); op_def->set_summary("summary"); @@ -282,7 +284,7 @@ TEST(StrippedOpListForGraphTest, FlatTest) { // Make a graph which uses two ops once and twice, respectively. // The result should be independent of the ordering. - const string graph_ops[4][3] = { + const std::string graph_ops[4][3] = { {"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}}; for (const bool use_function : {false, true}) { for (int order = 0; order < 4; order++) { @@ -290,13 +292,13 @@ TEST(StrippedOpListForGraphTest, FlatTest) { if (use_function) { FunctionDef* function_def = graph_def.mutable_library()->add_function(); function_def->mutable_signature()->set_name("F"); - for (const string& op : graph_ops[order]) { + for (const std::string& op : graph_ops[order]) { function_def->add_node_def()->set_op(op); } graph_def.add_node()->set_op("F"); } else { - for (const string& op : graph_ops[order]) { - string name = absl::StrCat("name", graph_def.node_size()); + for (const std::string& op : graph_ops[order]) { + std::string name = absl::StrCat("name", graph_def.node_size()); NodeDef* node = graph_def.add_node(); node->set_name(name); node->set_op(op); @@ -319,9 +321,9 @@ TEST(StrippedOpListForGraphTest, FlatTest) { } // Should get the same result using OpsUsedByGraph(). - std::set used_ops; + std::set used_ops; OpsUsedByGraph(graph_def, &used_ops); - ASSERT_EQ(std::set({"B", "C"}), used_ops); + ASSERT_EQ(std::set({"B", "C"}), used_ops); } } } @@ -356,9 +358,9 @@ TEST(StrippedOpListForGraphTest, NestedFunctionTest) { ASSERT_EQ(stripped_op_list.op(0).name(), "A"); // Should get the same result using OpsUsedByGraph(). - std::set used_ops; + std::set used_ops; OpsUsedByGraph(graph_def, &used_ops); - ASSERT_EQ(std::set({"A"}), used_ops); + ASSERT_EQ(std::set({"A"}), used_ops); } } diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index 95b6287c4e56b6..b3226c6fac490b 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -51,45 +51,45 @@ class NodeNameMapping { // Normalize the input name and make it unique. This is the same as the // function for output, expect that it adds a name mapping for the name. - string GetInputName(const string& name); + std::string GetInputName(const std::string& name); // Normalize the output name and make it unique. - string GetOutputName(const string& name); + std::string GetOutputName(const std::string& name); // Make the node name unique. - string Uniquify(const string& name); + std::string Uniquify(const std::string& name); // Records name as a used name. If this name is already used, // returns an error status. - absl::Status UseOutputName(const string& name); + absl::Status UseOutputName(const std::string& name); // Look up how a node name was previously normalized/uniquified. // Returns empty if name was never seen. - string Lookup(const string& name) const; + std::string Lookup(const std::string& name) const; private: - string UniquifyHelper(const string& name); - static string Normalize(string name); + std::string UniquifyHelper(const std::string& name); + static std::string Normalize(std::string name); // The normalized/uniquified names already used as // input names (in signature), output names (in signature), and node names // (in node_def). // This is a superset of values in name_mapping_. - absl::flat_hash_map used_names_; + absl::flat_hash_map used_names_; // Mapping from original node name from the graph to the normalized // and uniquified version of it. - absl::flat_hash_map name_mapping_; + absl::flat_hash_map name_mapping_; }; -string NodeNameMapping::Normalize(string name) { +std::string NodeNameMapping::Normalize(std::string name) { // Convert letters to lowercase and non-alphanumeric characters to '_'. if (name.empty()) return "unknown"; const int n = name.size(); for (int i = 0; i < n; ++i) { char c = name[i]; - if (isalnum(c)) { - if (isupper(c)) { - name[i] = tolower(c); + if (absl::ascii_isalnum(c)) { + if (absl::ascii_isupper(c)) { + name[i] = absl::ascii_tolower(c); } } else { name[i] = '_'; @@ -99,45 +99,45 @@ string NodeNameMapping::Normalize(string name) { // Find the first letter and start with it. int i = 0; for (; i < n; ++i) { - if (isalpha(name[i])) break; + if (absl::ascii_isalpha(name[i])) break; } // Return "unknown" if none of the name's chars were letters. return i == n ? "unknown" : name.substr(i); } -string NodeNameMapping::UniquifyHelper(const string& name) { +std::string NodeNameMapping::UniquifyHelper(const std::string& name) { auto it = used_names_.emplace(name, 0); // If the name hasn't been used yet, use it as-is. if (it.second) return name; // Add a suffix to name to make it unique. while (true) { - const string candidate = absl::StrCat(name, "_", it.first->second); + const std::string candidate = absl::StrCat(name, "_", it.first->second); it.first->second++; if (used_names_.emplace(candidate, 0).second) return candidate; } } -string NodeNameMapping::GetInputName(const string& name) { - const string& input_name = UniquifyHelper(Normalize(name)); +std::string NodeNameMapping::GetInputName(const std::string& name) { + const std::string& input_name = UniquifyHelper(Normalize(name)); name_mapping_[name] = input_name; return input_name; } -string NodeNameMapping::GetOutputName(const string& name) { - const string& input_name = UniquifyHelper(Normalize(name)); +std::string NodeNameMapping::GetOutputName(const std::string& name) { + const std::string& input_name = UniquifyHelper(Normalize(name)); // Don't add it to name_mapping_ since this name is not for a node. return input_name; } -string NodeNameMapping::Uniquify(const string& name) { - const string uniqued = UniquifyHelper(name); +std::string NodeNameMapping::Uniquify(const std::string& name) { + const std::string uniqued = UniquifyHelper(name); name_mapping_[name] = uniqued; return uniqued; } -absl::Status NodeNameMapping::UseOutputName(const string& name) { +absl::Status NodeNameMapping::UseOutputName(const std::string& name) { const auto& iter = used_names_.find(name); if (iter != used_names_.end()) { return errors::InvalidArgument( @@ -148,19 +148,19 @@ absl::Status NodeNameMapping::UseOutputName(const string& name) { return absl::OkStatus(); } -string NodeNameMapping::Lookup(const string& name) const { +std::string NodeNameMapping::Lookup(const std::string& name) const { const auto iter = name_mapping_.find(name); - if (iter == name_mapping_.end()) return string(); + if (iter == name_mapping_.end()) return std::string(); return iter->second; } absl::Status FillFunctionBody( - const string& fn_name, const NodeNameMapping& node_names, + const std::string& fn_name, const NodeNameMapping& node_names, const std::vector& body_nodes, - const absl::flat_hash_map& tensor_renaming, + const absl::flat_hash_map& tensor_renaming, bool set_stateful_from_nodes, bool copy_placeholder_attrs_from_nodes, bool allow_destructive_reads, FunctionDef* fdef) { - absl::flat_hash_set func_attr_names; + absl::flat_hash_set func_attr_names; for (const auto& func_attr : fdef->signature().attr()) { func_attr_names.insert(func_attr.name()); } @@ -263,7 +263,7 @@ absl::Status FillFunctionBody( for (const Edge* edge : control_edges) { // Add this control input only if the src node is in the body or a part of // the inputs. - const string normalized = node_names.Lookup(edge->src()->name()); + const std::string normalized = node_names.Lookup(edge->src()->name()); // If we did not find a name for the source of control edge, this // source must be outside of the body, and not an input. Raise an error. if (normalized.empty()) { @@ -322,15 +322,16 @@ absl::Status FillFunctionBody( } absl::Status GraphToFunctionDefHelper( - const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, - bool set_stateful_from_nodes, bool copy_placeholder_attrs_from_nodes, + const Graph& fn_body, const std::string& fn_name, + bool append_hash_to_fn_name, bool set_stateful_from_nodes, + bool copy_placeholder_attrs_from_nodes, const std::vector& body_nodes, const std::vector& inputs, const std::vector& outputs, - const std::vector& output_names, + const std::vector& output_names, const std::vector& control_outputs, - const std::vector& control_output_names, const char* description, - bool allow_destructive_reads, FunctionDef* fdef) { + const std::vector& control_output_names, + const char* description, bool allow_destructive_reads, FunctionDef* fdef) { if (!output_names.empty()) { DCHECK_EQ(output_names.size(), outputs.size()); } @@ -350,7 +351,7 @@ absl::Status GraphToFunctionDefHelper( // - For tensors produced by nodes in function's body: // {flat_tensor_name -> nested_tensor_name} // e.g. {Add:3 -> add_0:z:1} - absl::flat_hash_map tensor_renaming; + absl::flat_hash_map tensor_renaming; // Fill outputs in function's signature. // We fill the outputs first to prevent output_names from colliding @@ -380,7 +381,7 @@ absl::Status GraphToFunctionDefHelper( int idx = inputs[i].index; OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); argdef->set_type(node->output_type(idx)); - const string& input_name = node_names.GetInputName(node->name()); + const std::string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); FunctionDef::ArgAttrs arg_attrs; int64_t resource_arg_unique_id = -1; @@ -431,7 +432,7 @@ absl::Status GraphToFunctionDefHelper( // in tensor_renaming. for (const Node* node : body_nodes) { // Make sure node_name does not collide with an input or output name. - const string& node_name = node_names.Uniquify(node->name()); + const std::string& node_name = node_names.Uniquify(node->name()); // For each output_arg in the op_def, the output_ranges // map will have [start, end] range of indices that this arg produces // among all the output tensors of this op. @@ -443,8 +444,8 @@ absl::Status GraphToFunctionDefHelper( int index_start = output.second.first; int index_end = output.second.second; for (int i = index_start; i < index_end; ++i) { - const string& original_name = absl::StrCat(node->name(), ":", i); - const string& new_name = + const std::string& original_name = absl::StrCat(node->name(), ":", i); + const std::string& new_name = strings::StrCat(node_name, ":", output_name, ":", i - index_start); // Record the mapping if this tensor is not already mapped. // Tensor can be already mapped if it is used as an input. @@ -461,10 +462,10 @@ absl::Status GraphToFunctionDefHelper( // Remap return values. for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { - const string& ret_name = fdef->signature().output_arg(r).name(); + const std::string& ret_name = fdef->signature().output_arg(r).name(); // We convert this flat tensor name to the nested value // (e.g. `add:z:1`) that we stored in tensor_renaming. - string return_value; + std::string return_value; if (outputs[r].node->IsRetval()) { Edge const* edge; TF_RETURN_IF_ERROR(outputs[r].node->input_edge(0, &edge)); @@ -484,8 +485,8 @@ absl::Status GraphToFunctionDefHelper( } if (append_hash_to_fn_name) { - const uint64 hash = FunctionDefHash(*fdef); - string encoded; + const uint64_t hash = FunctionDefHash(*fdef); + std::string encoded; TF_RETURN_IF_ERROR(Base64Encode( absl::string_view(reinterpret_cast(&hash), sizeof(hash)), &encoded)); @@ -508,9 +509,9 @@ absl::Status GraphToFunctionDefHelper( ") and the number of control output names (", control_output_names.size(), ") to match but they do not."); } - std::set control_output_names_set; + std::set control_output_names_set; for (int i = 0; i < control_outputs.size(); ++i) { - string signature_name; + std::string signature_name; if (!control_output_names.empty()) { signature_name = control_output_names[i]; } else { @@ -523,7 +524,7 @@ absl::Status GraphToFunctionDefHelper( return errors::InvalidArgument("Repeated control output name: ", signature_name); } - const string control_output_node = + const std::string control_output_node = node_names.Lookup(control_outputs[i]->name()); if (control_output_node.empty()) { return errors::InvalidArgument( @@ -531,7 +532,7 @@ absl::Status GraphToFunctionDefHelper( } (*fdef->mutable_control_ret())[signature_name] = control_output_node; } - for (const string& control_output : control_output_names_set) { + for (const std::string& control_output : control_output_names_set) { fdef->mutable_signature()->add_control_output(control_output); } @@ -539,9 +540,9 @@ absl::Status GraphToFunctionDefHelper( } absl::Status GraphToFunctionDefHelper( - const Graph& graph, const string& name, - const std::function(const Node*)>& control_ret, - const std::vector& output_names, bool allow_destructive_reads, + const Graph& graph, const std::string& name, + const std::function(const Node*)>& control_ret, + const std::vector& output_names, bool allow_destructive_reads, FunctionDef* fdef) { auto add_arg_or_retval = [](Node* node, std::vector* args_or_retvals) { @@ -566,7 +567,7 @@ absl::Status GraphToFunctionDefHelper( std::vector inputs; std::vector outputs; std::vector control_outputs; - std::vector control_output_names; + std::vector control_output_names; for (Node* node : graph.op_nodes()) { if (node->IsArg()) { TF_RETURN_IF_ERROR(add_arg_or_retval(node, &inputs)); @@ -591,7 +592,7 @@ absl::Status GraphToFunctionDefHelper( auto validate_args_retvals = [](const std::vector& args_or_retvals, - const string& op_type) { + const std::string& op_type) { for (int i = 0, e = args_or_retvals.size(); i < e; ++i) { if (args_or_retvals[i].node == nullptr) { return errors::InvalidArgument("Missing '", op_type, @@ -614,17 +615,17 @@ absl::Status GraphToFunctionDefHelper( } // anonymous namespace -absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, - bool append_hash_to_fn_name, - bool set_stateful_from_nodes, - bool copy_placeholder_attrs_from_nodes, - const std::vector& body_nodes, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& output_names, - const std::vector& control_outputs, - const std::vector& control_output_names, - const char* description, FunctionDef* fdef) { +absl::Status GraphToFunctionDef( + const Graph& fn_body, const std::string& fn_name, + bool append_hash_to_fn_name, bool set_stateful_from_nodes, + bool copy_placeholder_attrs_from_nodes, + const std::vector& body_nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, + const char* description, FunctionDef* fdef) { return GraphToFunctionDefHelper( fn_body, fn_name, append_hash_to_fn_name, set_stateful_from_nodes, copy_placeholder_attrs_from_nodes, body_nodes, inputs, outputs, @@ -634,20 +635,20 @@ absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, } absl::Status GraphToFunctionDef( - const Graph& graph, const string& name, - const std::function(const Node*)>& control_ret, + const Graph& graph, const std::string& name, + const std::function(const Node*)>& control_ret, FunctionDef* fdef) { return GraphToFunctionDefHelper(graph, name, control_ret, /*output_names=*/{}, /*allow_destructive_reads=*/false, fdef); } -absl::Status GraphToFunctionDef(const Graph& graph, const string& name, +absl::Status GraphToFunctionDef(const Graph& graph, const std::string& name, FunctionDef* fdef) { return GraphToFunctionDef(graph, name, /*control_ret=*/nullptr, fdef); } -absl::Status GraphToFunctionDef(const Graph& graph, const string& name, +absl::Status GraphToFunctionDef(const Graph& graph, const std::string& name, const std::vector& output_names, FunctionDef* fdef) { return GraphToFunctionDefHelper(graph, name, /*control_ret=*/nullptr, @@ -656,8 +657,8 @@ absl::Status GraphToFunctionDef(const Graph& graph, const string& name, } absl::Status GraphToFunctionDef( - std::unique_ptr graph, const string& name, - const std::function(const Node*)>& control_ret, + std::unique_ptr graph, const std::string& name, + const std::function(const Node*)>& control_ret, FunctionDef* fdef) { return GraphToFunctionDefHelper(*graph, name, control_ret, /*output_names=*/{}, diff --git a/tensorflow/core/framework/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h index 369b86ecea5e03..4558af7938f312 100644 --- a/tensorflow/core/framework/graph_to_functiondef.h +++ b/tensorflow/core/framework/graph_to_functiondef.h @@ -29,17 +29,17 @@ namespace tensorflow { // Graph to FunctionDef conversion. This code is closely modeled on the Python // function graph_to_function_def(), which is located in // tensorflow/python/framework/graph_to_function_def.py. -absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, - bool append_hash_to_fn_name, - bool set_stateful_from_nodes, - bool copy_placeholder_attrs_from_nodes, - const std::vector& body_nodes, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& output_names, - const std::vector& control_outputs, - const std::vector& control_output_names, - const char* description, FunctionDef* fdef); +absl::Status GraphToFunctionDef( + const Graph& fn_body, const std::string& fn_name, + bool append_hash_to_fn_name, bool set_stateful_from_nodes, + bool copy_placeholder_attrs_from_nodes, + const std::vector& body_nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, + const char* description, FunctionDef* fdef); // Converts 'graph' to a FunctionDef 'fdef', with name 'name': // @@ -50,20 +50,20 @@ absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, // `control_output` in Op definition (see OpDef). Control output name must // be unique for all control output nodes. absl::Status GraphToFunctionDef( - const Graph& graph, const string& name, - const std::function(const Node*)>& control_ret, + const Graph& graph, const std::string& name, + const std::function(const Node*)>& control_ret, FunctionDef* fdef); -absl::Status GraphToFunctionDef(const Graph& graph, const string& name, +absl::Status GraphToFunctionDef(const Graph& graph, const std::string& name, FunctionDef* fdef); -absl::Status GraphToFunctionDef(const Graph& graph, const string& name, +absl::Status GraphToFunctionDef(const Graph& graph, const std::string& name, const std::vector& output_names, FunctionDef* fdef); absl::Status GraphToFunctionDef( - std::unique_ptr graph, const string& name, - const std::function(const Node*)>& control_ret, + std::unique_ptr graph, const std::string& name, + const std::function(const Node*)>& control_ret, FunctionDef* fdef); } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_to_functiondef_test.cc b/tensorflow/core/framework/graph_to_functiondef_test.cc index d71f6b9ff47a3b..719f9af233758e 100644 --- a/tensorflow/core/framework/graph_to_functiondef_test.cc +++ b/tensorflow/core/framework/graph_to_functiondef_test.cc @@ -47,7 +47,7 @@ FunctionDef RemoveDebugInfo(const FunctionDef& def) { } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, - string* diff) { + std::string* diff) { // TODO(phawkins) use a more sophisticated equality test. if (a.DebugString() != b.DebugString()) { if (diff) { @@ -95,7 +95,7 @@ TEST(GraphToFunctionDefTest, Basics) { }, {{"h", "G:sum:0"}}); // return values - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); @@ -119,7 +119,7 @@ TEST(GraphToFunctionDefTest, OverrideOutputNames) { {}, // body {{"b", "a"}}); // return values - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); @@ -168,7 +168,7 @@ TEST(GraphToFunctionDefTest, ArgAttrShape) { attrs.mutable_attr()->insert({"_output_shapes", output_shapes}); (*fdef_expected.mutable_arg_attr())[0] = std::move(attrs); - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); @@ -199,7 +199,7 @@ TEST(GraphToFunctionDefTest, ArgAttrPrivateAttr) { attrs.mutable_attr()->insert({"_name", private_attr}); (*fdef_expected.mutable_arg_attr())[0] = std::move(attrs); - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); @@ -266,7 +266,7 @@ TEST(GraphToFunctionDefTest, ArgAttrConstInput) { (*fdef_expected.mutable_signature()->mutable_description()) = "ArgAttrConstInput"; - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); @@ -374,7 +374,7 @@ TEST(GraphToFunctionDefTest, ControlDependencies) { }, {{"c", "b:y:0"}}); // return values - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); @@ -395,8 +395,9 @@ TEST(GraphToFunctionDefTest, ControlOutputs) { TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, graph.get())); // Add a 'b' node to the control return set. - const auto control_ret = [](const Node* n) -> absl::optional { - if (n->name() == "b") return absl::make_optional("must_execute"); + const auto control_ret = [](const Node* n) -> absl::optional { + if (n->name() == "b") + return absl::make_optional("must_execute"); return absl::nullopt; }; @@ -415,7 +416,7 @@ TEST(GraphToFunctionDefTest, ControlOutputs) { {{"c", "b:y:0"}}, // return values {{"must_execute", "b"}}); // control returns - string diff; + std::string diff; bool fdefs_equal = EqualFunctionDef(fdef_expected, RemoveDebugInfo(fdef), &diff); diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc index c9788b0a08c45f..7b7e90df8bab2a 100644 --- a/tensorflow/core/framework/kernel_def_builder.cc +++ b/tensorflow/core/framework/kernel_def_builder.cc @@ -55,8 +55,8 @@ KernelDefBuilder& KernelDefBuilder::AttrConstraint( } template <> -KernelDefBuilder& KernelDefBuilder::AttrConstraint( - const char* attr_name, absl::Span allowed) { +KernelDefBuilder& KernelDefBuilder::AttrConstraint( + const char* attr_name, absl::Span allowed) { auto* constraint = kernel_def_->add_constraint(); constraint->set_name(attr_name); auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); @@ -67,11 +67,11 @@ KernelDefBuilder& KernelDefBuilder::AttrConstraint( } template <> -KernelDefBuilder& KernelDefBuilder::AttrConstraint( - const char* attr_name, string allowed) { - return AttrConstraint( - attr_name, - absl::Span(std::initializer_list({allowed}))); +KernelDefBuilder& KernelDefBuilder::AttrConstraint( + const char* attr_name, std::string allowed) { + return AttrConstraint(attr_name, + absl::Span( + std::initializer_list({allowed}))); } template <> diff --git a/tensorflow/core/framework/kernel_def_builder_test.cc b/tensorflow/core/framework/kernel_def_builder_test.cc index fa37b114abbe22..eefa454beb763e 100644 --- a/tensorflow/core/framework/kernel_def_builder_test.cc +++ b/tensorflow/core/framework/kernel_def_builder_test.cc @@ -48,7 +48,7 @@ TEST(KernelDefBuilderTest, TypeConstraint) { def = KernelDefBuilder("C") .Device(DEVICE_GPU) - .TypeConstraint("U") + .TypeConstraint("U") .TypeConstraint("V") .Build(); @@ -95,7 +95,7 @@ TEST(KernelDefBuilderTest, Int64Constraint) { .Device(DEVICE_GPU) .AttrConstraint("U", absl::Span{int64_t{5}, int64_t{17}}) - .AttrConstraint("V", string("proto")) + .AttrConstraint("V", std::string("proto")) .Build(); protobuf::TextFormat::ParseFromString( @@ -136,7 +136,7 @@ TEST(KernelDefBuilderTest, StringConstraint) { def = KernelDefBuilder("C") .Device(DEVICE_GPU) .AttrConstraint("U", absl::Span{"boo", "ya"}) - .AttrConstraint("V", string("proto")) + .AttrConstraint("V", std::string("proto")) .Build(); protobuf::TextFormat::ParseFromString( diff --git a/tensorflow/core/framework/kernel_def_util_test.cc b/tensorflow/core/framework/kernel_def_util_test.cc index a2e4aa82fafd56..a15fa7b0cfbe0f 100644 --- a/tensorflow/core/framework/kernel_def_util_test.cc +++ b/tensorflow/core/framework/kernel_def_util_test.cc @@ -24,13 +24,13 @@ namespace tensorflow { namespace { -NodeDef NodeDefFromText(const string& text) { +NodeDef NodeDefFromText(const std::string& text) { NodeDef node_def; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); return node_def; } -KernelDef KernelDefFromText(const string& text) { +KernelDef KernelDefFromText(const std::string& text) { KernelDef kernel_def; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &kernel_def)); return kernel_def; @@ -38,8 +38,8 @@ KernelDef KernelDefFromText(const string& text) { class AttrsMatchTest : public ::testing::Test { protected: - void ExpectStatus(const string& node_def_str, const string& kernel_def_str, - error::Code code) { + void ExpectStatus(const std::string& node_def_str, + const std::string& kernel_def_str, error::Code code) { bool match; auto status = KernelAttrsMatch(KernelDefFromText(kernel_def_str), NodeDefFromText(node_def_str), &match); @@ -53,7 +53,7 @@ class AttrsMatchTest : public ::testing::Test { }; TEST_F(AttrsMatchTest, ValidConstraint) { - string node_def_str = R"( + std::string node_def_str = R"( name: "ValidConstraint-op" op: "ValidConstraint" attr { @@ -63,7 +63,7 @@ TEST_F(AttrsMatchTest, ValidConstraint) { } } )"; - string kernel_def_str = R"( + std::string kernel_def_str = R"( op: "ValidConstraint" device_type: "CPU" constraint { @@ -79,7 +79,7 @@ TEST_F(AttrsMatchTest, ValidConstraint) { } TEST_F(AttrsMatchTest, BadConstraint) { - string node_def_str = R"( + std::string node_def_str = R"( name: "BadConstraint-op" op: "BadConstraint" attr { @@ -89,7 +89,7 @@ TEST_F(AttrsMatchTest, BadConstraint) { } } )"; - string kernel_def_str = R"( + std::string kernel_def_str = R"( op: "BadConstraint" device_type: "CPU" constraint { @@ -105,7 +105,7 @@ TEST_F(AttrsMatchTest, BadConstraint) { } TEST_F(AttrsMatchTest, Unimplemented) { - string node_def_str = R"( + std::string node_def_str = R"( name: "BadConstraint-op" op: "BadConstraint" attr { @@ -115,7 +115,7 @@ TEST_F(AttrsMatchTest, Unimplemented) { } } )"; - string kernel_def_str = R"( + std::string kernel_def_str = R"( op: "BadConstraint" device_type: "CPU" constraint { diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index a8ad5ba42069a7..df63471f59dff3 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -46,10 +46,10 @@ struct Library { absl::Status LoadDynamicLibrary(const char* library_filename, void** result, const void** buf, size_t* len) { static mutex mu(LINKER_INITIALIZED); - static std::unordered_map loaded_libs; + static std::unordered_map loaded_libs; Env* env = Env::Default(); Library library; - std::unordered_set seen_op_names; + std::unordered_set seen_op_names; { mutex_lock lock(mu); if (loaded_libs.find(library_filename) != loaded_libs.end()) { @@ -90,7 +90,7 @@ absl::Status LoadDynamicLibrary(const char* library_filename, void** result, loaded_libs[library_filename] = library; } } - string str; + std::string str; library.op_list.SerializeToString(&str); char* str_buf = reinterpret_cast(port::Malloc(str.length())); memcpy(str_buf, str.data(), str.length()); diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc index fffc5f8864e992..6a56c1695d35b9 100644 --- a/tensorflow/core/framework/local_rendezvous.cc +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -141,7 +141,7 @@ LocalRendezvous::~LocalRendezvous() { } namespace { -uint64 KeyHash(const absl::string_view& k) { +uint64_t KeyHash(const absl::string_view& k) { return Hash64(k.data(), k.size()); } } // namespace @@ -149,7 +149,7 @@ uint64 KeyHash(const absl::string_view& k) { absl::Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, const Rendezvous::Args& send_args, const Tensor& val, const bool is_dead) { - uint64 key_hash = KeyHash(key.FullKey()); + uint64_t key_hash = KeyHash(key.FullKey()); DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey(); if (is_dead) { @@ -158,7 +158,7 @@ absl::Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, "The number of dead values sent between a pair of devices.", "send_device", "recv_device"); rendezvous_dead_values_sent - ->GetCell(string(key.src_device), string(key.dst_device)) + ->GetCell(std::string(key.src_device), std::string(key.dst_device)) ->IncrementBy(1); } @@ -229,7 +229,7 @@ absl::Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) { - uint64 key_hash = KeyHash(key.FullKey()); + uint64_t key_hash = KeyHash(key.FullKey()); DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey(); tsl::core::RefCountPtr rc_keep_alive; diff --git a/tensorflow/core/framework/local_rendezvous.h b/tensorflow/core/framework/local_rendezvous.h index 332daaa6c02060..628bd4642f4762 100644 --- a/tensorflow/core/framework/local_rendezvous.h +++ b/tensorflow/core/framework/local_rendezvous.h @@ -82,7 +82,7 @@ class LocalRendezvous { Item* tail = nullptr; }; - typedef gtl::FlatMap Table; + typedef gtl::FlatMap Table; const int num_buckets_; // Pointer to the owner class of this LocalRendezvous if it is refcounted, diff --git a/tensorflow/core/framework/log_memory.cc b/tensorflow/core/framework/log_memory.cc index b168957ef7ed03..4fc2b86e18f156 100644 --- a/tensorflow/core/framework/log_memory.cc +++ b/tensorflow/core/framework/log_memory.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -const string LogMemory::kLogMemoryLabel = "__LOG_MEMORY__"; +const std::string LogMemory::kLogMemoryLabel = "__LOG_MEMORY__"; bool LogMemory::IsEnabled() { return VLOG_IS_ON(2); } @@ -28,23 +28,23 @@ namespace { // Write the proto entry to LOG(INFO). template void OutputToLog(const T& proto) { - string type_name(proto.GetTypeName()); + std::string type_name(proto.GetTypeName()); const size_t index = type_name.find_last_of('.'); - if (index != string::npos) type_name = type_name.substr(index + 1); + if (index != std::string::npos) type_name = type_name.substr(index + 1); LOG(INFO) << LogMemory::kLogMemoryLabel << " " << type_name << " { " << proto.ShortDebugString() << " }"; } } // namespace -void LogMemory::RecordStep(const int64_t step_id, const string& handle) { +void LogMemory::RecordStep(const int64_t step_id, const std::string& handle) { MemoryLogStep step; step.set_step_id(step_id); step.set_handle(handle); OutputToLog(step); } -void LogMemory::RecordTensorAllocation(const string& kernel_name, +void LogMemory::RecordTensorAllocation(const std::string& kernel_name, const int64_t step_id, const Tensor& tensor) { MemoryLogTensorAllocation allocation; @@ -55,14 +55,14 @@ void LogMemory::RecordTensorAllocation(const string& kernel_name, } void LogMemory::RecordTensorDeallocation(const int64_t allocation_id, - const string& allocator_name) { + const std::string& allocator_name) { MemoryLogTensorDeallocation deallocation; deallocation.set_allocation_id(allocation_id); deallocation.set_allocator_name(allocator_name); OutputToLog(deallocation); } -void LogMemory::RecordTensorOutput(const string& kernel_name, +void LogMemory::RecordTensorOutput(const std::string& kernel_name, const int64_t step_id, const int index, const Tensor& tensor) { MemoryLogTensorOutput output; @@ -73,7 +73,7 @@ void LogMemory::RecordTensorOutput(const string& kernel_name, OutputToLog(output); } -void LogMemory::RecordRawAllocation(const string& operation, +void LogMemory::RecordRawAllocation(const std::string& operation, const int64_t step_id, size_t num_bytes, void* ptr, Allocator* allocator) { MemoryLogRawAllocation allocation; @@ -86,7 +86,7 @@ void LogMemory::RecordRawAllocation(const string& operation, OutputToLog(allocation); } -void LogMemory::RecordRawDeallocation(const string& operation, +void LogMemory::RecordRawDeallocation(const std::string& operation, const int64_t step_id, void* ptr, Allocator* allocator, bool deferred) { MemoryLogRawDeallocation deallocation; diff --git a/tensorflow/core/framework/logging.cc b/tensorflow/core/framework/logging.cc index 14f23b06d0e5e3..d10b4d555fd00f 100644 --- a/tensorflow/core/framework/logging.cc +++ b/tensorflow/core/framework/logging.cc @@ -36,13 +36,13 @@ bool RegisterListener(void (*listener)(const char*)) { return true; } -bool LogToListeners(string msg, string end) { +bool LogToListeners(std::string msg, std::string end) { auto listeners = logging::GetListeners(); if (listeners->empty()) { return false; } - string ended_msg = absl::StrCat(msg, end); + std::string ended_msg = absl::StrCat(msg, end); for (auto& listener : *listeners) { listener(ended_msg.c_str()); diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h index 06524726e5cfc1..ccc167ca91474e 100644 --- a/tensorflow/core/framework/lookup_interface.h +++ b/tensorflow/core/framework/lookup_interface.h @@ -133,7 +133,7 @@ class LookupInterface : public ResourceBase { absl::Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); - string DebugString() const override { + std::string DebugString() const override { return absl::StrCat("A lookup table of size: ", size()); } diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 8b187beb125740..11317fa9656c1f 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -43,7 +43,7 @@ int GetTotal(const NameRangeMap& name_map) { // to DEVICE_MEMORY except those args in host_memory_args. Removes // elements of host_memory_args that were used. void MemoryTypesHelper(const NameRangeMap& name_map, - std::vector* host_memory_args, + std::vector* host_memory_args, MemoryTypeVector* memory_types) { // Update args that have been marked as in "HOST_MEMORY". size_t keep = 0; @@ -62,10 +62,10 @@ void MemoryTypesHelper(const NameRangeMap& name_map, host_memory_args->resize(keep); } -bool IsFunctionCallOp(const string& op_type) { +bool IsFunctionCallOp(const std::string& op_type) { return op_type == "SymbolicGradient" || op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall" || op_type == "While" || - op_type == "StatelessWhile" || op_type == "If" || + op_type == "StatelessWhile" || op_type == "If" || op_type == "StatelessIf"; } @@ -110,11 +110,11 @@ absl::Status MemoryTypesForNode(const OpRegistryInterface* op_registry, bool is_fn = IsFunctionCallOp(ndef.op()); bool has_kernel_def = status.ok() && !is_fn; auto host_memory_required = [&](const DataType& dt) { - bool int32_on_device = + bool int32_on_device = has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile; return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device); }; - + // Edge cases: // 1. If[Tcond=DT_BOOL, Tin=[DT_FLOAT,DT_INT32], Tout=[DT_FLOAT,DT_INT32]] // * Tcond marked HostMemory by kernel_def @@ -146,17 +146,16 @@ absl::Status MemoryTypesForNode(const OpRegistryInterface* op_registry, out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); } - // Fills in host memory types based on the kernel def - if(kdef != nullptr) { // can this ever be false? - const auto& from_proto = kdef->host_memory_arg(); - std::vector host_memory_args(from_proto.begin(), from_proto.end()); - MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); - MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); - if (!host_memory_args.empty()) { - return errors::InvalidArgument( - "HostMemory args '", absl::StrJoin(host_memory_args, "', '"), - "' not found in OpDef: ", SummarizeOpDef(*op_def)); - } + // Fills in host memory types based on the kernel def. + const auto& from_proto = kdef->host_memory_arg(); + std::vector host_memory_args(from_proto.begin(), + from_proto.end()); + MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); + MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", absl::StrJoin(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(*op_def)); } } else { inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY); @@ -177,7 +176,7 @@ absl::Status MemoryTypesForNode(const OpRegistryInterface* op_registry, } } - std::vector hostmem_attr; + std::vector hostmem_attr; if (TryGetNodeAttr(ndef, "_input_hostmem", &hostmem_attr)) { for (int32_t i : hostmem_attr) { if (0 <= i && i < inp_mtypes->size()) { diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index dafcef280b48e5..c55d7e46a89140 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -305,7 +305,7 @@ auto* tf_data_pipeline_processing_time = tsl::monitoring::Gauge::New( "in microseconds", "id"); -auto* tf_data_auto_shard = tsl::monitoring::Gauge::New( +auto* tf_data_auto_shard = tsl::monitoring::Gauge::New( "/tensorflow/data/autoshard", "tf.data autoshard statistics.", "id", "name"); @@ -490,39 +490,41 @@ std::string GraphOptimizationSourceMapping(GraphOptimizationSource source) { } } -void RecordTFDataFetchOp(const string& name) { +void RecordTFDataFetchOp(const std::string& name) { tf_data_fetch_op_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataAutotune(const string& name) { +void RecordTFDataAutotune(const std::string& name) { tf_data_autotune_counter->GetCell(name)->IncrementBy(1); } tsl::monitoring::CounterCell* GetTFDataBytesConsumedCounter( - const string& name) { + const std::string& name) { return tf_data_bytes_consumed_counter->GetCell(name); } tsl::monitoring::CounterCell* GetTFDataBytesProducedCounter( - const string& name) { + const std::string& name) { return tf_data_bytes_produced_counter->GetCell(name); } -tsl::monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name) { +tsl::monitoring::CounterCell* GetTFDataBytesReadCounter( + const std::string& name) { return tf_data_bytes_read_counter->GetCell(name); } -tsl::monitoring::CounterCell* GetTFDataElementsCounter(const string& name) { +tsl::monitoring::CounterCell* GetTFDataElementsCounter( + const std::string& name) { return tf_data_elements_counter->GetCell(name); } tsl::monitoring::GaugeCell>* GetTFDataModelGauge( - const string& id) { + const std::string& id) { return tf_data_model_gauge->GetCell(id); } tsl::monitoring::GaugeCell* GetTFDataPipelineProcessingTimeGauge( - const string& id) { + const std::string& id) { return tf_data_pipeline_processing_time->GetCell(id); } @@ -530,23 +532,23 @@ void RecordTFDataBytesFetched(int64_t num_bytes) { tf_data_bytes_fetched_counter->GetCell()->IncrementBy(num_bytes); } -void RecordTFDataExperiment(const string& name) { +void RecordTFDataExperiment(const std::string& name) { tf_data_experiment_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataExperimentLive(const string& name) { +void RecordTFDataExperimentLive(const std::string& name) { tf_data_experiment_live_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataExperimentOptIn(const string& name) { +void RecordTFDataExperimentOptIn(const std::string& name) { tf_data_experiment_opt_in_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataExperimentOptOut(const string& name) { +void RecordTFDataExperimentOptOut(const std::string& name) { tf_data_experiment_opt_out_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataFingerprint(const string& name) { +void RecordTFDataFingerprint(const std::string& name) { tf_data_fingerprint_counter->GetCell(name)->IncrementBy(1); } @@ -557,18 +559,18 @@ void RecordTFDataServiceRuntimeCompressionDecision(bool compression_disabled) { ->IncrementBy(1); } -void RecordTFDataServiceCompressionAction(const string& action) { +void RecordTFDataServiceCompressionAction(const std::string& action) { tf_data_service_compression->GetCell(action)->IncrementBy(1); } -void RecordTFDataServiceGetElementDuration(const string& data_transfer_protocol, - uint64 duration_us) { +void RecordTFDataServiceGetElementDuration( + const std::string& data_transfer_protocol, uint64_t duration_us) { tf_data_service_get_element_duration_usecs_histogram ->GetCell(data_transfer_protocol) ->Add(duration_us); } -void RecordTFDataGetNextDuration(uint64 duration_us) { +void RecordTFDataGetNextDuration(uint64_t duration_us) { static auto* tf_data_get_next_duration_cell = tf_data_get_next_duration_usecs_histogram->GetCell(); tf_data_get_next_duration_cell->Add(duration_us); @@ -586,25 +588,25 @@ void RecordTFDataAutotuneMaxBufferBudgetRatio(const double ratio) { tf_data_buffered_vs_budget_ratio_histogram_cell->Add(ratio); } -void RecordTFDataIteratorBusy(uint64 duration_us) { +void RecordTFDataIteratorBusy(uint64_t duration_us) { static auto* tf_data_iterator_busy_cell = tf_data_iterator_busy_counter->GetCell(); tf_data_iterator_busy_cell->IncrementBy(duration_us); } -void RecordTFDataIteratorLifetime(uint64 duration_us) { +void RecordTFDataIteratorLifetime(uint64_t duration_us) { static auto* tf_data_iterator_lifetime_cell = tf_data_iterator_lifetime_counter->GetCell(); tf_data_iterator_lifetime_cell->IncrementBy(duration_us); } -void RecordTFDataIteratorGap(uint64 duration_us) { +void RecordTFDataIteratorGap(uint64_t duration_us) { static auto* tf_data_iterator_gap_msec_histogram_cell = tf_data_iterator_gap_msec_histogram->GetCell(); tf_data_iterator_gap_msec_histogram_cell->Add(duration_us * 0.001); } -void RecordTFDataOptimization(const string& name, int64_t num_changes) { +void RecordTFDataOptimization(const std::string& name, int64_t num_changes) { tf_data_optimization_counter->GetCell(name)->IncrementBy(num_changes); } @@ -641,7 +643,7 @@ void RecordTFDataServiceClientIterators( } void RecordTFDataServiceDataTransferProtocolUsed( - const string& data_transfer_protocol, bool user_specified) { + const std::string& data_transfer_protocol, bool user_specified) { std::string nature = user_specified ? "specified" : "default"; tf_data_service_data_transfer_protocol_used_by_nature ->GetCell(data_transfer_protocol, nature) @@ -649,16 +651,16 @@ void RecordTFDataServiceDataTransferProtocolUsed( } void RecordTFDataServiceDataTransferProtocolFallback( - const string& data_transfer_protocol, error::Code code, - const string& error_message) { + const std::string& data_transfer_protocol, error::Code code, + const std::string& error_message) { tf_data_service_data_transfer_protocol_fallback ->GetCell(data_transfer_protocol, error::Code_Name(code), error_message) ->IncrementBy(1); } void RecordTFDataServiceDataTransferProtocolError( - const string& data_transfer_protocol, error::Code code, - const string& error_message) { + const std::string& data_transfer_protocol, error::Code code, + const std::string& error_message) { tf_data_service_data_transfer_protocol_error ->GetCell(data_transfer_protocol, error::Code_Name(code), error_message) ->IncrementBy(1); @@ -688,7 +690,8 @@ void RecordTFDataServiceOptimalNumberOfWorkers(int64_t number_of_workers) { tf_data_service_optimal_number_of_workers->GetCell()->Set(number_of_workers); } -void RecordTFDataFilename(const string& name, const string& filename) { +void RecordTFDataFilename(const std::string& name, + const std::string& filename) { tf_data_filename_counter->GetCell(name, filename)->IncrementBy(1); } @@ -697,7 +700,7 @@ void RecordTFDataFileLoggerAttempts() { } void RecordTFDataFileLoggerErrors(error::Code error_code, - const string& error_message) { + const std::string& error_message) { tf_data_file_logger_errors_counter ->GetCell(error::Code_Name(error_code), error_message) ->IncrementBy(1); @@ -710,39 +713,40 @@ void RecordTFDataFileLoggerAttemptedNumFiles(size_t num_files) { void RecordTFDataFileLoggerErrorsNumFiles(size_t num_files, error::Code error_code, - const string& error_message) { + const std::string& error_message) { tf_data_file_logger_errors_num_files_counter ->GetCell(error::Code_Name(error_code), error_message) ->IncrementBy(num_files); } -void RecordTFDataAutoShard(const string& id, data::AutoShardPolicy policy, - int64 num_workers, int64 num_replicas) { +void RecordTFDataAutoShard(const std::string& id, data::AutoShardPolicy policy, + int64_t num_workers, int64_t num_replicas) { tf_data_auto_shard->GetCell(id, "policy")->Set(static_cast(policy)); tf_data_auto_shard->GetCell(id, "num_workers")->Set(num_workers); tf_data_auto_shard->GetCell(id, "num_replicas")->Set(num_replicas); } void RecordTFDataAutoShardRewriteBatchSize( - bool eligible, const std::vector& ineligible_reason) { + bool eligible, const std::vector& ineligible_reason) { tf_data_auto_shard_rewrite_batch_size_eligible ->GetCell(eligible ? "true" : "false") ->IncrementBy(1); - for (const string& reason : ineligible_reason) { + for (const std::string& reason : ineligible_reason) { tf_data_auto_shard_rewrite_batch_size_reason->GetCell(reason)->IncrementBy( 1); } } -void RecordTFDataAutotuneStoppingCriteria(const string& name) { +void RecordTFDataAutotuneStoppingCriteria(const std::string& name) { tf_data_autotune_stopping_criteria_counter->GetCell(name)->IncrementBy(1); } -void RecordTFDataDebug(const string& event) { +void RecordTFDataDebug(const std::string& event) { tf_data_debug->GetCell(event)->IncrementBy(1); } -void RecordTFDataError(const string& error_type, const string& status_code) { +void RecordTFDataError(const std::string& error_type, + const std::string& status_code) { tf_data_error->GetCell(error_type, status_code)->IncrementBy(1); } @@ -750,7 +754,7 @@ void RecordTFDataFrameworkType(const std::string& framework_type) { tf_data_framework_type->GetCell(framework_type)->IncrementBy(1); } -void RecordParseDenseFeature(int64 num_features) { +void RecordParseDenseFeature(int64_t num_features) { static auto* parse_dense_feature_counter_cell = parse_dense_feature_counter->GetCell(); parse_dense_feature_counter_cell->IncrementBy(num_features); @@ -797,7 +801,7 @@ void UpdateAotBefMlirLoadCount() { aot_bef_mlir_load_count_cell->IncrementBy(1); } -void UpdateGraphExecTime(const uint64 running_time_usecs) { +void UpdateGraphExecTime(const uint64_t running_time_usecs) { if (running_time_usecs > 0) { static auto* graph_runs_cell = graph_runs->GetCell(); static auto* graph_run_time_usecs_cell = graph_run_time_usecs->GetCell(); @@ -809,13 +813,13 @@ void UpdateGraphExecTime(const uint64 running_time_usecs) { } } -void UpdateGraphPendingQueueLength(uint64 len) { +void UpdateGraphPendingQueueLength(uint64_t len) { static auto* graph_pending_queue_length_cell = graph_pending_queue_length_histogram->GetCell(); graph_pending_queue_length_cell->Add(len); } -void UpdateGraphBuildTime(const uint64 running_time_usecs) { +void UpdateGraphBuildTime(const uint64_t running_time_usecs) { if (running_time_usecs > 0) { static auto* build_graph_calls_cell = build_graph_calls->GetCell(); static auto* build_graph_time_usecs_cell = @@ -825,7 +829,7 @@ void UpdateGraphBuildTime(const uint64 running_time_usecs) { } } -void UpdateFunctionGraphOptimizationTime(const uint64 running_time_usecs) { +void UpdateFunctionGraphOptimizationTime(const uint64_t running_time_usecs) { if (running_time_usecs > 0) { static auto* function_graph_optimization_time_usecs_cell = function_graph_optimization_time_usecs->GetCell(); @@ -834,7 +838,7 @@ void UpdateFunctionGraphOptimizationTime(const uint64 running_time_usecs) { } } -void UpdateFunctionGraphOptimizationSavingTime(const uint64 saving_time_usecs, +void UpdateFunctionGraphOptimizationSavingTime(const uint64_t saving_time_usecs, GraphOptimizationSource source) { if (saving_time_usecs > 0) { std::string mapped_source = GraphOptimizationSourceMapping(source); @@ -845,7 +849,7 @@ void UpdateFunctionGraphOptimizationSavingTime(const uint64 saving_time_usecs, } } -uint64 GetFunctionGraphOptimizationSavingTimeUsecs( +uint64_t GetFunctionGraphOptimizationSavingTimeUsecs( GraphOptimizationSource source) { std::string mapped_source = GraphOptimizationSourceMapping(source); return graph_optimization_saving_time_usecs->GetCell(mapped_source)->value(); @@ -904,14 +908,14 @@ int64_t GetFunctionGraphOptimizationCacheLoadCount( return graph_optimization_cache_load_count->GetCell(mapped_source)->value(); } -void UpdateTpuVariableDistributionTime(const uint64 distribution_time_usecs) { +void UpdateTpuVariableDistributionTime(const uint64_t distribution_time_usecs) { if (distribution_time_usecs > 0) { tpu_variable_distribution_time_usecs->GetCell()->IncrementBy( distribution_time_usecs); } } -void UpdateXlaCompilationTime(const uint64 compilation_time_usecs) { +void UpdateXlaCompilationTime(const uint64_t compilation_time_usecs) { if (compilation_time_usecs > 0) { static auto* xla_compilations_cell = xla_compilations->GetCell(); static auto* xla_compilation_time_usecs_cell = @@ -921,32 +925,32 @@ void UpdateXlaCompilationTime(const uint64 compilation_time_usecs) { } } -void RecordUnusedOutput(const string& op_name) { +void RecordUnusedOutput(const std::string& op_name) { graph_unused_outputs->GetCell(op_name)->IncrementBy(1); } -void RecordPipelineProcessingTime(const string& id, +void RecordPipelineProcessingTime(const std::string& id, double pipeline_processing_time_usec) { GetTFDataPipelineProcessingTimeGauge(id)->Set(pipeline_processing_time_usec); } -void IncrementTestCounter(const string& name, const string& label) { +void IncrementTestCounter(const std::string& name, const std::string& label) { test_counters->GetCell(name, label)->IncrementBy(1); } -const tsl::monitoring::CounterCell* TestCounter(const string& name, - const string& label) { +const tsl::monitoring::CounterCell* TestCounter(const std::string& name, + const std::string& label) { return test_counters->GetCell(name, label); } -TestDelta::TestDelta(const string& name, const string& label) +TestDelta::TestDelta(const std::string& name, const std::string& label) : cell_(TestCounter(name, label)) { Reset(); } void TestDelta::Reset() { last_value_ = cell_->value(); } -int64 TestDelta::Get() { return cell_->value() - last_value_; } +int64_t TestDelta::Get() { return cell_->value() - last_value_; } void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& bridge_type, const std::string& bridge_version, @@ -1020,12 +1024,13 @@ void IncrementPhase2XlaCompilerCounter(Phase2XlaCompilerMetric metric) { ->IncrementBy(1); } -void UpdateTpuErrorCounter(const string& op, const string& error_type) { +void UpdateTpuErrorCounter(const std::string& op, + const std::string& error_type) { tpu_op_error_counter->GetCell(op, error_type)->IncrementBy(1); } -void UpdateEagerClientErrorCounter(const string& error_source, - const string& error_type) { +void UpdateEagerClientErrorCounter(const std::string& error_source, + const std::string& error_type) { eager_client_error_counter->GetCell(error_source, error_type)->IncrementBy(1); } diff --git a/tensorflow/core/framework/metrics.h b/tensorflow/core/framework/metrics.h index 18b52c49ecf61b..4d84c1f615adae 100644 --- a/tensorflow/core/framework/metrics.h +++ b/tensorflow/core/framework/metrics.h @@ -39,24 +39,24 @@ enum class GraphOptimizationSource { // Records when a data-fetching tf.data operation is executed. // // The `name` argument identifies the operation type (e.g. "ToSingleElementOp"). -void RecordTFDataFetchOp(const string& name); +void RecordTFDataFetchOp(const std::string& name); // Records that a tf.data.Dataset executed by the program used autotuning. // // The `name` argument identifies the Dataset type (e.g. "ParallelMap"). -void RecordTFDataAutotune(const string& name); +void RecordTFDataAutotune(const std::string& name); // Returns a counter that can be used to record the number of bytes produced by // a tf.data.Dataset. // // The `name` argument identifies the Dataset type (e.g. "Batch" or "Map"). -monitoring::CounterCell* GetTFDataBytesConsumedCounter(const string& name); +monitoring::CounterCell* GetTFDataBytesConsumedCounter(const std::string& name); // Returns a counter that can be used to record the number of bytes produced by // a tf.data.Dataset. // // The `name` argument identifies the Dataset type (e.g. "Batch" or "Map"). -monitoring::CounterCell* GetTFDataBytesProducedCounter(const string& name); +monitoring::CounterCell* GetTFDataBytesProducedCounter(const std::string& name); // Returns a counter than can be used to record the number of bytes read from // the filesystem by a tf.data.Dataset source. @@ -64,43 +64,43 @@ monitoring::CounterCell* GetTFDataBytesProducedCounter(const string& name); // The `name` argument identifies the Dataset type (e.g. "TFRecordDataset"). // // TODO(jsimsa): Remove this now that we have GetTFDataBytesConsumedCounter? -monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name); +monitoring::CounterCell* GetTFDataBytesReadCounter(const std::string& name); // Returns a counter than can be used to record the number of elements produced // by a tf.data.Dataset. // // The `name` argument identifies the Dataset type (e.g. "Batch" or "Map"). -monitoring::CounterCell* GetTFDataElementsCounter(const string& name); +monitoring::CounterCell* GetTFDataElementsCounter(const std::string& name); // Returns a gauge than can be used to record the performance model information. // // The `id` argument represents the (unique) model ID. monitoring::GaugeCell>* GetTFDataModelGauge( - const string& id); + const std::string& id); // Records the number of bytes fetched from tf.data.Dataset iterator. void RecordTFDataBytesFetched(int64_t num_bytes); // Records the number of times a tf.data experiment was applied. -void RecordTFDataExperiment(const string& name); +void RecordTFDataExperiment(const std::string& name); // Records the number of times a tf.data experiment could have been applied. -void RecordTFDataExperimentLive(const string& name); +void RecordTFDataExperimentLive(const std::string& name); // Records the number of times a tf.data experiment was opted into. -void RecordTFDataExperimentOptIn(const string& experiment_name); +void RecordTFDataExperimentOptIn(const std::string& experiment_name); // Records the number of times a tf.data experiment was opted out of. -void RecordTFDataExperimentOptOut(const string& experiment_name); +void RecordTFDataExperimentOptOut(const std::string& experiment_name); // Records the time (in microseconds) spent generating an element and // transferring it over the network for the given protocol. -void RecordTFDataServiceGetElementDuration(const string& data_transfer_protocol, - uint64 duration_us); +void RecordTFDataServiceGetElementDuration( + const std::string& data_transfer_protocol, uint64_t duration_us); // Records the time (in microseconds) spent in a single invocation of // `ItertatorResource::GetNext()`. -void RecordTFDataGetNextDuration(uint64 duration_us); +void RecordTFDataGetNextDuration(uint64_t duration_us); // Records the histogram of ratios of tf.data autotune algorithm used RAM over // the ram budget. @@ -115,7 +115,7 @@ void RecordTFDataAutotuneMaxBufferBudgetRatio(const double ratio); // // The `name` argument identifies the Dataset graph fingerprint, // created using GraphHash(). -void RecordTFDataFingerprint(const string& name); +void RecordTFDataFingerprint(const std::string& name); // Records the event of a tf.data service pipeline getting a runtime // compression decision. @@ -123,26 +123,26 @@ void RecordTFDataServiceRuntimeCompressionDecision(bool compression_decision); // Records the event of a tf.data service pipeline making the compression // related action. -void RecordTFDataServiceCompressionAction(const string& action); +void RecordTFDataServiceCompressionAction(const std::string& action); // Records the time (in microseconds) during which `IteratorResource` was busy // processing at least one `GetNext()` request. -void RecordTFDataIteratorBusy(uint64 duration_us); +void RecordTFDataIteratorBusy(uint64_t duration_us); // Records the time (in microseconds) between `IteratorResource` receiving the // first `GetNext()` request and responding to the last `GetNext()` request. -void RecordTFDataIteratorLifetime(uint64 duration_us); +void RecordTFDataIteratorLifetime(uint64_t duration_us); // Records the time histogram (in microseconds) between `IteratorResource` // responding to a `GetNext()` request and receiving the next `GetNext()` // request. -void RecordTFDataIteratorGap(uint64 duration_us); +void RecordTFDataIteratorGap(uint64_t duration_us); // Records the number of independent graph changes resulting from the // application of a tf.data optimization. // // The `name` argument identifies the optimization (e.g. "noop_elimination"). -void RecordTFDataOptimization(const string& name, int64_t num_changes); +void RecordTFDataOptimization(const std::string& name, int64_t num_changes); // Records that a tf.data service worker has been created. void RecordTFDataServiceWorkerCreated(); @@ -160,21 +160,21 @@ void RecordTFDataServiceClientIterators( // `data_transfer_protocol` to get data from the worker server and whether or // not the user explicitly specified the protocol. void RecordTFDataServiceDataTransferProtocolUsed( - const string& data_transfer_protocol, bool user_specified); + const std::string& data_transfer_protocol, bool user_specified); // Records that a tf.data service worker client fell back to gRPC rather than // use `data_transfer_protocol` because of an error of type `code` with message // `error_message`. void RecordTFDataServiceDataTransferProtocolFallback( - const string& data_transfer_protocol, error::Code code, - const string& error_message); + const std::string& data_transfer_protocol, error::Code code, + const std::string& error_message); // Records that a tf.data service worker client got an error of non-retriable // type `code` with message `error_message` when trying to transfer data over // `data_transfer_protocol`. void RecordTFDataServiceDataTransferProtocolError( - const string& data_transfer_protocol, error::Code code, - const string& error_message); + const std::string& data_transfer_protocol, error::Code code, + const std::string& error_message); // Records tf.data service cross-trainer cache queries. void RecordTFDataServiceCrossTrainerCacheQuery(bool cache_hit); @@ -195,7 +195,7 @@ void RecordTFDataServiceOptimalNumberOfWorkers(int64_t number_of_workers); // Records the file name read by a tf.data Dataset. // // The `name` argument identifies the Dataset type (e.g. "TFRecordDataset"). -void RecordTFDataFilename(const string& name, const string& filename); +void RecordTFDataFilename(const std::string& name, const std::string& filename); // Records the total attempts made by file logger. void RecordTFDataFileLoggerAttempts(); @@ -203,7 +203,7 @@ void RecordTFDataFileLoggerAttempts(); // Records an error of type `code` with message `error_message` encountered by // file logger. void RecordTFDataFileLoggerErrors(error::Code code, - const string& error_message); + const std::string& error_message); // Records the total number of files attempted to be logged by file logger. void RecordTFDataFileLoggerAttemptedNumFiles(size_t num_files); @@ -212,15 +212,15 @@ void RecordTFDataFileLoggerAttemptedNumFiles(size_t num_files); // `code` with message `error_message` during logging by file logger with this // error code. void RecordTFDataFileLoggerErrorsNumFiles(size_t num_files, error::Code code, - const string& error_message); + const std::string& error_message); // Records statistics of tf.data auto sharding. // // The `id` is a unique identifier of the input pipeline. The `policy` // identifies the auto-sharding policy used, the `num_workers` identifies the // number of workers, and `num_replicas` identifies the number of replicas. -void RecordTFDataAutoShard(const string& id, data::AutoShardPolicy policy, - int64 num_workers, int64 num_replicas); +void RecordTFDataAutoShard(const std::string& id, data::AutoShardPolicy policy, + int64_t num_workers, int64_t num_replicas); // Records statistics of whether we can rewrite batch size in tf.data auto // sharding. @@ -229,26 +229,27 @@ void RecordTFDataAutoShard(const string& id, data::AutoShardPolicy policy, // indicates whether the input pipeline is eligible for the rewrite. The // `ineligible_reason` is the reason if the input pipeline is ineligible. void RecordTFDataAutoShardRewriteBatchSize( - bool eligible, const std::vector& ineligible_reason); + bool eligible, const std::vector& ineligible_reason); // Records the number of times each tf.data autotuning algorithm stopping // criterion is met. -void RecordTFDataAutotuneStoppingCriteria(const string& name); +void RecordTFDataAutotuneStoppingCriteria(const std::string& name); // Records the number of times this event occured, for debugging. -void RecordTFDataDebug(const string& event); +void RecordTFDataDebug(const std::string& event); // Records the number of times an error of this type occurred with this status // code. -void RecordTFDataError(const string& error_type, const string& error_code); +void RecordTFDataError(const std::string& error_type, + const std::string& error_code); // Records the framework type used to build the tf.data.Dataset. void RecordTFDataFrameworkType(const std::string& framework_type); // Records the number of times tf.data file logger encountered an error of this // type occurred with this status code. -void RecordTFDataFileLoggerError(const string& error_type, - const string& error_code); +void RecordTFDataFileLoggerError(const std::string& error_type, + const std::string& error_code); // Records parsing of dense tensor features. void RecordParseDenseFeature(int64_t num_features); @@ -266,14 +267,14 @@ void RecordGraphOutputTensors(const size_t size); // Records the number of cores requested by graphs with XLA SPMD enabled. void RecordTPUXlaSpmdCoresPerReplica(int64_t cores_per_replica); -void UpdateGraphExecTime(const uint64 running_time_usecs); -void UpdateGraphPendingQueueLength(uint64 len); +void UpdateGraphExecTime(const uint64_t running_time_usecs); +void UpdateGraphPendingQueueLength(uint64_t len); // Records that one output of an op of type `op_name` was unused. -void RecordUnusedOutput(const string& op_name); +void RecordUnusedOutput(const std::string& op_name); // Records the pipeline processing time in microseconds -void RecordPipelineProcessingTime(const string& id, +void RecordPipelineProcessingTime(const std::string& id, double pipeline_processing_time_usec); // Increments the count of binaries loaded from the persistent cache. @@ -295,17 +296,17 @@ void UpdateAotBefMlirLoadCount(); // When executing eagerly, this will not record any activity. // // TODO(jtkeeling): Should we record building/optimizing tf.functions? -void UpdateGraphBuildTime(const uint64 running_time_usecs); +void UpdateGraphBuildTime(const uint64_t running_time_usecs); // Updates the metric stored for time spent optimizing function graphs. -void UpdateFunctionGraphOptimizationTime(const uint64 running_time_usecs); +void UpdateFunctionGraphOptimizationTime(const uint64_t running_time_usecs); // Updates the metric stored for time saved by caching graph optimization. -void UpdateFunctionGraphOptimizationSavingTime(uint64 saving_time_usec, +void UpdateFunctionGraphOptimizationSavingTime(uint64_t saving_time_usec, GraphOptimizationSource source); // Retrieves the total time saved by the graph optimization caching. -uint64 GetFunctionGraphOptimizationSavingTimeUsecs( +uint64_t GetFunctionGraphOptimizationSavingTimeUsecs( GraphOptimizationSource source); // Increments the hit count for the graph optimization cache. @@ -463,10 +464,10 @@ class ScopedCounter final { // Returns duration of the current interval in case the timer has started. // Returns nullopt otherwise. - std::optional DurationMicroSec() const { - return started_ ? std::optional(accumulated_time_ + - Env::Default()->NowMicros() - - start_time_) + std::optional DurationMicroSec() const { + return started_ ? std::optional(accumulated_time_ + + Env::Default()->NowMicros() - + start_time_) : std::nullopt; } @@ -492,7 +493,7 @@ class ScopedCounter final { private: template void ReportInternal(std::index_sequence) { - uint64 time_interval = Env::Default()->NowMicros() - start_time_; + uint64_t time_interval = Env::Default()->NowMicros() - start_time_; time_interval += accumulated_time_; if (time_interval > 0) { counter_->GetCell(labels_[S]...)->IncrementBy(time_interval); @@ -508,8 +509,8 @@ class ScopedCounter final { monitoring::Counter* counter_; std::array labels_; bool started_{false}; - uint64 start_time_; - uint64 accumulated_time_; + uint64_t start_time_; + uint64_t accumulated_time_; }; // Returns a counter used to capture timing metrics for graph optimization @@ -517,32 +518,33 @@ class ScopedCounter final { monitoring::Counter<2>* GetGraphOptimizationCounter(); // Updates metrics for time to distribute variables to all TPU hosts. -void UpdateTpuVariableDistributionTime(const uint64 distribution_time_usecs); +void UpdateTpuVariableDistributionTime(const uint64_t distribution_time_usecs); // Updates the metrics stored about time XLA spents compiling graphs. -void UpdateXlaCompilationTime(const uint64 compilation_time_usecs); +void UpdateXlaCompilationTime(const uint64_t compilation_time_usecs); // Increments (by 1) a simple integer counter that is exposed for testing. -void IncrementTestCounter(const string& name, const string& label); +void IncrementTestCounter(const std::string& name, const std::string& label); // Read-only access to a counter for testing. -const monitoring::CounterCell* TestCounter(const string& name, - const string& label); +const monitoring::CounterCell* TestCounter(const std::string& name, + const std::string& label); // Read-only wrapper for a TestCounter to track increments between calls. class TestDelta { public: - TestDelta(const string& name, const string& label); + TestDelta(const std::string& name, const std::string& label); void Reset(); - int64 Get(); + int64_t Get(); private: const monitoring::CounterCell* cell_; - int64 last_value_; + int64_t last_value_; }; -void UpdateTpuErrorCounter(const string& op, const string& error_type); -void UpdateEagerClientErrorCounter(const string& error_source, - const string& error_type); +void UpdateTpuErrorCounter(const std::string& op, + const std::string& error_type); +void UpdateEagerClientErrorCounter(const std::string& error_source, + const std::string& error_type); } // namespace metrics } // namespace tensorflow diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 654b15b9ac0201..0d05c8d72b69d7 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -240,7 +240,7 @@ bool AreAllParametersMax(const Model::ModelParameters& parameters) { } // Records the ram usage of hill climbing algorithm. -void RecordAutotuneRamUsage(int64 ram_budget, double max_buffered_bytes) { +void RecordAutotuneRamUsage(int64_t ram_budget, double max_buffered_bytes) { if (ram_budget == 0) { return; } @@ -1227,8 +1227,8 @@ class UnknownRatio : public Node { // The processing time is the sum of the self processing time and the product // of the ratio estimate and the sum of processing times of inputs. void TotalProcessingTimeLocked( - absl::flat_hash_map* processing_times, - absl::flat_hash_map* total_processing_times) override + absl::flat_hash_map* processing_times, + absl::flat_hash_map* total_processing_times) override TF_SHARED_LOCKS_REQUIRED(mu_) { double self_processing_time = SelfProcessingTimeLocked(); if (processing_times) { @@ -1400,13 +1400,13 @@ class AsyncUnknownRatio : public AsyncRatio { thread_local int64_t Node::work_start_; -std::shared_ptr MakeParameter(const string& name, +std::shared_ptr MakeParameter(const std::string& name, std::shared_ptr state, double min, double max) { return std::make_shared(name, state, min, max); } -std::shared_ptr MakeParameter(const string& name, +std::shared_ptr MakeParameter(const std::string& name, std::shared_ptr state, double min, double max, double value) { std::shared_ptr parameter = @@ -1415,7 +1415,7 @@ std::shared_ptr MakeParameter(const string& name, return parameter; } -std::shared_ptr MakeNonTunableParameter(const string& name, +std::shared_ptr MakeNonTunableParameter(const std::string& name, double value) { return std::make_shared(name, nullptr, /*min=*/value, /*max=*/value); @@ -1649,8 +1649,8 @@ Node::ModelParameters Node::CollectNodeTunableParameters() const { return parameters; } -string Node::DebugString() const { - absl::flat_hash_map debug_strings; +std::string Node::DebugString() const { + absl::flat_hash_map debug_strings; tf_shared_lock l(mu_); // Build up the debug string from the leaves of the nodes tree to the root. for (const auto& node : @@ -2035,9 +2035,10 @@ void Node::CollectTunableParametersHelper( } } -void Node::DebugStringHelper(absl::flat_hash_map* debug_strings) - const TF_SHARED_LOCKS_REQUIRED(mu_) { - string result; +void Node::DebugStringHelper( + absl::flat_hash_map* debug_strings) const + TF_SHARED_LOCKS_REQUIRED(mu_) { + std::string result; absl::StrAppend(&result, long_name(), ":\n"); absl::StrAppend(&result, " autotune=", autotune_.load(), "\n"); absl::StrAppend(&result, " buffered_bytes=", buffered_bytes_.load(), "\n"); @@ -2047,7 +2048,7 @@ void Node::DebugStringHelper(absl::flat_hash_map* debug_strings) absl::StrAppend(&result, " bytes_produced=", bytes_produced_.load(), "\n"); absl::StrAppend(&result, " processing_time=", processing_time_.load(), "\n"); absl::StrAppend(&result, " num_elements=", num_elements_.load(), "\n"); - string inputs; + std::string inputs; for (auto& input : inputs_) { absl::StrAppend(&inputs, input->long_name(), ","); } @@ -2080,7 +2081,7 @@ std::shared_ptr Node::SnapshotHelper( { mutex_lock l2(cloned_current->mu_); cloned_current->parameters_ = - absl::flat_hash_map>(); + absl::flat_hash_map>(); for (const auto& [parameter_name, parameter_ptr] : parameters_) { cloned_current->parameters_[parameter_name] = std::make_shared(parameter_ptr); @@ -2257,7 +2258,7 @@ Model::Model(std::optional dataset_name) : dataset_name_(std::move(dataset_name)), optimization_period_ms_(kOptimizationPeriodMinMs), safe_to_collect_metrics_(std::make_shared(true)) { - model_id_ = absl::StrCat(reinterpret_cast(this)); + model_id_ = absl::StrCat(reinterpret_cast(this)); model_gauge_cell_ = metrics::GetTFDataModelGauge(model_id_); // Capture `safe_to_collect_metrics_` by value to avoid use-after-free issues @@ -2297,7 +2298,7 @@ Model::~Model() { metrics::RecordPipelineProcessingTime(model_id_, 0); } -void Model::AddNode(Node::Factory factory, const string& name, +void Model::AddNode(Node::Factory factory, const std::string& name, std::shared_ptr parent, std::shared_ptr* out_node) { // The name captures the sequence of iterators joined by `::`. We only use the @@ -2935,7 +2936,7 @@ void Model::OptimizeStageBasedNonAsyncInterleaveManyNodes( node_tunable_parameters.end()); } // Initialize the parallelism parameter values to minimal before tuning. - for (std::pair>& pair : + for (std::pair>& pair : tunable_parameters) { if (pair.second->name != kParallelism) { continue; @@ -3206,7 +3207,8 @@ absl::Status Model::FromProto(ModelProto model_proto, return absl::OkStatus(); } -absl::Status Model::Save(const string& fname, std::shared_ptr snapshot, +absl::Status Model::Save(const std::string& fname, + std::shared_ptr snapshot, const OptimizationParams& optimization_params) { ModelProto model_proto; std::unique_ptr model_snapshot = std::make_unique(); @@ -3222,7 +3224,8 @@ absl::Status Model::Save(const string& fname, std::shared_ptr snapshot, return WriteBinaryProto(Env::Default(), fname, model_proto); } -absl::Status Model::Load(const string& fname, std::unique_ptr* model, +absl::Status Model::Load(const std::string& fname, + std::unique_ptr* model, OptimizationParams* optimization_params) { ModelProto model_proto; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index fd47c91842721c..c8c39768dc2e6a 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -97,8 +97,8 @@ struct SharedState { // Represents a parameter. struct Parameter { - Parameter(const string& name, std::shared_ptr state, double min, - double max) + Parameter(const std::string& name, std::shared_ptr state, + double min, double max) : name(name), // Sometimes non-autotune nodes (with `autotune_=false`) may contain // parameters (for example inputs of parallel interleave dataset which @@ -121,7 +121,7 @@ struct Parameter { state(parameter->state) {} // Human-readable name of the parameter. - const string name; + const std::string name; // Identifies the model value of the parameter. This can be different from // the actual value (e.g. during optimization search). @@ -138,18 +138,18 @@ struct Parameter { }; // Returns a new tunable parameter with the value set to `min`. -std::shared_ptr MakeParameter(const string& name, +std::shared_ptr MakeParameter(const std::string& name, std::shared_ptr state, double min, double max); // Returns a new tunable parameter with the value set to `value` instead // of `min`. -std::shared_ptr MakeParameter(const string& name, +std::shared_ptr MakeParameter(const std::string& name, std::shared_ptr state, double min, double max, double value); // Returns a new non-tunable parameter. -std::shared_ptr MakeNonTunableParameter(const string& name, +std::shared_ptr MakeNonTunableParameter(const std::string& name, double value); // Class for managing the ram budget of an iterator. This is necessary for @@ -283,7 +283,7 @@ class Node { // Arguments for `Node` constructor. struct Args { int64_t id; - string name; + std::string name; std::shared_ptr output; }; @@ -292,10 +292,10 @@ class Node { using NodePairList = std::list, std::shared_ptr>>; using ModelParameters = - std::vector>>; - using NodeValues = absl::flat_hash_map; + std::vector>>; + using NodeValues = absl::flat_hash_map; using ParameterGradients = - absl::flat_hash_map, double>; + absl::flat_hash_map, double>; explicit Node(Args args) : id_(args.id), @@ -413,10 +413,12 @@ class Node { } // Returns a longer node name that is guaranteed to be unique. - string long_name() const { return absl::StrCat(name_, "(id:", id_, ")"); } + std::string long_name() const { + return absl::StrCat(name_, "(id:", id_, ")"); + } // Returns the node name. - const string& name() const { return name_; } + const std::string& name() const { return name_; } // Returns the number of elements produced by the node. int64_t num_elements() const TF_LOCKS_EXCLUDED(mu_) { return num_elements_; } @@ -426,7 +428,7 @@ class Node { std::shared_ptr output_shared() { return output_weak_ptr_.lock(); } // Returns the parameter value. - double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) { + double parameter_value(const std::string& name) const TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock l(mu_); return parameters_.at(name)->state->value; } @@ -564,7 +566,7 @@ class Node { ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_); // Returns a human-readable representation of this node. - string DebugString() const TF_LOCKS_EXCLUDED(mu_); + std::string DebugString() const TF_LOCKS_EXCLUDED(mu_); // Flushes the metrics recorded by this node. void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); @@ -645,7 +647,7 @@ class Node { // Used for (incrementally) recording metrics. The class is thread-safe. class Metrics { public: - explicit Metrics(const string& name) + explicit Metrics(const std::string& name) : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)), bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)), num_elements_counter_(metrics::GetTFDataElementsCounter(name)), @@ -787,8 +789,9 @@ class Node { TF_SHARED_LOCKS_REQUIRED(mu_); // Build up debug string for the node and store in the debug strings map. - void DebugStringHelper(absl::flat_hash_map* debug_strings) - const TF_SHARED_LOCKS_REQUIRED(mu_); + void DebugStringHelper( + absl::flat_hash_map* debug_strings) const + TF_SHARED_LOCKS_REQUIRED(mu_); // Copy the node and add the (input, copy) pairs to the NodePairList. std::shared_ptr SnapshotHelper(std::shared_ptr cloned_output, @@ -827,7 +830,7 @@ class Node { mutable mutex mu_; const int64_t id_; - const string name_; + const std::string name_; // Indicates whether the subtree rooted in this node should be included in // autotuning. In particular, if this is `false`, then the subtree is excluded @@ -844,7 +847,7 @@ class Node { std::atomic processing_time_; std::atomic record_metrics_; Metrics metrics_; - absl::flat_hash_map> parameters_ + absl::flat_hash_map> parameters_ TF_GUARDED_BY(mu_); // Statistic of inputs processing time history. @@ -952,7 +955,7 @@ class Model { } // Adds a node with the given name and given parent. - void AddNode(Node::Factory factory, const string& name, + void AddNode(Node::Factory factory, const std::string& name, std::shared_ptr parent, std::shared_ptr* out_node) TF_LOCKS_EXCLUDED(mu_); @@ -1014,12 +1017,13 @@ class Model { // Saves this model with a given snapshot and its optimization parameters to a // file. Note that the file directory must already exist. - absl::Status Save(const string& fname, std::shared_ptr snapshot, + absl::Status Save(const std::string& fname, std::shared_ptr snapshot, const OptimizationParams& optimization_params); // Loads a model and its optimization parameters from a file with the given // name. - static absl::Status Load(const string& fname, std::unique_ptr* model, + static absl::Status Load(const std::string& fname, + std::unique_ptr* model, OptimizationParams* optimization_params); // Records gap time between consecutive `GetNext()` calls. diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc index b7d42eaa0522d3..6ad728f1a0de2c 100644 --- a/tensorflow/core/framework/model_test.cc +++ b/tensorflow/core/framework/model_test.cc @@ -54,7 +54,7 @@ std::function RamBudgetFunc(int64_t budget) { return [budget](int64_t) { return budget; }; } -int64_t CountParametersOnNode(const string& node_name, +int64_t CountParametersOnNode(const std::string& node_name, const Model::ModelParameters& parameters) { int64_t cnt = 0; for (const auto& pair : parameters) { @@ -865,10 +865,11 @@ TEST(AsyncInterleaveManyGradientTest, Model) { (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } -class AsyncKnownRatioGradientTest : public ::testing::TestWithParam {}; +class AsyncKnownRatioGradientTest + : public ::testing::TestWithParam {}; TEST_P(AsyncKnownRatioGradientTest, Model) { - const string parameter_name = GetParam(); + const std::string parameter_name = GetParam(); const double input_time = 100; const int64_t num_inputs_per_output = 2; @@ -1165,7 +1166,7 @@ TEST(SaveModelTest, Model) { // Make Save->Load roundtrip. Env* env = Env::Default(); - string tmpFile; + std::string tmpFile; EXPECT_TRUE(env->LocalTempFilename(&tmpFile)); tmpFile += "_autotune_model_test"; diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 92fb66395efbf8..fcbb4b7d3672a3 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -32,7 +32,7 @@ NodeDefBuilder::NodeOut::NodeOut() { } void NodeDefBuilder::NodeOut::Reset(absl::string_view n, int i, DataType dt) { - node = string(n); + node = std::string(n); index = i; data_type = dt; } @@ -41,9 +41,9 @@ NodeDefBuilder::NodeDefBuilder(absl::string_view name, absl::string_view op_name, const OpRegistryInterface* op_registry, const NodeDebugInfo* debug) { - node_def_.set_name(string(name)); + node_def_.set_name(name); const absl::Status status = - op_registry->LookUpOpDef(string(op_name), &op_def_); + op_registry->LookUpOpDef(std::string(op_name), &op_def_); if (status.ok()) { Initialize(); } else { @@ -62,7 +62,7 @@ NodeDefBuilder::NodeDefBuilder(absl::string_view name, NodeDefBuilder::NodeDefBuilder(absl::string_view name, const OpDef* op_def) : op_def_(op_def) { - node_def_.set_name(string(name)); + node_def_.set_name(name); Initialize(); } @@ -182,7 +182,7 @@ void NodeDefBuilder::AddInput(absl::string_view src_node, int src_index) { } else if (src_index > 0) { node_def_.add_input(absl::StrCat(src_node, ":", src_index)); } else { - node_def_.add_input(string(src_node)); + node_def_.add_input(std::string(src_node)); } } @@ -210,13 +210,13 @@ NodeDefBuilder& NodeDefBuilder::ControlInput(absl::string_view src_node) { } NodeDefBuilder& NodeDefBuilder::Device(absl::string_view device_spec) { - node_def_.set_device(string(device_spec)); + node_def_.set_device(device_spec); return *this; } absl::Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { - const std::vector* errors_ptr = &errors_; - std::vector errors_storage; + const std::vector* errors_ptr = &errors_; + std::vector errors_storage; if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { // Since this is a const method, to add an error, we have to make // a copy of the existing errors. @@ -318,9 +318,9 @@ ATTR(const TensorProto&) ATTR(const NameAttrList&) ATTR(absl::Span) ATTR(absl::Span) -ATTR(absl::Span) +ATTR(absl::Span) ATTR(absl::Span) -ATTR(absl::Span) +ATTR(absl::Span) ATTR(absl::Span) ATTR(absl::Span) ATTR(absl::Span) diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index 47b14f185800cf..6b74b20fd85ad3 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -56,7 +56,7 @@ class NodeDefBuilder { NodeOut(absl::string_view n, int i, DataType dt); NodeOut(); // uninitialized, call Reset() before use. void Reset(absl::string_view n, int i, DataType dt); - string node; + std::string node; int index; DataType data_type; }; @@ -112,9 +112,10 @@ class NodeDefBuilder { absl::Span value); NodeDefBuilder& Attr(absl::string_view name, absl::Span value); - NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); NodeDefBuilder& Attr(absl::string_view name, absl::Span value); - NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); NodeDefBuilder& Attr(absl::string_view name, absl::Span value); NodeDefBuilder& Attr(absl::string_view name, absl::Span value); NodeDefBuilder& Attr(absl::string_view name, absl::Span value); @@ -145,7 +146,7 @@ class NodeDefBuilder { absl::Status Finalize(NodeDef* node_def, bool consume = false); // Accessors for the values set in the constructor. - const string& node_name() const { return node_def_.name(); } + const std::string& node_name() const { return node_def_.name(); } const OpDef& op_def() const { return *op_def_; } private: @@ -189,8 +190,8 @@ class NodeDefBuilder { const OpDef* op_def_; NodeDef node_def_; int inputs_specified_; - std::vector control_inputs_; - std::vector errors_; + std::vector control_inputs_; + std::vector errors_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc index b5429579bc889b..c769537ab13d94 100644 --- a/tensorflow/core/framework/node_def_builder_test.cc +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -79,12 +79,12 @@ class NodeDefBuilderTest : public ::testing::Test { // Calls Finalize() and verifies it returns an error. // Each message must appear as a substring of the error. void ExpectFailures(NodeDefBuilder& builder, // NOLINT - const std::vector& messages) { + const std::vector& messages) { NodeDef node_def; absl::Status status = builder.Finalize(&node_def); EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); if (status.ok()) return; - for (const string& message : messages) { + for (const std::string& message : messages) { EXPECT_TRUE(absl::StrContains(status.message(), message)) << status << ", " << message; } @@ -93,14 +93,14 @@ class NodeDefBuilderTest : public ::testing::Test { // Calls Finalize() and verifies it returns an error. // Message must appear as a substring of the error. void ExpectFailure(NodeDefBuilder& builder, // NOLINT - const string& message) { + const std::string& message) { ExpectFailures(builder, {message}); } // Like ExpectFailure(), except that the error can come from // ValidateNodeDef(). void ExpectInvalid(NodeDefBuilder& builder, // NOLINT - const string& message) { + const std::string& message) { NodeDef node_def; absl::Status status = builder.Finalize(&node_def); if (status.ok()) { @@ -822,9 +822,9 @@ TEST_F(NodeDefBuilderTest, AttrManyDefault) { .Input(FakeInput(DT_FLOAT)) .Attr("a", "foo") .Attr("e", "foo") - .Attr("b", std::vector({"bar", "baz"})) + .Attr("b", std::vector({"bar", "baz"})) .Attr("f", 1.0f), - {DT_FLOAT}, {}, R"proto( + {DT_FLOAT}, {}, R"pb( op: "AttrManyDefaultAndInferred" input: "a" attr { @@ -854,7 +854,7 @@ TEST_F(NodeDefBuilderTest, AttrManyDefault) { attr { key: "d" value { f: 0.3 } - })proto"); + })pb"); } TEST_F(NodeDefBuilderTest, AttrListDefault) { diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 507e2275afc3b5..42c5e841c99417 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -64,18 +64,18 @@ AttrSlice::AttrSlice(const NodeDef& node_def) AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {} -string SummarizeAttrsHelper(AttrSlice attrs, absl::string_view device) { - string ret; +std::string SummarizeAttrsHelper(AttrSlice attrs, absl::string_view device) { + std::string ret; // We sort the attrs so the output is deterministic. - std::vector attr_names; + std::vector attr_names; attr_names.reserve(attrs.size()); for (const auto& attr : attrs) { attr_names.push_back(attr.first); } std::sort(attr_names.begin(), attr_names.end()); bool first = true; - for (const string& attr_name : attr_names) { + for (const std::string& attr_name : attr_names) { if (!first) absl::StrAppend(&ret, ", "); first = false; absl::StrAppend(&ret, attr_name, "=", @@ -91,18 +91,18 @@ string SummarizeAttrsHelper(AttrSlice attrs, absl::string_view device) { return ret; } -string AttrSlice::SummarizeNode() const { +std::string AttrSlice::SummarizeNode() const { return ndef_ ? SummarizeNodeDef(*ndef_) : absl::StrCat( "[", SummarizeAttrsHelper(*this, absl::string_view()), "]"); } -string AttrSlice::DebugString() const { - std::vector attr_key_vals; +std::string AttrSlice::DebugString() const { + std::vector attr_key_vals; attr_key_vals.reserve(attrs()->size()); for (const auto& it : *this) { - const string& name = it.first; + const std::string& name = it.first; const AttrValue& attr_value = it.second; attr_key_vals.push_back( absl::StrCat(name, "=", SummarizeAttrValue(attr_value))); @@ -110,15 +110,17 @@ string AttrSlice::DebugString() const { return absl::StrJoin(attr_key_vals, ", "); } -string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) { - string ret = absl::StrCat(errors::FormatNodeNameForError(node_def.name()), - " = ", node_def.op(), "["); +std::string SummarizeNodeDef(const NodeDef& node_def, + int max_inputs_in_summary) { + std::string ret = + absl::StrCat(errors::FormatNodeNameForError(node_def.name()), " = ", + node_def.op(), "["); absl::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device())); absl::StrAppend(&ret, "]("); // Output inputs, including control inputs, verbatim. bool first = true; - for (const string& input : node_def.input()) { + for (const std::string& input : node_def.input()) { if (!first) absl::StrAppend(&ret, ", "); first = false; if (max_inputs_in_summary-- == 0) { @@ -131,22 +133,22 @@ string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) { return ret; } -string SummarizeAttrs(const NodeDef& node_def) { +std::string SummarizeAttrs(const NodeDef& node_def) { return SummarizeAttrsHelper(node_def, node_def.device()); } -string FormatNodeDefForError( +std::string FormatNodeDefForError( absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info) { return !has_experimental_debug_info || experimental_debug_info.original_node_names().empty() - ? errors::FormatNodeNameForError(string(node_name)) + ? errors::FormatNodeNameForError(node_name) : errors::FormatOriginalNodeLocationForError( experimental_debug_info.original_node_names(), experimental_debug_info.original_func_names()); } -string FormatNodeDefForError(const NodeDef& node_def) { +std::string FormatNodeDefForError(const NodeDef& node_def) { return FormatNodeDefForError(node_def.name(), node_def.has_experimental_debug_info(), node_def.experimental_debug_info()); @@ -174,7 +176,7 @@ const AttrValue* AttrSlice::Find(absl::string_view attr_name) const { return nullptr; } -const AttrValue* AttrSlice::FindByString(const string& attr_name) const { +const AttrValue* AttrSlice::FindByString(const std::string& attr_name) const { auto iter = attrs()->find(attr_name); if (iter != attrs()->end()) { return &iter->second; @@ -205,7 +207,7 @@ absl::Status AttrSlice::Find(absl::string_view attr_name, return CheckFind(attr_name, *attr_value); } -absl::Status AttrSlice::FindByString(const string& attr_name, +absl::Status AttrSlice::FindByString(const std::string& attr_name, const AttrValue** attr_value) const { *attr_value = FindByString(attr_name); return CheckFind(attr_name, *attr_value); @@ -288,19 +290,19 @@ bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const { } DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;) DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;) -DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;) -DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;) +DEFINE_GET_ATTR(std::string, s, "string", emplace_back, v, ;) +DEFINE_TRY_GET_ATTR(std::string, s, "string", emplace_back, v, ;) DEFINE_GET_ATTR(int64_t, i, "int", emplace_back, v, ;) DEFINE_TRY_GET_ATTR(int64_t, i, "int", emplace_back, v, ;) DEFINE_GET_ATTR( - int32, i, "int", emplace_back, static_cast(v), - if (static_cast(static_cast(v)) != v) { + int32_t, i, "int", emplace_back, static_cast(v), + if (static_cast(static_cast(v)) != v) { return errors::InvalidArgument("Attr ", attr_name, " has value ", v, " out of range for an int32"); }) DEFINE_TRY_GET_ATTR( - int32, i, "int", emplace_back, static_cast(v), - if (static_cast(static_cast(v)) != v) { + int32_t, i, "int", emplace_back, static_cast(v), + if (static_cast(static_cast(v)) != v) { static int log_counter = 0; if (log_counter < 10) { log_counter++; @@ -345,13 +347,13 @@ DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;); #undef DEFINE_GET_ATTR bool HasNodeAttr(const NodeDef& node_def, absl::string_view attr_name) { - return node_def.attr().find(string(attr_name)) != node_def.attr().end(); + return node_def.attr().find(std::string(attr_name)) != node_def.attr().end(); } -static const string& kEmptyString = *new string(); +static const std::string& kEmptyString = *new std::string(); -const string& GetNodeAttrString(const AttrSlice& attrs, - absl::string_view attr_name) { +const std::string& GetNodeAttrString(const AttrSlice& attrs, + absl::string_view attr_name) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return kEmptyString; @@ -364,7 +366,7 @@ const string& GetNodeAttrString(const AttrSlice& attrs, } bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - std::vector* value) { + std::vector* value) { const AttrValue* attr_value = attrs.Find(attr_name); if (attr_value == nullptr) { return false; @@ -456,7 +458,7 @@ bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, Padding* value) { - string str_value; + std::string str_value; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value)); return GetPaddingFromString(str_value, value); } @@ -473,7 +475,7 @@ absl::Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, TF_RETURN_IF_ERROR( GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats)); // We can't handle outputs that are larger than int32 sizes. - if (static_cast(static_cast(repeats)) != repeats) { + if (static_cast(static_cast(repeats)) != repeats) { return errors::InvalidArgument("Number of outputs is too big: ", repeats); } if (repeats < 0) { @@ -645,10 +647,10 @@ absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { bool seen_control = false; size_t num_inputs = 0; // TODO(josh11b): Unify the input field validation. - for (const string& input : node_def.input()) { + for (const std::string& input : node_def.input()) { if (absl::StartsWith(input, "^")) { seen_control = true; - if (input.find(':') != string::npos) { + if (input.find(':') != std::string::npos) { return errors::InvalidArgument("Control input '", input, "' must not have ':' in NodeDef: ", FormatNodeDefForError(node_def)); @@ -662,7 +664,7 @@ absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { } } - std::unordered_map op_attrs; + std::unordered_map op_attrs; for (const auto& attr : op_def.attr()) { if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) { return errors::InvalidArgument("OpDef has duplicate attr name '", @@ -700,7 +702,7 @@ absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { // Were all attrs in the OpDef found in the NodeDef? if (!op_attrs.empty()) { - string attrs; + std::string attrs; for (const auto& attr_pair : op_attrs) { if (!attrs.empty()) absl::StrAppend(&attrs, "', '"); absl::StrAppend(&attrs, attr_pair.first); @@ -870,7 +872,8 @@ const absl::string_view kColocationGroupPrefixStringPiece( } // namespace -absl::Status ValidateOpInput(const string& input_name, bool* is_control_input) { +absl::Status ValidateOpInput(const std::string& input_name, + bool* is_control_input) { *is_control_input = false; if (IsValidDataInputName(input_name)) { return absl::OkStatus(); @@ -882,7 +885,7 @@ absl::Status ValidateOpInput(const string& input_name, bool* is_control_input) { } } -absl::Status ValidateNodeName(const string& node_name) { +absl::Status ValidateNodeName(const std::string& node_name) { if (IsValidNodeName(node_name)) { return absl::OkStatus(); } else { @@ -896,7 +899,7 @@ absl::Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { return AttachDef(s, node_def); } bool in_control_inputs = false; - for (const string& input_name : node_def.input()) { + for (const std::string& input_name : node_def.input()) { bool is_control_input; s = ValidateOpInput(input_name, &is_control_input); if (!s.ok()) { @@ -915,7 +918,7 @@ absl::Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { absl::Status AttachDef(const absl::Status& status, const NodeDef& node_def, bool allow_multiple_formatted_node) { - string node_error; + std::string node_error; if (!allow_multiple_formatted_node && absl::StrContains(status.message(), "{{node ")) { node_error = node_def.name(); @@ -930,11 +933,11 @@ absl::Status AttachDef(const absl::Status& status, const NodeDef& node_def, void AddNodeAttr(absl::string_view name, const AttrValue& value, NodeDef* node_def) { node_def->mutable_attr()->insert( - AttrValueMap::value_type(string(name), value)); + AttrValueMap::value_type(std::string(name), value)); } void AddNodeAttr(absl::string_view name, AttrValue&& value, NodeDef* node_def) { - (*node_def->mutable_attr())[string(name)] = std::move(value); + (*node_def->mutable_attr())[std::string(name)] = std::move(value); } #define ADD_NODE_ATTR(T) \ @@ -957,8 +960,8 @@ ADD_NODE_ATTR(const TensorProto&) ADD_NODE_ATTR(const NameAttrList&) ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) -ADD_NODE_ATTR(absl::Span) -ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(absl::Span) @@ -973,7 +976,7 @@ ADD_NODE_ATTR(absl::Span) void AddAttr(absl::string_view name, const AttrValue& value, AttrValueMap* map) { - map->insert(AttrValueMap::value_type(string(name), value)); + map->insert(AttrValueMap::value_type(std::string(name), value)); } #define ADD_ATTR(T) \ @@ -994,7 +997,7 @@ absl::Status AddPrefixAndSuffixToNode(absl::string_view prefix, // Update frame name to avoid multiple LoopCond nodes in one frame. if (uniquify_frame_name && (node_def->op() == "Enter" || node_def->op() == "RefEnter")) { - string frame_name; + std::string frame_name; TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name)); AttrValue& attr = (*node_def->mutable_attr())["frame_name"]; frame_name = absl::StrCat(prefix, frame_name, suffix); @@ -1005,7 +1008,7 @@ absl::Status AddPrefixAndSuffixToNode(absl::string_view prefix, } absl::Status MaybeAddPrefixToColocationConstraints( - const std::unordered_set& match, absl::string_view prefix, + const std::unordered_set& match, absl::string_view prefix, NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); if (attr == node_def->mutable_attr()->end()) { @@ -1016,7 +1019,7 @@ absl::Status MaybeAddPrefixToColocationConstraints( for (size_t i = 0; i < constraints_size; ++i) { absl::string_view original(constraints_list->s(i)); if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) { - if (match.find(string(original)) != match.end()) { + if (match.find(std::string(original)) != match.end()) { (*constraints_list->mutable_s(i)) = absl::StrCat(kColocationGroupPrefix, prefix, original); } diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 2b82c596fee301..1dd97f9e4137db 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -81,7 +81,7 @@ std::string FormatNodeDefForError( absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info); -typedef protobuf::Map AttrValueMap; +typedef protobuf::Map AttrValueMap; // Adds an attr with name and value to *node_def. // The type of the attr is based on the type of value. @@ -109,9 +109,9 @@ void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(absl::string_view name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); -void AddNodeAttr(absl::string_view name, absl::Span value, +void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); void AddNodeAttr(absl::string_view name, absl::Span value, NodeDef* node_def); @@ -221,7 +221,7 @@ absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, int64_t* value); // type: "int" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - int32* value); // type: "int" + int32_t* value); // type: "int" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, float* value); // type: "float" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, @@ -236,14 +236,15 @@ absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, PartialTensorShape* value); // type: "shape" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, Tensor* value); // type: "tensor" -absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - std::vector* value); // type "list(string)" +absl::Status GetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(string)" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(tstring)" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(int)" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - std::vector* value); // type "list(int)" + std::vector* value); // type "list(int)" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type "list(float)" absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, @@ -302,7 +303,7 @@ bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "int" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - int32* value); // type: "int" + int32_t* value); // type: "int" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, float* value); // type: "float" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, @@ -313,11 +314,11 @@ bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, TensorShape* value); // type: "shape" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - std::vector* value); // type: "list(string)" + std::vector* value); // type: "list(string)" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(tstring)" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - std::vector* value); // type: "list(int)" + std::vector* value); // type: "list(int)" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(float)" bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, @@ -329,8 +330,9 @@ bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute // values. -bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, - std::vector* value); // type: "list(string)" +bool TryGetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(string)" bool TryGetNodeAttr( const AttrSlice& attrs, absl::string_view attr_name, std::vector* value); // type: "list(shape)" @@ -442,7 +444,7 @@ absl::Status AddPrefixAndSuffixToNode(absl::string_view prefix, // Appends the given prefix to the colocation group name if the name exists // in `to_match`. absl::Status MaybeAddPrefixToColocationConstraints( - const std::unordered_set& match, absl::string_view prefix, + const std::unordered_set& match, absl::string_view prefix, NodeDef* node_def); // Updates the colocation constraint name with the one provided in the map (if diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 5296dcc7075dc6..66a37a41ee3f8a 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -37,7 +37,7 @@ OpDef ToOpDef(const OpDefBuilder& builder) { return op_reg_data.op_def; } -NodeDef ToNodeDef(const string& text) { +NodeDef ToNodeDef(const std::string& text) { NodeDef node_def; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); return node_def; @@ -56,7 +56,7 @@ void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { } void ExpectFailure(const NodeDef& bad, const OpDef& op_def, - const string& message) { + const std::string& message) { absl::Status status = ValidateNodeDef(bad, op_def); EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad) @@ -322,7 +322,7 @@ void ExpectValidSyntax(const NodeDef& good) { << "NodeDef: " << SummarizeNodeDef(good); } -void ExpectInvalidSyntax(const NodeDef& bad, const string& message) { +void ExpectInvalidSyntax(const NodeDef& bad, const std::string& message) { absl::Status status = ValidateExternalNodeDefSyntax(bad); ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad); @@ -761,11 +761,11 @@ TEST(AddPrefixAndSuffixToNode, Enter) { node_def.set_name("enter"); node_def.set_op("Enter"); AddNodeAttr("frame_name", "test_frame", &node_def); - const string prefix = "prefix/"; - const string suffix = "/suffix"; + const std::string prefix = "prefix/"; + const std::string suffix = "/suffix"; TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def)); EXPECT_EQ("prefix/enter/suffix", node_def.name()); - string frame_name; + std::string frame_name; TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name)); EXPECT_EQ("prefix/test_frame/suffix", frame_name); } @@ -780,15 +780,15 @@ TEST(MaybeAddPrefixToColocationConstraints, Basic) { absl::StrCat(kColocationGroupPrefix, "Node3")}, &node_def); - std::unordered_set match; + std::unordered_set match; match.insert("Node1"); match.insert("Node3"); TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def)); - std::vector coloc_constraints; + std::vector coloc_constraints; TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints)); - EXPECT_EQ( - coloc_constraints, - std::vector({"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"})); + EXPECT_EQ(coloc_constraints, + std::vector( + {"loc:@fn/Node1", "loc:@Node2", "loc:@fn/Node3"})); } TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) { @@ -796,7 +796,7 @@ TEST(MaybeAddPrefixToColocationConstraints, NoConstraints) { node_def.set_name("Identity"); node_def.set_op("Identity"); - std::unordered_set match; + std::unordered_set match; match.insert("Node1"); match.insert("Node3"); TF_ASSERT_OK(MaybeAddPrefixToColocationConstraints(match, "fn/", &node_def)); @@ -817,10 +817,10 @@ TEST(MaybeUpdateColocationConstraintsWithMap, Basic) { node_map["Node1"] = "Node4"; node_map["Invalid"] = "Node5"; TF_ASSERT_OK(MaybeUpdateColocationConstraintsWithMap(node_map, &node_def)); - std::vector coloc_constraints; + std::vector coloc_constraints; TF_ASSERT_OK(GetNodeAttr(node_def, kColocationAttrName, &coloc_constraints)); - EXPECT_EQ(coloc_constraints, - std::vector({"loc:@Node4", "loc:@Node2", "loc:@Node3"})); + EXPECT_EQ(coloc_constraints, std::vector( + {"loc:@Node4", "loc:@Node2", "loc:@Node3"})); } TEST(MaybeUpdateColocationConstraintsWithMap, NoConstraints) { diff --git a/tensorflow/core/framework/node_properties_test.cc b/tensorflow/core/framework/node_properties_test.cc index 8e1dd344e91261..28f992c4e4dff4 100644 --- a/tensorflow/core/framework/node_properties_test.cc +++ b/tensorflow/core/framework/node_properties_test.cc @@ -40,7 +40,7 @@ class MockOpRegistry : public OpRegistryInterface { // Returns an error status and sets *op_reg_data to nullptr if no OpDef is // registered under that name, otherwise returns the registered OpDef. // Caller must not delete the returned pointer. - absl::Status LookUp(const string& op_type_name, + absl::Status LookUp(const std::string& op_type_name, const OpRegistrationData** op_reg_data) const override { if (op_type_name == "Foo") { *op_reg_data = &op_reg_; diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 31aeb2421bc652..7688578d8513f5 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -39,7 +39,7 @@ absl::Status DefaultValidator(const OpRegistryInterface& op_registry) { // OpRegistry ----------------------------------------------------------------- -absl::Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, +absl::Status OpRegistryInterface::LookUpOpDef(const std::string& op_type_name, const OpDef** op_def) const { *op_def = nullptr; const OpRegistrationData* op_reg_data = nullptr; @@ -62,7 +62,7 @@ void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) { namespace { // Helper function that returns Status message for failed LookUp. -absl::Status OpNotFound(const string& op_type_name) { +absl::Status OpNotFound(const std::string& op_type_name) { absl::Status status = errors::NotFound( "Op type not registered '", op_type_name, "' in binary running on ", port::Hostname(), ". ", @@ -76,13 +76,14 @@ absl::Status OpNotFound(const string& op_type_name) { } } // namespace -absl::Status OpRegistry::LookUp(const string& op_type_name, +absl::Status OpRegistry::LookUp(const std::string& op_type_name, const OpRegistrationData** op_reg_data) const { if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } -const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const { +const OpRegistrationData* OpRegistry::LookUp( + const std::string& op_type_name) const { { tf_shared_lock l(mu_); if (initialized_) { @@ -96,7 +97,7 @@ const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const { } const OpRegistrationData* OpRegistry::LookUpSlow( - const string& op_type_name) const { + const std::string& op_type_name) const { const OpRegistrationData* res = nullptr; bool first_call = false; @@ -195,10 +196,10 @@ absl::Status OpRegistry::ProcessRegistrations() const { return CallDeferred(); } -string OpRegistry::DebugString(bool include_internal) const { +std::string OpRegistry::DebugString(bool include_internal) const { OpList op_list; Export(include_internal, &op_list); - string ret; + std::string ret; for (const auto& op : op_list.op()) { absl::StrAppend(&ret, SummarizeOpDef(op), "\n"); } @@ -268,7 +269,7 @@ OpListOpRegistry::OpListOpRegistry(const OpList* op_list) { } const OpRegistrationData* OpListOpRegistry::LookUp( - const string& op_type_name) const { + const std::string& op_type_name) const { auto iter = index_.find(op_type_name); if (iter == index_.end()) { return nullptr; @@ -277,7 +278,8 @@ const OpRegistrationData* OpListOpRegistry::LookUp( } absl::Status OpListOpRegistry::LookUp( - const string& op_type_name, const OpRegistrationData** op_reg_data) const { + const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const { if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 41b39fc2076469..251d58bdd01a15 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -165,7 +165,8 @@ class OpRegistry : public OpRegistryInterface { // Functions in deferred_ may only be called with mu_ held. mutable std::vector deferred_ TF_GUARDED_BY(mu_); // Values are owned. - mutable absl::flat_hash_map> + mutable absl::flat_hash_map> registry_ TF_GUARDED_BY(mu_); mutable bool initialized_ TF_GUARDED_BY(mu_); @@ -193,7 +194,8 @@ class OpListOpRegistry : public OpRegistryInterface { private: // Values are owned. - absl::flat_hash_map> index_; + absl::flat_hash_map> + index_; }; // Support for defining the OpDef (specifying the semantics of the Op and how diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index da11e32498becf..f6087d6d5f33ed 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -88,10 +88,10 @@ class OpCompatibilityTest : public OpsTestBase { TF_ASSERT_OK(RunOpKernel()); } - string Result() { return GetOutput(0)->scalar()(); } + std::string Result() { return GetOutput(0)->scalar()(); } void ExpectIncompatible(const OpDef& old_op_def, const OpDef& new_op_def, - const string& error) { + const std::string& error) { // Test OpDefCompatible gives the same answer without the node_def. absl::Status status = OpDefCompatible(old_op_def, new_op_def); if (status.ok()) { @@ -103,8 +103,9 @@ class OpCompatibilityTest : public OpsTestBase { } } - void ExpectInvalid(const OpDef& old_op_def, const string& validation_error, - const string& compatibility_error) { + void ExpectInvalid(const OpDef& old_op_def, + const std::string& validation_error, + const std::string& compatibility_error) { // Record the original signature before we change *node_def(). DataTypeVector old_in_types, old_out_types; TF_ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types, @@ -127,7 +128,7 @@ class OpCompatibilityTest : public OpsTestBase { } void ExpectTypeMismatch(const OpDef& old_op_def, - const string& compatibility_error) { + const std::string& compatibility_error) { // Record the original signature before we change *node_def(). DataTypeVector old_in_types, old_out_types; TF_ASSERT_OK(InOutTypesForNode(*node_def(), old_op_def, &old_in_types, @@ -153,7 +154,7 @@ class OpCompatibilityTest : public OpsTestBase { } void ExpectRenameFailure(const OpDef& old_op_def, - const string& compatibility_error) { + const std::string& compatibility_error) { // This should be all that is needed to get compatibility. const OpDef* new_op_def = RegisteredOpDef(); AddDefaultsToNodeDef(*new_op_def, node_def()); @@ -166,7 +167,7 @@ class OpCompatibilityTest : public OpsTestBase { } void ExpectDefaultChangeFailure(const OpDef& old_op_def, - const string& compatibility_error) { + const std::string& compatibility_error) { // This should be all that is needed to get compatibility. const OpDef* new_op_def = RegisteredOpDef(); AddDefaultsToNodeDef(*new_op_def, node_def()); diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index e4ec9e50497d73..9265f5b10ed7e4 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -36,7 +36,7 @@ namespace tensorflow { namespace { -string AttrError(absl::string_view orig, const string& op_name) { +std::string AttrError(absl::string_view orig, const std::string& op_name) { return absl::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); } @@ -62,7 +62,7 @@ bool ConsumeListPrefix(absl::string_view* sp) { bool ConsumeQuotedString(char quote_ch, absl::string_view* sp, absl::string_view* out) { - const string quote_str(1, quote_ch); + const std::string quote_str(1, quote_ch); return Scanner(*sp) .OneLiteral(quote_str.c_str()) .RestartCapture() @@ -150,7 +150,7 @@ bool ProcessCompoundType(const absl::string_view type_string, } void FinalizeAttr(absl::string_view spec, bool allow_attr_type_any, - OpDef* op_def, std::vector* errors) { + OpDef* op_def, std::vector* errors) { OpDef::AttrDef* attr = op_def->add_attr(); absl::string_view orig(spec); @@ -161,7 +161,7 @@ void FinalizeAttr(absl::string_view spec, bool allow_attr_type_any, // Read "" or "list()". bool is_list = ConsumeListPrefix(&spec); - string type; + std::string type; absl::string_view type_string; // Used if type == "type" if (absl::ConsumePrefix(&spec, "string")) { type = "string"; @@ -197,8 +197,8 @@ void FinalizeAttr(absl::string_view spec, bool allow_attr_type_any, VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) || ConsumeQuotedString('\'', &spec, &escaped_string), "Trouble parsing allowed string at '", spec, "'"); - string unescaped; - string error; + std::string unescaped; + std::string error; VERIFY(absl::CUnescape(escaped_string, &unescaped, &error), "Trouble unescaping \"", escaped_string, "\", got error: ", error); @@ -274,8 +274,8 @@ void FinalizeAttr(absl::string_view spec, bool allow_attr_type_any, #undef VERIFY -string InOutError(bool is_output, absl::string_view orig, - const string& op_name) { +std::string InOutError(bool is_output, absl::string_view orig, + const std::string& op_name) { return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig, "\") for Op ", op_name); } @@ -343,7 +343,7 @@ bool ConsumeControlOutName(absl::string_view* sp, absl::string_view* out) { } while (false) void FinalizeInputOrOutput(absl::string_view spec, bool is_output, - OpDef* op_def, std::vector* errors) { + OpDef* op_def, std::vector* errors) { OpDef::ArgDef* arg = is_output ? op_def->add_output_arg() : op_def->add_input_arg(); @@ -426,12 +426,13 @@ void FinalizeInputOrOutput(absl::string_view spec, bool is_output, #undef VERIFY -string ControlOutError(absl::string_view orig, const string& op_name) { +std::string ControlOutError(absl::string_view orig, + const std::string& op_name) { return absl::StrCat(" from ControlOutput(\"", orig, "\") for Op ", op_name); } void FinalizeControlOutput(absl::string_view name, OpDef* op_def, - std::vector* errors) { + std::vector* errors) { absl::string_view orig(name); // Parse control output name. @@ -441,7 +442,7 @@ void FinalizeControlOutput(absl::string_view name, OpDef* op_def, ControlOutError(orig, op_def->name()))); } - *op_def->add_control_output() = string(tmp_name.data(), tmp_name.size()); + *op_def->add_control_output() = std::string(tmp_name.data(), tmp_name.size()); } int num_leading_spaces(absl::string_view s) { @@ -467,12 +468,12 @@ bool IsDocNameColon(absl::string_view s) { return ConsumeDocNameColon(&s, nullptr /* out */); } -void FinalizeDoc(const string& text, OpDef* op_def, - std::vector* errors) { - std::vector lines = str_util::Split(text, '\n'); +void FinalizeDoc(const std::string& text, OpDef* op_def, + std::vector* errors) { + std::vector lines = str_util::Split(text, '\n'); // Remove trailing spaces. - for (string& line : lines) { + for (std::string& line : lines) { absl::StripTrailingAsciiWhitespace(&line); } @@ -493,8 +494,9 @@ void FinalizeDoc(const string& text, OpDef* op_def, int end_l = l; // Trim trailing blank lines from the description. while (start_l < end_l && lines[end_l - 1].empty()) --end_l; - string desc = absl::StrJoin( - absl::Span(lines.data() + start_l, end_l - start_l), "\n"); + std::string desc = absl::StrJoin( + absl::Span(lines.data() + start_l, end_l - start_l), + "\n"); if (!desc.empty()) op_def->set_description(desc); // name: description @@ -528,7 +530,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, if (!description[i].empty()) description[i].remove_prefix(min_indent); } // Concatenate lines into a single string. - const string complete(absl::StrJoin(description, "\n")); + const std::string complete(absl::StrJoin(description, "\n")); // Find name. bool found = false; @@ -561,31 +563,31 @@ void FinalizeDoc(const string& text, OpDef* op_def, } // namespace -OpDefBuilder::OpDefBuilder(string op_name) { +OpDefBuilder::OpDefBuilder(std::string op_name) { op_def()->set_name(std::move(op_name)); } -OpDefBuilder& OpDefBuilder::Attr(string spec) { +OpDefBuilder& OpDefBuilder::Attr(std::string spec) { attrs_.push_back(std::move(spec)); return *this; } -OpDefBuilder& OpDefBuilder::Input(string spec) { +OpDefBuilder& OpDefBuilder::Input(std::string spec) { inputs_.push_back(std::move(spec)); return *this; } -OpDefBuilder& OpDefBuilder::Output(string spec) { +OpDefBuilder& OpDefBuilder::Output(std::string spec) { outputs_.push_back(std::move(spec)); return *this; } -OpDefBuilder& OpDefBuilder::ControlOutput(string name) { +OpDefBuilder& OpDefBuilder::ControlOutput(std::string name) { control_outputs_.push_back(std::move(name)); return *this; } -OpDefBuilder& OpDefBuilder::Doc(string text) { +OpDefBuilder& OpDefBuilder::Doc(std::string text) { #ifndef TF_LEAN_BINARY if (!doc_.empty()) { errors_.push_back( @@ -622,7 +624,7 @@ OpDefBuilder& OpDefBuilder::SetIsDistributedCommunication() { return *this; } -OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) { +OpDefBuilder& OpDefBuilder::Deprecated(int version, std::string explanation) { if (op_def()->has_deprecation()) { errors_.push_back( absl::StrCat("Deprecated called twice for Op ", op_def()->name())); @@ -667,7 +669,7 @@ OpDefBuilder& OpDefBuilder::AllowAttrTypeAny() { } absl::Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { - std::vector errors = errors_; + std::vector errors = errors_; *op_reg_data = op_reg_data_; OpDef* op_def = &op_reg_data->op_def; diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index 8009135d584188..3df88e028c2bd2 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -40,7 +40,7 @@ typedef std::vector> TypeRefVector; // A callback into the type inference process, allowing type inference functions // to request inferring the type of some function (assumed to exist in the // runtime). The function is specified by name. -typedef std::function(const string&, +typedef std::function(const std::string&, const TypeRefVector&)> FunctionTypeInferrer; @@ -266,12 +266,12 @@ class OpDefBuilder { OpDef* op_def() { return &op_reg_data_.op_def; } OpRegistrationData op_reg_data_; - std::vector attrs_; - std::vector inputs_; - std::vector outputs_; - std::vector control_outputs_; + std::vector attrs_; + std::vector inputs_; + std::vector outputs_; + std::vector control_outputs_; std::string doc_; - std::vector errors_; + std::vector errors_; bool allow_attr_type_any_ = false; }; diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index 3e8c805bcb419f..8dad7a721dad34 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -74,7 +74,7 @@ class OpDefBuilderTest : public ::testing::Test { } } - void ExpectFailure(const OpDefBuilder& builder, const string& error) { + void ExpectFailure(const OpDefBuilder& builder, const std::string& error) { OpRegistrationData op_reg_data; absl::Status status = builder.Finalize(&op_reg_data); EXPECT_FALSE(status.ok()); diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index e228d1f4969a7c..b11360b68bb4a6 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -48,7 +48,7 @@ absl::Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { return absl::OkStatus(); } } - string allowed_str; + std::string allowed_str; for (int i = 0; i < allowed_values.list().type_size(); ++i) { if (!allowed_str.empty()) { absl::StrAppend(&allowed_str, ", "); @@ -61,15 +61,16 @@ absl::Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { " is not in the list of allowed values: ", allowed_str); } -absl::Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { +absl::Status AllowedStringValue(const std::string& str, + const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (const auto& allowed : allowed_values.list().s()) { if (str == allowed) { return absl::OkStatus(); } } - string allowed_str; - for (const string& allowed : allowed_values.list().s()) { + std::string allowed_str; + for (const std::string& allowed : allowed_values.list().s()) { if (!allowed_str.empty()) { absl::StrAppend(&allowed_str, ", "); } @@ -135,7 +136,7 @@ absl::Status ValidateAttrValue(const AttrValue& attr_value, } else if (attr.type() == "string") { TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); } else if (attr.type() == "list(string)") { - for (const string& str : attr_value.list().s()) { + for (const std::string& str : attr_value.list().s()) { TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); } } else { @@ -193,7 +194,7 @@ const ApiDef::Arg* FindInputArg(absl::string_view name, const ApiDef& api_def) { static absl::Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, bool output, absl::flat_hash_set* names) { - const string suffix = + const std::string suffix = absl::StrCat(output ? " for output '" : " for input '", arg.name(), "'"); VALIDATE(names->emplace(arg.name()).second, "Duplicate name: ", arg.name()); VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); @@ -320,7 +321,7 @@ absl::Status ValidateOpDef(const OpDef& op_def) { // Validate allowed_values if (attr.has_allowed_values()) { - const string list_type = + const std::string list_type = is_list ? attr.type() : absl::StrCat("list(", attr.type(), ")"); TF_RETURN_WITH_CONTEXT_IF_ERROR( AttrValueHasType(attr.allowed_values(), list_type), " for attr '", @@ -360,7 +361,7 @@ absl::Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { } else { // Warn only once for each op name, and do it in a threadsafe manner. static mutex mu(LINKER_INITIALIZED); - static auto* warned = new absl::flat_hash_set(); + static auto* warned = new absl::flat_hash_set(); bool warn; { mutex_lock lock(mu); @@ -378,8 +379,9 @@ absl::Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { namespace { -string SummarizeArgs(const protobuf::RepeatedPtrField& args) { - string ret; +std::string SummarizeArgs( + const protobuf::RepeatedPtrField& args) { + std::string ret; for (const OpDef::ArgDef& arg : args) { if (!ret.empty()) absl::StrAppend(&ret, ", "); absl::StrAppend(&ret, arg.name(), ":"); @@ -399,8 +401,8 @@ string SummarizeArgs(const protobuf::RepeatedPtrField& args) { } // namespace -string SummarizeOpDef(const OpDef& op_def) { - string ret = absl::StrCat("Op ", SummarizeArgs(op_def.output_arg())); for (int i = 0; i < op_def.attr_size(); ++i) { @@ -474,12 +476,12 @@ bool MoreRestrictive(const OpDef::AttrDef& old_attr, return false; } -string AllowedStr(const OpDef::AttrDef& attr) { +std::string AllowedStr(const OpDef::AttrDef& attr) { if (!attr.has_allowed_values()) return "no restriction"; return SummarizeAttrValue(attr.allowed_values()); } -string DefaultAttrStr(const OpDef::AttrDef& attr) { +std::string DefaultAttrStr(const OpDef::AttrDef& attr) { if (!attr.has_default_value()) return "no default"; return SummarizeAttrValue(attr.default_value()); } @@ -495,7 +497,7 @@ bool HigherMinimum(const OpDef::AttrDef& old_attr, return new_attr.minimum() > old_attr.minimum(); } -string MinStr(const OpDef::AttrDef& attr) { +std::string MinStr(const OpDef::AttrDef& attr) { if (!attr.has_minimum()) return "no minimum"; return absl::StrCat(attr.minimum()); } @@ -509,7 +511,7 @@ void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { // Add a comma to *s every call but the first (*add_comma should be // initialized to false). -void AddComma(string* s, bool* add_comma) { +void AddComma(std::string* s, bool* add_comma) { if (*add_comma) { absl::StrAppend(s, ", "); } else { @@ -518,7 +520,7 @@ void AddComma(string* s, bool* add_comma) { } // Will add the `name` from arg if name is true. -void AddName(string* s, bool name, const OpDef::ArgDef& arg) { +void AddName(std::string* s, bool name, const OpDef::ArgDef& arg) { if (name) { absl::StrAppend(s, arg.name(), ":"); } @@ -535,11 +537,11 @@ void AddName(string* s, bool name, const OpDef::ArgDef& arg) { // // We get the types by either using the attrs in args if they are in // old_attrs, or substituting the default value from new_attrs. -string ComputeArgSignature( +std::string ComputeArgSignature( const protobuf::RepeatedPtrField& args, const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector* ref, bool names) { - string s; + std::string s; bool add_comma = false; for (const OpDef::ArgDef& arg : args) { if (!arg.type_list_attr().empty()) { @@ -568,7 +570,7 @@ string ComputeArgSignature( } } else { int num = 1; // How many input/outputs does this represent? - string type; // What is the type of this arg? + std::string type; // What is the type of this arg? AddName(&type, names, arg); if (!arg.number_attr().empty()) { // N * type case. @@ -655,9 +657,9 @@ absl::Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { } std::vector old_in_ref, new_in_ref, old_out_ref, new_out_ref; - const string old_in_sig = ComputeArgSignature( + const std::string old_in_sig = ComputeArgSignature( old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */); - const string new_in_sig = ComputeArgSignature( + const std::string new_in_sig = ComputeArgSignature( new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */); VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig, "' vs. '", new_in_sig, "'"); @@ -669,10 +671,10 @@ absl::Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { " changed from non-ref to ref"); } - const string old_out_sig = + const std::string old_out_sig = ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs, &old_out_ref, true /* names */); - const string new_out_sig = + const std::string new_out_sig = ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs, &new_out_ref, true /* names */); VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '", @@ -805,13 +807,13 @@ bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) { return true; } -uint64 AttrDefHash(const OpDef::AttrDef& a) { - uint64 h = Hash64(a.name()); +uint64_t AttrDefHash(const OpDef::AttrDef& a) { + uint64_t h = Hash64(a.name()); h = Hash64(a.type().data(), a.type().size(), h); h = Hash64Combine(AttrValueHash(a.default_value()), h); h = Hash64(a.description().data(), a.description().size(), h); - h = Hash64Combine(static_cast(a.has_minimum()), h); - h = Hash64Combine(static_cast(a.minimum()), h); + h = Hash64Combine(static_cast(a.has_minimum()), h); + h = Hash64Combine(static_cast(a.minimum()), h); h = Hash64Combine(AttrValueHash(a.allowed_values()), h); return h; } @@ -837,7 +839,7 @@ bool RepeatedAttrDefEqual( return true; } -uint64 RepeatedAttrDefHash( +uint64_t RepeatedAttrDefHash( const protobuf::RepeatedPtrField& a) { // Insert AttrDefs into map to deterministically sort by name std::vector a_sorted; @@ -850,7 +852,7 @@ uint64 RepeatedAttrDefHash( return lhs->name() < rhs->name(); }); // Iterate and combines hashes of keys and values - uint64 h = 0xDECAFCAFFE; + uint64_t h = 0xDECAFCAFFE; for (const auto& def : a_sorted) { h = Hash64(def->name().data(), def->name().size(), h); h = Hash64Combine(AttrDefHash(*def), h); @@ -884,8 +886,8 @@ bool OpDefEqual(const OpDef& o1, const OpDef& o2) { return AreSerializedProtosEqual(o1_copy, o2_copy); } -uint64 OpDefHash(const OpDef& o) { - uint64 h = RepeatedAttrDefHash(o.attr()); +uint64_t OpDefHash(const OpDef& o) { + uint64_t h = RepeatedAttrDefHash(o.attr()); // Compute deterministic order-independent control outputs hash. std::vector control_output; diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index be1f08225c0e2e..abaaeefb03c9a8 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -88,7 +88,7 @@ void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def); bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2); // Returns hash of `a` that is consistent with AttrDefEqual. -uint64 AttrDefHash(const OpDef::AttrDef& a); +uint64_t AttrDefHash(const OpDef::AttrDef& a); // Returns true if all AttrDefs in `a1` equal corresponding AttrDefs in // `a2`. Correspondence is established by name. @@ -96,14 +96,15 @@ bool RepeatedAttrDefEqual(const protobuf::RepeatedPtrField& a1, const protobuf::RepeatedPtrField& a2); // Returns hash of `a` that is consistent with RepeatedAttrDefEqual -uint64 RepeatedAttrDefHash(const protobuf::RepeatedPtrField& a); +uint64_t RepeatedAttrDefHash( + const protobuf::RepeatedPtrField& a); // Returns true if `o1` is equal to `o2`. // Equality includes all the fields. OpDef.attr field is treated as a set. bool OpDefEqual(const OpDef& o1, const OpDef& o2); // Returns hash of `o` that is consistent with AttrDefEqual. -uint64 OpDefHash(const OpDef& o); +uint64_t OpDefHash(const OpDef& o); } // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc index 333a103cef7e65..41fd90d4e79fcf 100644 --- a/tensorflow/core/framework/op_def_util_test.cc +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -27,13 +27,13 @@ limitations under the License. namespace tensorflow { namespace { -OpDef FromText(const string& text) { +OpDef FromText(const std::string& text) { OpDef op_def; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &op_def)); return op_def; } -OpDef::AttrDef ADef(const string& text) { +OpDef::AttrDef ADef(const std::string& text) { OpDef::AttrDef attr_def; EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &attr_def)); return attr_def; @@ -41,7 +41,7 @@ OpDef::AttrDef ADef(const string& text) { class ValidateOpDefTest : public ::testing::Test { protected: - absl::Status TestProto(const string& text) { + absl::Status TestProto(const std::string& text) { return ValidateOpDef(FromText(text)); } @@ -58,7 +58,7 @@ class ValidateOpDefTest : public ::testing::Test { }; namespace { -void ExpectFailure(const absl::Status& status, const string& message) { +void ExpectFailure(const absl::Status& status, const std::string& message) { EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; if (!status.ok()) { LOG(INFO) << "message: " << status; @@ -516,9 +516,9 @@ void ExpectDifferent(const OpDef& o1, const OpDef& o2) { } TEST(OpDefEqualityTest, EqualAndHash) { - string a1 = "attr { name: 'a' type: 'string' } "; - string a2 = "attr { name: 'b' type: 'string' } "; - string a3 = "attr { name: 'c' type: 'int32' } "; + std::string a1 = "attr { name: 'a' type: 'string' } "; + std::string a2 = "attr { name: 'b' type: 'string' } "; + std::string a3 = "attr { name: 'c' type: 'int32' } "; OpDef o1 = FromText(absl::StrCat("name: 'MatMul' ", a1)); OpDef o2 = FromText(absl::StrCat("name: 'MatMul' ", a2)); OpDef o3 = FromText(absl::StrCat("name: 'MatMul' ", a1, a2)); diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 026b8e677ac668..79766a2d187d93 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -30,10 +30,11 @@ limitations under the License. namespace tensorflow { -string WordWrap(absl::string_view prefix, absl::string_view str, int width) { - const string indent_next_line = "\n" + Spaces(prefix.size()); +std::string WordWrap(absl::string_view prefix, absl::string_view str, + int width) { + const std::string indent_next_line = "\n" + Spaces(prefix.size()); width -= prefix.size(); - string result; + std::string result; absl::StrAppend(&result, prefix); while (!str.empty()) { @@ -100,8 +101,8 @@ static bool SplitAt(char split_ch, absl::string_view* orig, // Does this line start with ":" where "" is // in multi_line_fields? Sets *colon_pos to the position of the colon. -static bool StartsWithFieldName(absl::string_view line, - const std::vector& multi_line_fields) { +static bool StartsWithFieldName( + absl::string_view line, const std::vector& multi_line_fields) { absl::string_view up_to_colon; if (!SplitAt(':', &line, &up_to_colon)) return false; while (absl::ConsumePrefix(&up_to_colon, " ")) @@ -115,8 +116,8 @@ static bool StartsWithFieldName(absl::string_view line, } static bool ConvertLine(absl::string_view line, - const std::vector& multi_line_fields, - string* ml) { + const std::vector& multi_line_fields, + std::string* ml) { // Is this a field we should convert? if (!StartsWithFieldName(line, multi_line_fields)) { return false; @@ -140,7 +141,7 @@ static bool ConvertLine(absl::string_view line, absl::string_view suffix = after_colon.substr(last_quote + 1); // We've now parsed line into ': ""' - string unescaped; + std::string unescaped; if (!absl::CUnescape(escaped, &unescaped, nullptr)) { // Error unescaping, abort the conversion. return false; @@ -148,8 +149,8 @@ static bool ConvertLine(absl::string_view line, // No more errors possible at this point. // Find a string to mark the end that isn't in unescaped. - string end = "END"; - for (int s = 0; unescaped.find(end) != string::npos; ++s) { + std::string end = "END"; + for (int s = 0; unescaped.find(end) != std::string::npos; ++s) { end = absl::StrCat("END", s); } @@ -163,9 +164,10 @@ static bool ConvertLine(absl::string_view line, return true; } -string PBTxtToMultiline(absl::string_view pbtxt, - const std::vector& multi_line_fields) { - string ml; +std::string PBTxtToMultiline( + absl::string_view pbtxt, + const std::vector& multi_line_fields) { + std::string ml; // Probably big enough, since the input and output are about the // same size, but just a guess. ml.reserve(pbtxt.size() * (17. / 16)); @@ -184,20 +186,21 @@ string PBTxtToMultiline(absl::string_view pbtxt, // Given a single line of text `line` with first : at `colon`, determine if // there is an "< split; - string::size_type pos = 0; + std::vector split; + std::string::size_type pos = 0; while (pos < s->size()) { auto found = s->find(from, pos); - if (found == string::npos) { + if (found == std::string::npos) { split.push_back(s->substr(pos)); break; } else { @@ -271,10 +275,10 @@ static void StringReplace(const string& from, const string& to, string* s) { *s = absl::StrJoin(split, to); } -static void RenameInDocs(const string& from, const string& to, +static void RenameInDocs(const std::string& from, const std::string& to, ApiDef* api_def) { - const string from_quoted = absl::StrCat("`", from, "`"); - const string to_quoted = absl::StrCat("`", to, "`"); + const std::string from_quoted = absl::StrCat("`", from, "`"); + const std::string to_quoted = absl::StrCat("`", to, "`"); for (int i = 0; i < api_def->in_arg_size(); ++i) { if (!api_def->in_arg(i).description().empty()) { StringReplace(from_quoted, to_quoted, @@ -480,17 +484,17 @@ ApiDefMap::ApiDefMap(const OpList& op_list) { ApiDefMap::~ApiDefMap() {} -absl::Status ApiDefMap::LoadFileList(Env* env, - const std::vector& filenames) { +absl::Status ApiDefMap::LoadFileList( + Env* env, const std::vector& filenames) { for (const auto& filename : filenames) { TF_RETURN_IF_ERROR(LoadFile(env, filename)); } return absl::OkStatus(); } -absl::Status ApiDefMap::LoadFile(Env* env, const string& filename) { +absl::Status ApiDefMap::LoadFile(Env* env, const std::string& filename) { if (filename.empty()) return absl::OkStatus(); - string contents; + std::string contents; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); absl::Status status = LoadApiDef(contents); if (!status.ok()) { @@ -502,8 +506,8 @@ absl::Status ApiDefMap::LoadFile(Env* env, const string& filename) { return absl::OkStatus(); } -absl::Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { - const string contents = PBTxtFromMultiline(api_def_file_contents); +absl::Status ApiDefMap::LoadApiDef(const std::string& api_def_file_contents) { + const std::string contents = PBTxtFromMultiline(api_def_file_contents); ApiDefs api_defs; TF_RETURN_IF_ERROR( proto_utils::ParseTextFormatFromString(contents, &api_defs)); @@ -522,7 +526,7 @@ void ApiDefMap::UpdateDocs() { for (auto& name_and_api_def : map_) { auto& api_def = name_and_api_def.second; CHECK_GT(api_def.endpoint_size(), 0); - const string canonical_name = api_def.endpoint(0).name(); + const std::string canonical_name = api_def.endpoint(0).name(); if (api_def.graph_op_name() != canonical_name) { RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def); } @@ -544,7 +548,7 @@ void ApiDefMap::UpdateDocs() { } } -const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const { +const tensorflow::ApiDef* ApiDefMap::GetApiDef(const std::string& name) const { return gtl::FindOrNull(map_, name); } } // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h index 27ffe522a6dd35..e5414c031abdca 100644 --- a/tensorflow/core/framework/op_gen_lib.h +++ b/tensorflow/core/framework/op_gen_lib.h @@ -29,13 +29,14 @@ namespace tensorflow { // Forward declare protos so their symbols can be removed from .so exports class OpDef; -inline string Spaces(int n) { return string(n, ' '); } +inline std::string Spaces(int n) { return std::string(n, ' '); } // Wrap prefix + str to be at most width characters, indenting every line // after the first by prefix.size() spaces. Intended use case is something // like prefix = " Foo(" and str is a list of arguments (terminated by a ")"). // TODO(josh11b): Option to wrap on ", " instead of " " when possible. -string WordWrap(absl::string_view prefix, absl::string_view str, int width); +std::string WordWrap(absl::string_view prefix, absl::string_view str, + int width); // Looks for an "=" at the beginning of *description. If found, strips it off // (and any following spaces) from *description and return true. Otherwise @@ -43,9 +44,9 @@ string WordWrap(absl::string_view prefix, absl::string_view str, int width); bool ConsumeEquals(absl::string_view* description); // Convert text-serialized protobufs to/from multiline format. -string PBTxtToMultiline(absl::string_view pbtxt, - const std::vector& multi_line_fields); -string PBTxtFromMultiline(absl::string_view multiline_pbtxt); +std::string PBTxtToMultiline(absl::string_view pbtxt, + const std::vector& multi_line_fields); +std::string PBTxtFromMultiline(absl::string_view multiline_pbtxt); // Takes a list of files with ApiDefs text protos, and allows you to // look up the specific ApiDef for any given op. @@ -62,20 +63,21 @@ class ApiDefMap { // definitions take precedence. // ApiDefs loaded from files must contain a subset of ops defined // in the OpList passed to the constructor. - absl::Status LoadFileList(Env* env, const std::vector& filenames); + absl::Status LoadFileList(Env* env, + const std::vector& filenames); // Load a single file. Api definitions are merged if the same // op definition is loaded multiple times. Later-loaded // definitions take precedence. // ApiDefs loaded from file must contain a subset of ops defined // in the OpList passed to the constructor. - absl::Status LoadFile(Env* env, const string& filename); + absl::Status LoadFile(Env* env, const std::string& filename); // Load ApiDefs from string containing ApiDefs text proto. // api_def_file_contents is expected to be in "multiline format". // ApiDefs must contain a subset of ops defined in OpsList // passed to the constructor. - absl::Status LoadApiDef(const string& api_def_file_contents); + absl::Status LoadApiDef(const std::string& api_def_file_contents); // Updates ApiDef docs. For example, if ApiDef renames an argument // or attribute, applies these renames to descriptions as well. @@ -89,10 +91,10 @@ class ApiDefMap { // Note: Returned ApiDef pointer should stay valid even after calling // Load* functions defined above. Subsequent calls to Load* might modify // returned ApiDef contents, but should never remove the ApiDef itself. - const ApiDef* GetApiDef(const string& name) const; + const ApiDef* GetApiDef(const std::string& name) const; private: - std::unordered_map map_; + std::unordered_map map_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/op_gen_lib_test.cc b/tensorflow/core/framework/op_gen_lib_test.cc index b08c77ca83221c..b06646d9fc51bd 100644 --- a/tensorflow/core/framework/op_gen_lib_test.cc +++ b/tensorflow/core/framework/op_gen_lib_test.cc @@ -72,7 +72,7 @@ END TEST(OpGenLibTest, MultilinePBTxt) { // Non-multiline pbtxt - const string pbtxt = R"(foo: "abc" + const std::string pbtxt = R"(foo: "abc" foo: "" foo: "\n\n" foo: "abc\nEND" @@ -81,7 +81,7 @@ bar: "quotes:\"" )"; // Field "foo" converted to multiline but not "bar". - const string ml_foo = R"(foo: <