Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Large Tensor] Add support to Random Sample & Pdf ops #17445

Merged
merged 5 commits into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/operator/random/multisample_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline bool MultiSampleOpShape(const nnvm::NodeAttrs& attrs,
// Promote down by removing last dimensions which represent the samples.
tshape = mxnet::TShape(tshape.begin(), tshape.begin()+(tshape.ndim()-sshape.ndim()));
}
// Shape assignemnt/checking for inputs.
// Shape assignment/checking for inputs.
for (const auto& in_attr : *in_attrs) {
if ( !shape_assign(&tshape, in_attr)) return false;
}
Expand All @@ -86,7 +86,7 @@ inline bool MultiSampleOpShape(const nnvm::NodeAttrs& attrs,
}
if (tshape.ndim() > 0) {
// Shape assignment/check for propagation from inputs to output.
std::vector<int> cshape(tshape.begin(), tshape.end());
std::vector<index_t> cshape(tshape.begin(), tshape.end());
cshape.insert(cshape.end(), sshape.begin(), sshape.end());
mxnet::TShape oshape(cshape.begin(), cshape.end());
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
Expand Down
150 changes: 75 additions & 75 deletions src/operator/random/pdf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ MSHADOW_XINLINE mshadow::half::half_t ceph_psi(mshadow::half::half_t val) {
template<bool logpdf>
struct PDF_Uniform {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *lower, IType2 *upper) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(lower[index]), h(upper[index]);
const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
// No check whether sample is in the support.
out[i] = logpdf ? -DType(log(h - l)) : DType(1.0) / (h - l);
}
Expand All @@ -63,14 +63,14 @@ struct PDF_Uniform {
template<bool logpdf>
struct PDF_Uniform_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *lower, IType2 *upper,
DType *grad_out, IType1 *grad_sample, IType2 *grad_lower, IType2 *grad_upper) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(lower[index]), h(upper[index]);

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
grad_lower[i] = scaling / (h - l);
grad_upper[i] = scaling / (l - h);
Expand All @@ -82,14 +82,14 @@ struct PDF_Uniform_Grad {
template<bool logpdf>
struct PDF_Normal {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *loc, IType2 *scale) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType u(loc[index]), s(scale[index]), sq(s * s);
const DType normalizer(sqrt(2.0 * mxnet_op::PI) * s);

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType exponent((DType(-0.5) * (x - u) * (x - u)) / (sq));
out[i] = logpdf ? exponent - log(normalizer) : exp(exponent) / normalizer;
Expand All @@ -100,14 +100,14 @@ struct PDF_Normal {
template<bool logpdf>
struct PDF_Normal_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *loc, IType2 *scale,
DType *grad_out, IType1 *grad_sample, IType2 *grad_loc, IType2 *grad_scale) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType u(loc[index]), s(scale[index]), s_squared(s * s), s_cubed(s_squared * s);

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
grad_loc[i] = scaling * (x - u) / s_squared;
Expand All @@ -120,13 +120,13 @@ struct PDF_Normal_Grad {
template<bool logpdf>
struct PDF_Gamma {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *alpha, IType2 *beta) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType a(alpha[index]), b(beta[index]), lgamma_a(lgamma(a)), a_log_b(a * log(b));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType lpdf(a_log_b + (a - 1) * log(x) - b * x - lgamma_a);
out[i] = logpdf ? lpdf : DType(exp(lpdf));
Expand All @@ -137,14 +137,14 @@ struct PDF_Gamma {
template<bool logpdf>
struct PDF_Gamma_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *alpha, IType2 *beta,
DType *grad_out, IType1 *grad_sample, IType2 *grad_alpha, IType2 *grad_beta) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType a(alpha[index]), b(beta[index]), log_b(log(b)), ceph_psi_a(ceph_psi(a));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
grad_alpha[i] = scaling * (log_b + log(x) - ceph_psi_a);
Expand All @@ -157,13 +157,13 @@ struct PDF_Gamma_Grad {
template<bool logpdf>
struct PDF_Exponential {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *lambda) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(lambda[index]), log_l(log(l));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
out[i] = logpdf ? log_l - l * x : l * exp(-l * x);
}
Expand All @@ -173,14 +173,14 @@ struct PDF_Exponential {
template<bool logpdf>
struct PDF_Exponential_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *lambda,
DType *grad_out, IType1 *grad_sample, IType2 *grad_lambda) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(lambda[index]);

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
grad_lambda[i] = scaling * (DType(1) / l - x);
Expand All @@ -192,13 +192,13 @@ struct PDF_Exponential_Grad {
template<bool logpdf>
struct PDF_Poisson {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *lambda) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(lambda[index]), log_l(log(l));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType lpdf((x * log_l - lgamma(x + 1)) - l);
out[i] = logpdf ? lpdf : DType(exp(lpdf));
Expand All @@ -209,14 +209,14 @@ struct PDF_Poisson {
template<bool logpdf>
struct PDF_Poisson_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *lambda,
DType *grad_out, IType1 *grad_sample, IType2 *grad_lambda) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(lambda[index]);

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
grad_lambda[i] = scaling * (x / l - DType(1));
Expand All @@ -229,13 +229,13 @@ struct PDF_Poisson_Grad {
template<bool logpdf>
struct PDF_NegativeBinomial {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *limit, IType2 *prob) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType l(limit[index]), p(prob[index]), lgamma_l(lgamma(l));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType x(sample[i]);
const DType lpdf((lgamma(x + l) - lgamma(x + 1) - lgamma_l) + l * log(p) + x * log(1 - p));
out[i] = logpdf ? lpdf : DType(exp(lpdf));
Expand All @@ -252,12 +252,12 @@ struct PDF_NegativeBinomial {
template<bool logpdf>
struct PDF_NegativeBinomial_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *limit, IType2 *prob,
DType *grad_out, IType1 *grad_sample, IType2 *grad_limit, IType2 *grad_prob) {
const int index(start / sample_size);
const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t index(start / sample_size);
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
DType grad_l(0), grad_p(0);
LPDF_GRAD(DType(limit[index]), DType(prob[index]),
DType(sample[i]), out[i],
Expand All @@ -280,15 +280,15 @@ struct PDF_NegativeBinomial_Grad {
template<bool logpdf>
struct PDF_GeneralizedNegativeBinomial {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
DType *out, IType1 *sample, IType2 *mu, IType2 *alpha) {
const int index(start / sample_size);
const index_t index(start / sample_size);

// Reparameterize with limit = 1 / alpha, prob = 1 / (mu * alpha + 1)
const DType limit(1.0 / alpha[index]), prob(1.0 / (mu[index]*alpha[index]+1.0));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const DType lpdf(PDF_NegativeBinomial<logpdf>::LPDF(limit, prob, DType(sample[i])));
out[i] = logpdf ? lpdf : DType(exp(lpdf));
}
Expand All @@ -298,17 +298,17 @@ struct PDF_GeneralizedNegativeBinomial {
template<bool logpdf>
struct PDF_GeneralizedNegativeBinomial_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, OpReqType req,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
DType *out, IType1 *sample, IType2 *mu, IType2 *alpha,
DType *grad_out, IType1 *grad_sample, IType2 *grad_mu, IType2 *grad_alpha) {
const int index(start / sample_size);
const index_t index(start / sample_size);
const DType fmu(mu[index]), falpha(alpha[index]), den(fmu * falpha + 1.0);

// Reparameterize with limit = 1 / alpha, prob = 1 / (mu * alpha + 1)
const DType limit(1.0 / falpha), prob(1.0 / (fmu * falpha + 1.0));

const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
// Grad returned as d_limit, d_prob
DType grad_l(0), grad_p(0);
PDF_NegativeBinomial_Grad<logpdf>::LPDF_GRAD(limit, prob,
Expand All @@ -324,15 +324,15 @@ struct PDF_GeneralizedNegativeBinomial_Grad {
template<bool logpdf>
struct PDF_Dirichlet {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size, int k,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, index_t k,
DType *out, IType1 *sample, IType2 *alpha) {
const int index(start / sample_size);
const int end = start + length;
for (int i = start; i < end; ++i) {
const index_t index(start / sample_size);
const index_t end = start + length;
for (index_t i = start; i < end; ++i) {
const IType1 *cur_sample = sample + i * k;
const IType2 *cur_alpha = alpha + index * k;
DType sum_alpha(0), sum_lgamma(0), sum_sample(0);
for (int j = 0; j < k; ++j) {
for (index_t j = 0; j < k; ++j) {
sum_alpha += cur_alpha[j];
sum_lgamma += lgamma(cur_alpha[j]);
sum_sample += (cur_alpha[j]-1) * log(cur_sample[j]);
Expand All @@ -347,27 +347,27 @@ struct PDF_Dirichlet {
template<bool logpdf>
struct PDF_Dirichlet_Grad {
template<typename DType, typename IType1, typename IType2>
MSHADOW_XINLINE static void Map(int start, int length, int sample_size,
OpReqType req, int k,
MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
OpReqType req, index_t k,
DType *out, IType1 *sample, IType2 *alpha,
DType *grad_out, IType1 *grad_sample, IType2 *grad_alpha
) {
const int index(start / sample_size);
const int end = start + length;
const index_t index(start / sample_size);
const index_t end = start + length;

for (int i = start; i < end; ++i) {
for (index_t i = start; i < end; ++i) {
// Digamma function
const IType1 *cur_sample = sample + i * k;
const IType2 *cur_alpha = alpha + index * k;

const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
DType sum_alpha(0);
for (int j = 0; j < k; ++j) {
for (index_t j = 0; j < k; ++j) {
sum_alpha += cur_alpha[j];
}
const DType psi_sum(ceph_psi(sum_alpha));

for (int j = 0; j < k; ++j) {
for (index_t j = 0; j < k; ++j) {
size_t grad_alpha_index = i%sample_size + sample_size * (j + k * index);
size_t grad_sample_index = i * k + j;

Expand Down Expand Up @@ -433,25 +433,25 @@ inline bool PdfOpShape(const nnvm::NodeAttrs& attrs,
template<typename OP>
struct LaunchExWrapper {
template<typename ...Args>
MSHADOW_XINLINE static void Map(const int start, const int length, const int sample_size,
Args... args) {
MSHADOW_XINLINE static void Map(const index_t start, const index_t length,
const index_t sample_size, Args... args) {
// Apply the operator to the sample in strides of sample_size, so that
// the operators can assume that their distribution parameters are constant.
int i = start;
index_t i = start;

// Get aligned
const int align_step = sample_size - (i % sample_size);
const int first_stride = length > align_step ? align_step : length;
const index_t align_step = sample_size - (i % sample_size);
const index_t first_stride = length > align_step ? align_step : length;
OP::Map(i, first_stride, sample_size, args...);
i += first_stride;

const int end = start + length - sample_size;
const index_t end = start + length - sample_size;
for (; i < end; i += sample_size) {
OP::Map(i, sample_size, sample_size, args...);
}

// Last stride might not be aligned either
const int last_stride = start + length - i;
const index_t last_stride = start + length - i;
if (last_stride > 0) { // Don't overstep even if length <= sample_size
OP::Map(i, last_stride, sample_size, args...);
}
Expand Down