Skip to content

Commit

Permalink
Merge pull request #1 from tensorflow/master
Browse files Browse the repository at this point in the history
sync w/ master
  • Loading branch information
plopresti authored May 20, 2019
2 parents 089c7ec + 935ac2b commit 7d1ba6d
Show file tree
Hide file tree
Showing 160 changed files with 31,509 additions and 20,932 deletions.
25 changes: 24 additions & 1 deletion tensorflow/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ filegroup(
"c_api.h",
"c_api_experimental.h",
"tf_attrtype.h",
"tf_status.h",
],
visibility = ["//tensorflow:__subpackages__"],
)
Expand Down Expand Up @@ -52,6 +53,7 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"c_api_internal.h",
"tf_status.h",
],
visibility = [
"//tensorflow:internal",
Expand Down Expand Up @@ -84,6 +86,7 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"tf_attrtype.h",
"tf_status.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
Expand All @@ -106,7 +109,9 @@ tf_cuda_library(
"c_api.cc",
"c_api_function.cc",
],
hdrs = ["c_api.h"],
hdrs = [
"c_api.h",
],
copts = tf_copts(),
visibility = ["//tensorflow/c:__subpackages__"],
deps = [
Expand All @@ -117,6 +122,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":tf_status",
"@com_google_absl//absl/strings",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
Expand All @@ -137,6 +143,22 @@ tf_cuda_library(
}),
)

cc_library(
name = "tf_status",
srcs = ["tf_status.cc"],
hdrs = ["tf_status.h"],
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/c:c_api_internal",
"//tensorflow/core:lib",
],
}),
)

tf_cuda_library(
name = "c_api_experimental",
srcs = [
Expand All @@ -150,6 +172,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
":checkpoint_reader",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/compiler/jit:flags",
Expand Down
22 changes: 0 additions & 22 deletions tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,28 +116,6 @@ size_t TF_DataTypeSize(TF_DataType dt) {

// --------------------------------------------------------------------------

TF_Status* TF_NewStatus() { return new TF_Status; }

void TF_DeleteStatus(TF_Status* s) { delete s; }

void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
if (code == TF_OK) {
s->status = Status::OK();
return;
}
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}

TF_Code TF_GetCode(const TF_Status* s) {
return static_cast<TF_Code>(s->status.code());
}

const char* TF_Message(const TF_Status* s) {
return s->status.error_message().c_str();
}

// --------------------------------------------------------------------------

namespace {
class TF_ManagedBuffer : public TensorBuffer {
public:
Expand Down
48 changes: 1 addition & 47 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>

#include "tensorflow/c/tf_attrtype.h"
#include "tensorflow/c/tf_status.h"

// --------------------------------------------------------------------------
// C API for TensorFlow.
Expand Down Expand Up @@ -130,53 +131,6 @@ typedef enum TF_DataType {
// (eg. TF_STRING) or on failure.
TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt);

// --------------------------------------------------------------------------
// TF_Code holds an error code. The enum values here are identical to
// corresponding values in error_codes.proto.
typedef enum TF_Code {
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15,
} TF_Code;

// --------------------------------------------------------------------------
// TF_Status holds error information. It either has an OK code, or
// else an error code with an associated error message.
typedef struct TF_Status TF_Status;

// Return a new status object.
TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void);

// Delete a previously created status object.
TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*);

// Record <code, msg> in *s. Any previous information is lost.
// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code,
const char* msg);

// Return the code record in *s.
TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s);

// Return a pointer to the (null-terminated) error message in *s. The
// return value points to memory that is only usable until the next
// mutation to *s. Always returns an empty string if TF_GetCode(s) is
// TF_OK.
TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s);

// --------------------------------------------------------------------------
// TF_Buffer holds a pointer to a block of data and its associated length.
Expand Down
69 changes: 69 additions & 0 deletions tensorflow/c/c_api_experimental.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "absl/strings/substitute.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
Expand All @@ -37,6 +38,7 @@ using tensorflow::FunctionDef;
using tensorflow::Node;
using tensorflow::NodeBuilder;
using tensorflow::Status;
using tensorflow::errors::InvalidArgument;

namespace {
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
Expand Down Expand Up @@ -576,6 +578,73 @@ void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}

struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
std::vector<std::string> variable_list;
};

TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
TF_Status* status) {
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
if (!status->status.ok()) return nullptr;
const auto& m = reader->GetVariableToDataTypeMap();
for (auto it = m.begin(); it != m.end(); ++it)
reader->variable_list.push_back(it->first);
std::sort(reader->variable_list.begin(), reader->variable_list.end());
return reader;
}

void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }

int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
const char* name) {
return reader->HasTensor(name);
}

const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
int index) {
return reader->variable_list[index].c_str();
}

int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
return reader->variable_list.size();
}

TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToDataTypeMap();
return static_cast<TF_DataType>(m.at(name));
}

TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
const char* name, TF_Status* status) {
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, status);
if (!status->status.ok()) return nullptr;
return tensorflow::TF_TensorFromTensor(*tensor.get(), status);
}

void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
const char* name, int64_t* dims,
int num_dims, TF_Status* status) {
const auto& shape = reader->GetVariableToShapeMap().at(name);
int rank = shape.dims();
if (num_dims != rank) {
status->status = InvalidArgument("Expected rank is ", num_dims,
" but actual rank is ", rank);
return;
}
for (int i = 0; i < num_dims; i++) {
dims[i] = shape.dim_size(i);
}
}

int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
const char* name) {
const auto& m = reader->GetVariableToShapeMap();
return m.at(name).dims();
}

// This builder is used in the eager API to build a NodeDef.
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
using tensorflow::AttrBuilder::AttrBuilder;
Expand Down
28 changes: 28 additions & 0 deletions tensorflow/c/c_api_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,34 @@ TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);

// TF_NewCheckpointReader() return the CheckpointReader that can be use to
// investigate or load the variable from the checkpoint file
typedef struct TF_CheckpointReader TF_CheckpointReader;
TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader(
const char* filename, TF_Status* status);
TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader(
TF_CheckpointReader* reader);
TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor(
TF_CheckpointReader* reader, const char* name);
// Get the variable name at the given index
TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable(
TF_CheckpointReader* reader, int index);
// Get the number of variable in the checkpoint
TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader);
// Get the DataType of a variable
TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType(
TF_CheckpointReader* reader, const char* name);
// Read the shape of a variable and write to `dims`
TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape(
TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims,
TF_Status* status);
// Get the number of dimension of a variable
TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims(
TF_CheckpointReader* reader, const char* name);
// Load the weight of a variable
TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor(
TF_CheckpointReader* reader, const char* name, TF_Status* status);

// TF_NewAttrBuilder() returns an object that you can set attributes on as
// though it were an op. This allows querying properties of that op for
// type-checking purposes like if the op will run on a particular device type.
Expand Down
42 changes: 42 additions & 0 deletions tensorflow/c/tf_status.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/c/tf_status.h"

#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/lib/core/status.h"

using ::tensorflow::Status;
using ::tensorflow::error::Code;

TF_Status* TF_NewStatus() { return new TF_Status; }

void TF_DeleteStatus(TF_Status* s) { delete s; }

void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
if (code == TF_OK) {
s->status = Status::OK();
return;
}
s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
}

TF_Code TF_GetCode(const TF_Status* s) {
return static_cast<TF_Code>(s->status.code());
}

const char* TF_Message(const TF_Status* s) {
return s->status.error_message().c_str();
}
Loading

0 comments on commit 7d1ba6d

Please sign in to comment.