Skip to content

Commit

Permalink
Merge pull request #3 from tensorflow/master
Browse files Browse the repository at this point in the history
update
  • Loading branch information
zhizunbao-y authored Sep 18, 2019
2 parents 0a1c1b9 + a5c121e commit f30ffc2
Show file tree
Hide file tree
Showing 1,825 changed files with 59,592 additions and 23,614 deletions.
4 changes: 2 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ See also [TensorBoard 0.1.4](https://github.com/tensorflow/tensorboard/releases/
* TensorForest multi-regression bug fix.
* Framework now supports armv7, cocoapods.org now displays correct page.
* Script to create iOS framework for CocoaPods.
* Android releases of TensorFlow are now pushed to jcenter for easier integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md for more details.
* Android releases of TensorFlow are now pushed to jcenter for easier integration into apps. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/android/inference_interface/README.md for more details.
* TensorFlow Debugger (tfdbg):
* Fixed a bug that prevented tfdbg from functioning with multi-GPU setups.
* Fixed a bug that prevented tfdbg from working with `tf.Session.make_callable`.
Expand Down Expand Up @@ -1569,7 +1569,7 @@ answered questions, and were part of inspiring discussions.
* [`SavedModel CLI`](https://www.tensorflow.org/versions/master/guide/saved_model_cli) tool available to inspect and execute MetaGraph in SavedModel
* Android releases of TensorFlow are now pushed to jcenter for easier
integration into apps. See
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/android/README.md
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/android/inference_interface/README.md
for more details.

## Deprecations
Expand Down
96 changes: 53 additions & 43 deletions tensorflow/c/c_api.cc

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,15 @@ TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper,
// producer.index) to consumer.oper's input (given by consumer.index).
TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in);

// Get list of all inputs of a specific operation. `inputs` must point to
// an array of length at least `max_inputs` (ideally set to
// TF_OperationNumInputs(oper)). Beware that a concurrent
// modification of the graph can increase the number of inputs of
// an operation.
TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper,
TF_Output* inputs,
int max_inputs);

// Get the number of current consumers of a specific output of an
// operation. Note that this number can change when new operations
// are added to the graph.
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/c/c_api_experimental.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,6 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}

static void CheckOk(TF_Status* status) {
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}

void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
auto* status = TF_NewStatus();
if (!TFE_TensorHandleIsConcrete(handle)) {
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/c/eager/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ load(
)
load(
"//tensorflow/core/platform:default/build_config.bzl",
"tf_additional_device_tracer_test_flags",
"tf_kernel_tests_linkstatic",
)
load(
Expand All @@ -27,6 +26,7 @@ tf_cuda_library(
"c_api.cc",
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.cc",
"c_api_internal.h",
],
hdrs = ["c_api.h"],
Expand Down Expand Up @@ -237,8 +237,7 @@ tf_cuda_cc_test(
srcs = [
"c_api_experimental_test.cc",
],
args =
["--heap_check=local"] + tf_additional_device_tracer_test_flags(),
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
Expand Down
105 changes: 44 additions & 61 deletions tensorflow/c/eager/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/host_info.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
Expand Down Expand Up @@ -61,6 +61,7 @@ limitations under the License.
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
Expand Down Expand Up @@ -100,32 +101,34 @@ string DeviceName(const tensorflow::Device* d) {
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
tensorflow::Notification n;
tensorflow::mutex remote_devices_mu;
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, remote_worker,
[&status, &n, &remote_devices](
tensorflow::Env::Default(), worker_cache, remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
status = s;
statuses[i] = s;
if (s.ok()) {
tensorflow::mutex_lock l(remote_devices_mu);
for (tensorflow::Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
n.Notify();
counter.DecrementCount();
});
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::StaticDeviceMgr(std::move(remote_devices)));

TF_RETURN_IF_ERROR(status);

counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
Expand All @@ -135,11 +138,15 @@ tensorflow::Status CreateRemoteContexts(
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request) {
for (int i = 0; i < remote_workers.size(); i++) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];

tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response;
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
Expand All @@ -159,16 +166,17 @@ tensorflow::Status CreateRemoteContexts(
return tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
tensorflow::Notification n;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
eager_client->CreateContextAsync(
&request, &response, [&status, &n](const tensorflow::Status& s) {
status = s;
n.Notify();
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
n.WaitForNotification();
TF_RETURN_IF_ERROR(status);
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
Expand Down Expand Up @@ -215,7 +223,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
remote_workers.end());

std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
std::unique_ptr<tensorflow::DynamicDeviceMgr> remote_device_mgr;
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr));
Expand Down Expand Up @@ -247,7 +255,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor()->Async(), base_request));
ctx->context->Executor().Async(), base_request));

tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
Expand Down Expand Up @@ -564,7 +572,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), handle->Context()->Executor(),
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
return nullptr;
Expand Down Expand Up @@ -596,33 +604,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {

TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return new TFE_Op(ctx, name, false, types,
new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
return new TFE_Op(ctx, name, true, types, nullptr);
return NewOrResetOp(ctx, op_or_function_name, status,
/* op_to_reset= */ nullptr);
}

void TFE_DeleteOp(TFE_Op* op) { delete op; }
Expand Down Expand Up @@ -916,7 +899,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
ctx->context->Executor(),
&ctx->context->Executor(),
device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
Expand Down Expand Up @@ -967,7 +950,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,

void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
status->status = ctx->context->Executor()->WaitForAllPendingNodes();
status->status = ctx->context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
Expand All @@ -979,9 +962,9 @@ TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
for (const auto& attr : func.attr()) {
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (!status->status.ok()) return nullptr;
}
return func_op;
}
Expand Down Expand Up @@ -1029,7 +1012,7 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} break;
case tensorflow::AttrValue::kFunc: {
const auto func_op = GetFunc(ctx, default_value.func(), status);
if (TF_GetCode(status) != TF_OK) return;
if (!status->status.ok()) return;
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.
Expand Down
12 changes: 11 additions & 1 deletion tensorflow/c/eager/c_api_experimental.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ limitations under the License.

using tensorflow::string;

void TFE_OpReset(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
if (op_to_reset) {
NewOrResetOp(ctx, op_or_function_name, status, op_to_reset);
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"op_to_reset should not be nullptr");
}
}

void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
op->operation.ConsumeInput(h->handle);
}
Expand Down Expand Up @@ -597,5 +607,5 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
}

TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(ctx->context->Executor());
return new TFE_Executor(&ctx->context->Executor());
}
4 changes: 4 additions & 0 deletions tensorflow/c/eager/c_api_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ limitations under the License.
extern "C" {
#endif

TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Context* ctx,
const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset);

TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);

Expand Down
58 changes: 58 additions & 0 deletions tensorflow/c/eager/c_api_internal.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_internal.h"

#include "tensorflow/core/platform/host_info.h"

TFE_Op* NewOrResetOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status, TFE_Op* op_to_reset) {
const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types;
bool is_function = false;
status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (!status->status.ok()) {
return nullptr;
}
auto create_or_reset = [&op_to_reset, &ctx, &name, &types](
bool is_function,
TFE_OpInferenceContext* inference_ctx) -> TFE_Op* {
if (op_to_reset) {
op_to_reset->Reset(ctx, name, is_function, types, inference_ctx);
return op_to_reset;
} else {
return new TFE_Op(ctx, name, is_function, types, inference_ctx);
}
};

if (!is_function) {
const tensorflow::OpDef* op_def;
status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
if (!status->status.ok()) {
return nullptr;
}
return create_or_reset(false, new TFE_OpInferenceContext(op_def));
}
if (!ctx->context->FindFunctionByName(name)) {
status->status = tensorflow::errors::NotFound(
"'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
}
return create_or_reset(true, nullptr);
}
Loading

0 comments on commit f30ffc2

Please sign in to comment.