Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mattangus committed Nov 21, 2018
1 parent 245c7ca commit 4991bfc
Show file tree
Hide file tree
Showing 7 changed files with 441 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
*.ptx
*.cubin
*.fatbin

build/
.vscode/
41 changes: 41 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
cmake_minimum_required(VERSION 3.5)

#get the include directory for tensorflow
#execute_process(COMMAND python3 -c "import tensorflow as tf; print(tf.sysconfig.get_include())" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS)
execute_process(COMMAND python3 -c "import tensorflow as tf; print(tf.sysconfig.get_include(), end='')" OUTPUT_VARIABLE Tensorflow_INCLUDE_DIRS)
execute_process(COMMAND python3 -c "import tensorflow as tf; print(tf.sysconfig.get_lib(), end='')" OUTPUT_VARIABLE Tensorflow_LIB_DIRS)

message("tensorflow include dir: ${Tensorflow_INCLUDE_DIRS}")
message("tensorflow link dir: ${Tensorflow_LIB_DIRS}")

include_directories(${Tensorflow_INCLUDE_DIRS})
include_directories("/usr/local/")
link_directories(${Tensorflow_LIB_DIRS})

find_package(CUDA)

#set flags based on tutorial
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11 -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 -D GOOGLE_CUDA=1 -DNDEBUG")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --expt-relaxed-constexpr" )

set(CMAKE_BUILD_TYPE Debug)

#pass flags to c++ compiler
SET(CUDA_PROPAGATE_HOST_FLAGS ON)

include_directories(include)

#create library
cuda_add_library(
cholesky_update SHARED
src/cholesky_update.cu
src/cholesky_update.cc
)

target_link_libraries(cholesky_update "tensorflow_framework")

#copy python files to build folder (for easy testing)
file(GLOB PY_FILES
"src/*.py"
)
file(COPY ${PY_FILES} DESTINATION .)
38 changes: 38 additions & 0 deletions chol_as_cov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import tensorflow as tf
import numpy as np

try:
_tutorial = tf.load_op_library('./build/libcholesky_update.so')
except Exception as e:
_tutorial = tf.load_op_library('./libcholesky_update.so')
chol_update = _tutorial.chol_update

n = 100
m = 10
k = 19

chol = [np.eye(k)*1e-5 for _ in range(m)]
data = np.random.randint(0, 10, (n,m,k))
mean = np.mean(data,0)

R = tf.placeholder(tf.float32, shape=[m,k,k], name="R")
x = tf.placeholder(tf.float32, shape=[m,k], name="x")

update_op = chol_update(R,x)
#print(update_op)

config = tf.ConfigProto(log_device_placement = True)
#config.graph_options.optimizer_options.opt_level = -1

with tf.Session(config=config) as sess:
for i in range(n):
feed = {R: chol, x: data[i] - mean}
chol = sess.run(update_op, feed_dict=feed)

expected = np.mean([np.cov(data[:,i,:].T) for i in range(m)],0)
result = np.mean([np.matmul(c.T, c) for c in chol],0)/(n-1)

abs_diff = np.abs(expected - result)

print("max:", np.max(abs_diff))
print("mean:", np.mean(abs_diff))
22 changes: 22 additions & 0 deletions include/cholesky_update.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once
#define EIGEN_USE_GPU

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include <unsupported/Eigen/CXX11/Tensor>
#include <unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"

#define RIND(b,x,y) ((b)*dim + (y))*dim + (x)
#define XIND(b,x) (b)*dim + (x)

using namespace tensorflow;

template <typename Device, typename dtype>
struct launchCholUpdateKernel {
void operator()(const Device& d,
typename TTypes<dtype>::Flat output, typename TTypes<dtype>::Flat x_workspace,
const typename TTypes<dtype>::ConstFlat R, const typename TTypes<dtype>::ConstFlat x,
int batch_size, int dim);
};
214 changes: 214 additions & 0 deletions src/cholesky_update.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/default/integral_types.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/padding.h"

#include <iostream>
#include <cuda.h>

#include "cholesky_update.h"

using namespace tensorflow;
using namespace std;
using namespace shape_inference;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

//For now only accept 3 and 2 for rank
//i.e. batch of matrices and batch of vectors
Status ShapeFn(InferenceContext* c)
{
//check input shape has 3 dimensions (batch, d, d)
ShapeHandle r_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &r_shape));

//check indices has 2 dimensions (batch, d)
ShapeHandle x_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x_shape));

int r_rank = c->Rank(r_shape);
int x_rank = c->Rank(x_shape);
//R must be square
if (c->Value(c->Dim(r_shape,r_rank - 1)) != c->Value(c->Dim(r_shape, r_rank - 2)))
return errors::InvalidArgument("a must be square");

//R must match shape of xx^T
for(int i = 0; i < 2; i++)
{
DimensionHandle r_dim = c->Dim(r_shape,i);
DimensionHandle x_dim = c->Dim(x_shape,i);

if (c->Value(r_dim) != c->Value(x_dim))
return errors::InvalidArgument(
"R and x must have same dims");
}

//set output size
c->set_output(0, c->input(0));

return Status::OK();
}

/**
* register the operation with necessary options
*/
REGISTER_OP("CholUpdate")
.Input("r: T")
.Input("x: T")
.Output("c: T")
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {int32, float32, float64}")
.Attr("use_locking: bool = true")
.SetShapeFn(ShapeFn);

// int to_ind(int b, int h, int w, int batch, int dim)
// {
// return (b*dim + h)*dim + w;
// }

template <typename dtype>
struct launchCholUpdateKernel<CPUDevice, dtype> {
void operator()(const CPUDevice& d,
typename TTypes<dtype>::Flat R, typename TTypes<dtype>::Flat x,
typename TTypes<dtype>::ConstFlat R_in, typename TTypes<dtype>::ConstFlat x_in,
int batch_size, int dim) {
//based on https://stackoverflow.com/a/16160905/1097517
R.setZero();
R += R_in;
x.setZero();
x += x_in;

for(int b = 0; b < batch_size; b++)
{
for(int k = 0; k < dim; k++)
{
dtype Rkk = R(RIND(b,k,k));
dtype xk = x(XIND(b,k));

dtype r = sqrt(Rkk*Rkk + xk*xk);
dtype c = r/Rkk;
dtype s = xk/Rkk;
R(RIND(b,k,k)) = r;
for(int i = k+1; i < dim; i++)
{
R(RIND(b,i,k)) = (R(RIND(b,i,k)) + s*x(XIND(b,i)))/c;
x(XIND(b,i)) = c*x(XIND(b,i)) - s*R(RIND(b,i,k));
}
}
}
}
};

template <typename Device, typename dtype>
class CholUpdateOp : public OpKernel {
public:

explicit CholUpdateOp(OpKernelConstruction* context)
: OpKernel(context)
{
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));

//only nhwc supported
OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
errors::InvalidArgument("CholUpdate only supports NHWC ",
"on device type ",
DeviceTypeString(context->device_type())));

// const DataType dt = DataTypeToEnum<dtype>::v();
// OP_REQUIRES_OK(context,
// context->MatchSignature({MakeRefType(dt), dt},
// {MakeRefType(dt)}));

// OP_REQUIRES_OK(context,
// context->GetAttr("use_locking", &use_exclusive_lock_));

}

void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& r_tensor = context->input(0);
OP_REQUIRES(context, r_tensor.IsInitialized(),
errors::FailedPrecondition("Attempting to use uninitialized "
"parameters: ",
requested_input(0)));
const Tensor& x_tensor = context->input(1);

Tensor x_workspace;
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<dtype>::v(),
x_tensor.shape(), &x_workspace));

int batch_size = GetTensorDim(r_tensor.shape(), data_format_, 'N');
int dim = GetTensorDim(r_tensor.shape(), data_format_, 'H');

//flatten tensors
auto x_work_flat = x_workspace.flat<dtype>();
auto r_flat = r_tensor.flat<dtype>();
auto x_flat = x_tensor.flat<dtype>();

// std::cout << "here" << __LINE__ << std::endl;

// x_work_flat.device(context->eigen_device<Device>()).setZero();
// std::cout << "here" << __LINE__ << std::endl;
// x_work_flat += x_flat; //how do you just copy?
// std::cout << "here" << __LINE__ << std::endl;

Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0,
r_tensor.shape(),&output_tensor));

//auto output = context->output(0);
//out->template flat<Scalar>()
auto output_flat = output_tensor->flat<dtype>();
// std::cout << "here" << __LINE__ << std::endl;
// output.setZero();
// std::cout << "here" << __LINE__ << std::endl;
// output = r_flat; //how do you just copy?
// //const int N = output.size();
// std::cout << "here" << __LINE__ << std::endl;

// Call the cuda kernel launcher
launchCholUpdateKernel<Device, dtype>()(
context->eigen_device<Device>(),
output_flat,
x_work_flat,
r_flat,
x_flat,
batch_size, dim);
}
private:
TensorFormat data_format_;
bool use_exclusive_lock_;
};

//register kernel with types needed
#define REGISTER_GPU(type) \
REGISTER_KERNEL_BUILDER( \
Name("CholUpdate") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T"), \
CholUpdateOp<GPUDevice, type>) \

REGISTER_GPU(float);
// REGISTER_GPU(double);

#undef REGISTER_GPU

#define REGISTER_CPU(type) \
REGISTER_KERNEL_BUILDER( \
Name("CholUpdate") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
CholUpdateOp<CPUDevice, type>) \

REGISTER_CPU(float);
// REGISTER_CPU(double);

#undef REGISTER_CPU
63 changes: 63 additions & 0 deletions src/cholesky_update.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#define EIGEN_USE_GPU
#include <cuda.h>
#include <stdio.h>
#include "tensorflow/core/util/cuda_device_functions.h"

#include "cholesky_update.h"

using GPUDevice = Eigen::GpuDevice;

template <typename dtype>
__global__ void CholUpdateKernel(dtype* R, dtype* x, int batch_size, int dim){
for(int b : tensorflow::CudaGridRangeX(batch_size))
{
for(int k = 0; k < dim; k++)
{
dtype Rkk = R[RIND(b,k,k)];
dtype xk = x[XIND(b,k)];

dtype r = sqrt(Rkk*Rkk + xk*xk);
dtype c = r/Rkk;
dtype s = xk/Rkk;
R[RIND(b,k,k)] = r;
for(int i = k+1; i < dim; i++)
{
R[RIND(b,i,k)] = (R[RIND(b,i,k)] + s*x[XIND(b,i)])/c;
x[XIND(b,i)] = c*x[XIND(b,i)] - s*R[RIND(b,i,k)];
}
}
}
}

template <typename dtype>
struct launchCholUpdateKernel<GPUDevice, dtype> {
void operator()(const GPUDevice& d,
typename TTypes<dtype>::Flat R, typename TTypes<dtype>::Flat x,
typename TTypes<dtype>::ConstFlat R_in, typename TTypes<dtype>::ConstFlat x_in,
int batch_size, int dim) {

To32Bit(R).device(d) = To32Bit(R_in);
To32Bit(x).device(d) = To32Bit(x_in);

const int kThreadsPerBlock = 1024;

CholUpdateKernel<dtype><<<(dim*batch_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(
R.data(), x.data(), batch_size, dim);

cudaError_t cudaerr = cudaDeviceSynchronize();
if (cudaerr != cudaSuccess)
printf("kernel launch failed with error \"%s\".\n",
cudaGetErrorString(cudaerr));
}
};

//forward declaration for all the types needed
typedef Eigen::GpuDevice GPUDevice;
#define ADD_KERNEL_TYPE(type) \
template struct launchCholUpdateKernel<GPUDevice, type>; \

ADD_KERNEL_TYPE(float);
ADD_KERNEL_TYPE(double);

#undef ADD_KERNEL_TYPE
Loading

0 comments on commit 4991bfc

Please sign in to comment.