-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
441 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,6 @@ | |
*.ptx | ||
*.cubin | ||
*.fatbin | ||
|
||
build/ | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
cmake_minimum_required(VERSION 3.5) | ||
|
||
#get the include directory for tensorflow | ||
#execute_process(COMMAND python3 -c "import tensorflow as tf; print(tf.sysconfig.get_include())" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS) | ||
execute_process(COMMAND python3 -c "import tensorflow as tf; print(tf.sysconfig.get_include(), end='')" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS) | ||
execute_process(COMMAND python3 -c "import tensorflow as tf; print(tf.sysconfig.get_lib(), end='')" OUTPUT_VARIABLE Tensorflow_LIB_DIRS) | ||
|
||
message("tensorflow include dir: ${Tensorflow_INCLUDE_DIRS}") | ||
message("tensorflow link dir: ${Tensorflow_LIB_DIRS}") | ||
|
||
include_directories(${Tensorflow_INCLUDE_DIRS}) | ||
include_directories("/usr/local/") | ||
link_directories(${Tensorflow_LIB_DIRS}) | ||
|
||
find_package(CUDA) | ||
|
||
#set flags based on tutorial | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11 -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -D GOOGLE_CUDA=1 -DNDEBUG") | ||
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --expt-relaxed-constexpr" ) | ||
|
||
set(CMAKE_BUILD_TYPE Debug) | ||
|
||
#pass flags to c++ compiler | ||
SET(CUDA_PROPAGATE_HOST_FLAGS ON) | ||
|
||
include_directories(include) | ||
|
||
#create library | ||
cuda_add_library( | ||
cholesky_update SHARED | ||
src/cholesky_update.cu | ||
src/cholesky_update.cc | ||
) | ||
|
||
target_link_libraries(cholesky_update "tensorflow_framework") | ||
|
||
#copy python files to build folder (for easy testing) | ||
file(GLOB PY_FILES | ||
"src/*.py" | ||
) | ||
file(COPY ${PY_FILES} DESTINATION .) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
|
||
try: | ||
_tutorial = tf.load_op_library('./build/libcholesky_update.so') | ||
except Exception as e: | ||
_tutorial = tf.load_op_library('./libcholesky_update.so') | ||
chol_update = _tutorial.chol_update | ||
|
||
n = 100 | ||
m = 10 | ||
k = 19 | ||
|
||
chol = [np.eye(k)*1e-5 for _ in range(m)] | ||
data = np.random.randint(0, 10, (n,m,k)) | ||
mean = np.mean(data,0) | ||
|
||
R = tf.placeholder(tf.float32, shape=[m,k,k], name="R") | ||
x = tf.placeholder(tf.float32, shape=[m,k], name="x") | ||
|
||
update_op = chol_update(R,x) | ||
#print(update_op) | ||
|
||
config = tf.ConfigProto(log_device_placement = True) | ||
#config.graph_options.optimizer_options.opt_level = -1 | ||
|
||
with tf.Session(config=config) as sess: | ||
for i in range(n): | ||
feed = {R: chol, x: data[i] - mean} | ||
chol = sess.run(update_op, feed_dict=feed) | ||
|
||
expected = np.mean([np.cov(data[:,i,:].T) for i in range(m)],0) | ||
result = np.mean([np.matmul(c.T, c) for c in chol],0)/(n-1) | ||
|
||
abs_diff = np.abs(expected - result) | ||
|
||
print("max:", np.max(abs_diff)) | ||
print("mean:", np.mean(abs_diff)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#pragma once | ||
#define EIGEN_USE_GPU | ||
|
||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" | ||
#include <unsupported/Eigen/CXX11/Tensor> | ||
#include <unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h> | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "tensorflow/core/framework/tensor_types.h" | ||
#include "tensorflow/core/platform/types.h" | ||
|
||
#define RIND(b,x,y) ((b)*dim + (y))*dim + (x) | ||
#define XIND(b,x) (b)*dim + (x) | ||
|
||
using namespace tensorflow; | ||
|
||
template <typename Device, typename dtype> | ||
struct launchCholUpdateKernel { | ||
void operator()(const Device& d, | ||
typename TTypes<dtype>::Flat output, typename TTypes<dtype>::Flat x_workspace, | ||
const typename TTypes<dtype>::ConstFlat R, const typename TTypes<dtype>::ConstFlat x, | ||
int batch_size, int dim); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
#include "tensorflow/core/framework/op.h" | ||
#include "tensorflow/core/framework/shape_inference.h" | ||
#include "tensorflow/core/framework/op_kernel.h" | ||
#include "tensorflow/core/platform/default/integral_types.h" | ||
#include "tensorflow/core/util/tensor_format.h" | ||
#include "tensorflow/core/util/padding.h" | ||
|
||
#include <iostream> | ||
#include <cuda.h> | ||
|
||
#include "cholesky_update.h" | ||
|
||
using namespace tensorflow; | ||
using namespace std; | ||
using namespace shape_inference; | ||
|
||
using CPUDevice = Eigen::ThreadPoolDevice; | ||
using GPUDevice = Eigen::GpuDevice; | ||
|
||
//For now only accept 3 and 2 for rank | ||
//i.e. batch of matrices and batch of vectors | ||
Status ShapeFn(InferenceContext* c) | ||
{ | ||
//check input shape has 3 dimensions (batch, d, d) | ||
ShapeHandle r_shape; | ||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &r_shape)); | ||
|
||
//check indices has 2 dimensions (batch, d) | ||
ShapeHandle x_shape; | ||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x_shape)); | ||
|
||
int r_rank = c->Rank(r_shape); | ||
int x_rank = c->Rank(x_shape); | ||
//R must be square | ||
if (c->Value(c->Dim(r_shape,r_rank - 1)) != c->Value(c->Dim(r_shape, r_rank - 2))) | ||
return errors::InvalidArgument("a must be square"); | ||
|
||
//R must match shape of xx^T | ||
for(int i = 0; i < 2; i++) | ||
{ | ||
DimensionHandle r_dim = c->Dim(r_shape,i); | ||
DimensionHandle x_dim = c->Dim(x_shape,i); | ||
|
||
if (c->Value(r_dim) != c->Value(x_dim)) | ||
return errors::InvalidArgument( | ||
"R and x must have same dims"); | ||
} | ||
|
||
//set output size | ||
c->set_output(0, c->input(0)); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
/** | ||
* register the operation with necessary options | ||
*/ | ||
REGISTER_OP("CholUpdate") | ||
.Input("r: T") | ||
.Input("x: T") | ||
.Output("c: T") | ||
.Attr(GetConvnetDataFormatAttrString()) | ||
.Attr("T: {int32, float32, float64}") | ||
.Attr("use_locking: bool = true") | ||
.SetShapeFn(ShapeFn); | ||
|
||
// int to_ind(int b, int h, int w, int batch, int dim) | ||
// { | ||
// return (b*dim + h)*dim + w; | ||
// } | ||
|
||
template <typename dtype> | ||
struct launchCholUpdateKernel<CPUDevice, dtype> { | ||
void operator()(const CPUDevice& d, | ||
typename TTypes<dtype>::Flat R, typename TTypes<dtype>::Flat x, | ||
typename TTypes<dtype>::ConstFlat R_in, typename TTypes<dtype>::ConstFlat x_in, | ||
int batch_size, int dim) { | ||
//based on https://stackoverflow.com/a/16160905/1097517 | ||
R.setZero(); | ||
R += R_in; | ||
x.setZero(); | ||
x += x_in; | ||
|
||
for(int b = 0; b < batch_size; b++) | ||
{ | ||
for(int k = 0; k < dim; k++) | ||
{ | ||
dtype Rkk = R(RIND(b,k,k)); | ||
dtype xk = x(XIND(b,k)); | ||
|
||
dtype r = sqrt(Rkk*Rkk + xk*xk); | ||
dtype c = r/Rkk; | ||
dtype s = xk/Rkk; | ||
R(RIND(b,k,k)) = r; | ||
for(int i = k+1; i < dim; i++) | ||
{ | ||
R(RIND(b,i,k)) = (R(RIND(b,i,k)) + s*x(XIND(b,i)))/c; | ||
x(XIND(b,i)) = c*x(XIND(b,i)) - s*R(RIND(b,i,k)); | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
template <typename Device, typename dtype> | ||
class CholUpdateOp : public OpKernel { | ||
public: | ||
|
||
explicit CholUpdateOp(OpKernelConstruction* context) | ||
: OpKernel(context) | ||
{ | ||
string data_format; | ||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); | ||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_), | ||
errors::InvalidArgument("Invalid data format")); | ||
|
||
//only nhwc supported | ||
OP_REQUIRES(context, data_format_ == FORMAT_NHWC, | ||
errors::InvalidArgument("CholUpdate only supports NHWC ", | ||
"on device type ", | ||
DeviceTypeString(context->device_type()))); | ||
|
||
// const DataType dt = DataTypeToEnum<dtype>::v(); | ||
// OP_REQUIRES_OK(context, | ||
// context->MatchSignature({MakeRefType(dt), dt}, | ||
// {MakeRefType(dt)})); | ||
|
||
// OP_REQUIRES_OK(context, | ||
// context->GetAttr("use_locking", &use_exclusive_lock_)); | ||
|
||
} | ||
|
||
void Compute(OpKernelContext* context) override { | ||
// Grab the input tensor | ||
const Tensor& r_tensor = context->input(0); | ||
OP_REQUIRES(context, r_tensor.IsInitialized(), | ||
errors::FailedPrecondition("Attempting to use uninitialized " | ||
"parameters: ", | ||
requested_input(0))); | ||
const Tensor& x_tensor = context->input(1); | ||
|
||
Tensor x_workspace; | ||
OP_REQUIRES_OK(context, context->allocate_temp( | ||
DataTypeToEnum<dtype>::v(), | ||
x_tensor.shape(), &x_workspace)); | ||
|
||
int batch_size = GetTensorDim(r_tensor.shape(), data_format_, 'N'); | ||
int dim = GetTensorDim(r_tensor.shape(), data_format_, 'H'); | ||
|
||
//flatten tensors | ||
auto x_work_flat = x_workspace.flat<dtype>(); | ||
auto r_flat = r_tensor.flat<dtype>(); | ||
auto x_flat = x_tensor.flat<dtype>(); | ||
|
||
// std::cout << "here" << __LINE__ << std::endl; | ||
|
||
// x_work_flat.device(context->eigen_device<Device>()).setZero(); | ||
// std::cout << "here" << __LINE__ << std::endl; | ||
// x_work_flat += x_flat; //how do you just copy? | ||
// std::cout << "here" << __LINE__ << std::endl; | ||
|
||
Tensor* output_tensor = nullptr; | ||
OP_REQUIRES_OK(context, | ||
context->allocate_output(0, | ||
r_tensor.shape(),&output_tensor)); | ||
|
||
//auto output = context->output(0); | ||
//out->template flat<Scalar>() | ||
auto output_flat = output_tensor->flat<dtype>(); | ||
// std::cout << "here" << __LINE__ << std::endl; | ||
// output.setZero(); | ||
// std::cout << "here" << __LINE__ << std::endl; | ||
// output = r_flat; //how do you just copy? | ||
// //const int N = output.size(); | ||
// std::cout << "here" << __LINE__ << std::endl; | ||
|
||
// Call the cuda kernel launcher | ||
launchCholUpdateKernel<Device, dtype>()( | ||
context->eigen_device<Device>(), | ||
output_flat, | ||
x_work_flat, | ||
r_flat, | ||
x_flat, | ||
batch_size, dim); | ||
} | ||
private: | ||
TensorFormat data_format_; | ||
bool use_exclusive_lock_; | ||
}; | ||
|
||
//register kernel with types needed | ||
#define REGISTER_GPU(type) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("CholUpdate") \ | ||
.Device(DEVICE_GPU) \ | ||
.TypeConstraint<type>("T"), \ | ||
CholUpdateOp<GPUDevice, type>) \ | ||
|
||
REGISTER_GPU(float); | ||
// REGISTER_GPU(double); | ||
|
||
#undef REGISTER_GPU | ||
|
||
#define REGISTER_CPU(type) \ | ||
REGISTER_KERNEL_BUILDER( \ | ||
Name("CholUpdate") \ | ||
.Device(DEVICE_CPU) \ | ||
.TypeConstraint<type>("T"), \ | ||
CholUpdateOp<CPUDevice, type>) \ | ||
|
||
REGISTER_CPU(float); | ||
// REGISTER_CPU(double); | ||
|
||
#undef REGISTER_CPU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#define EIGEN_USE_GPU | ||
#include <cuda.h> | ||
#include <stdio.h> | ||
#include "tensorflow/core/util/cuda_device_functions.h" | ||
|
||
#include "cholesky_update.h" | ||
|
||
using GPUDevice = Eigen::GpuDevice; | ||
|
||
template <typename dtype> | ||
__global__ void CholUpdateKernel(dtype* R, dtype* x, int batch_size, int dim){ | ||
for(int b : tensorflow::CudaGridRangeX(batch_size)) | ||
{ | ||
for(int k = 0; k < dim; k++) | ||
{ | ||
dtype Rkk = R[RIND(b,k,k)]; | ||
dtype xk = x[XIND(b,k)]; | ||
|
||
dtype r = sqrt(Rkk*Rkk + xk*xk); | ||
dtype c = r/Rkk; | ||
dtype s = xk/Rkk; | ||
R[RIND(b,k,k)] = r; | ||
for(int i = k+1; i < dim; i++) | ||
{ | ||
R[RIND(b,i,k)] = (R[RIND(b,i,k)] + s*x[XIND(b,i)])/c; | ||
x[XIND(b,i)] = c*x[XIND(b,i)] - s*R[RIND(b,i,k)]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename dtype> | ||
struct launchCholUpdateKernel<GPUDevice, dtype> { | ||
void operator()(const GPUDevice& d, | ||
typename TTypes<dtype>::Flat R, typename TTypes<dtype>::Flat x, | ||
typename TTypes<dtype>::ConstFlat R_in, typename TTypes<dtype>::ConstFlat x_in, | ||
int batch_size, int dim) { | ||
|
||
To32Bit(R).device(d) = To32Bit(R_in); | ||
To32Bit(x).device(d) = To32Bit(x_in); | ||
|
||
const int kThreadsPerBlock = 1024; | ||
|
||
CholUpdateKernel<dtype><<<(dim*batch_size + kThreadsPerBlock - 1) / kThreadsPerBlock, | ||
kThreadsPerBlock, 0, d.stream()>>>( | ||
R.data(), x.data(), batch_size, dim); | ||
|
||
cudaError_t cudaerr = cudaDeviceSynchronize(); | ||
if (cudaerr != cudaSuccess) | ||
printf("kernel launch failed with error \"%s\".\n", | ||
cudaGetErrorString(cudaerr)); | ||
} | ||
}; | ||
|
||
//forward declaration for all the types needed | ||
typedef Eigen::GpuDevice GPUDevice; | ||
#define ADD_KERNEL_TYPE(type) \ | ||
template struct launchCholUpdateKernel<GPUDevice, type>; \ | ||
|
||
ADD_KERNEL_TYPE(float); | ||
ADD_KERNEL_TYPE(double); | ||
|
||
#undef ADD_KERNEL_TYPE |
Oops, something went wrong.