Skip to content

Commit

Permalink
optimize format
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Sep 29, 2020
1 parent 3cc41b2 commit b5b9903
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 125 deletions.
9 changes: 1 addition & 8 deletions paddle/fluid/operators/multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "A tensor contains probabilities of categories");
AddOutput("Out", "The output tensor of multinomial op");
// AddOutput("yokiOut", "yoki");
AddAttr<int>("num_samples", "number of the generated samples")
.SetDefault(1);
AddAttr<bool>("replacement", "can a category be sampled more than once")
Expand All @@ -50,7 +49,7 @@ class MultinomialOp : public framework::OperatorWithKernel {

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial");
// OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial");

auto x_dim = ctx->GetInputDim("X");
int64_t x_rank = x_dim.size();
Expand All @@ -63,7 +62,6 @@ class MultinomialOp : public framework::OperatorWithKernel {
out_dims[x_rank - 1] = num_samples;

ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
// ctx->SetOutputDim("yokiOut", x_dim);
}
};

Expand All @@ -74,16 +72,11 @@ class MultinomialOpKernel<platform::CPUDeviceContext, T>
void Compute(const framework::ExecutionContext &ctx) const override {
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
// auto yokiout = ctx.Output<framework::Tensor>("yokiOut");
const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement");

auto *in_data = x->data<T>();
int64_t *out_data = out->mutable_data<int64_t>(ctx.GetPlace());
/*auto *yokiout_data = yokiout->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < x->numel(); i++) {
yokiout_data[i] = in_data[i];
}*/

auto in_dims = x->dims();
int64_t in_rank = in_dims.size();
Expand Down
161 changes: 79 additions & 82 deletions paddle/fluid/operators/multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
}

template <typename T>
__global__ void Cumsum(T* norm_probs_data, int64_t num_distributions,
int64_t num_categories, T* cumulative_probs) {
__global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs) {
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
Expand All @@ -60,33 +62,27 @@ struct RandomGeneratorCudaFunctor {
};

template <typename T>
__device__ int binarySearchFunctor(T* cumdist, T* dist, int size, T val) {
__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data,
int num_categories, T rng_number) {
int left = 0;
int right = size;
// cumdist[size - 1] = 0 => all zero prob dist
// CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0));
int right = num_categories;

while (right - left > 0) {
int mid = left + (right - left) / 2;

T midVal = cumdist[mid];
if (midVal < val) {
T temp_prob = cumulative_probs[mid];
if (temp_prob < rng_number) {
left = mid + 1;
} else {
right = mid;
}
}

if (left == size) {
// No probability mass or precision problems; just return the
// first non-zero element by setting left to size-1 here,
// the code below will move it to the last non-zero probability
// this actually can happen when the random number is 1
// (github pytorch issue #4858).
left = size - 1;
if (left == num_categories) {
left = num_categories - 1;
}

while (left >= 1 && dist[left] == 0) left--;
while (left >= 1 && norm_probs_data[left] == 0) left--;

return left;
}
Expand All @@ -96,36 +92,25 @@ __global__ void sampleMultinomialWithReplacement(
T* rng_data, const int64_t num_samples, int64_t* out_data,
const int64_t num_distributions, const int64_t num_categories,
T* cumulative_probs, T* norm_probs_data) {
// At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on.

// global index formula for 2D grid of 1D blocks
// int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x +
// threadIdx.x;

// int idx = blockIdx.x * blockDim.x + threadIdx.x;
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].

int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;

for (int curDist = blockIdx.y; curDist < num_distributions;
curDist += gridDim.y) {
// for every distribution
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
// for every sample
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < num_samples; sample += blockDim.x * gridDim.x) {
// we are losing 3 out of 4 generated numbers but it's ok
// this kernel is not very efficient anyway

// T uniform_random = dist(rng);
T uniform_random = rng_data[sample + curDist * num_samples];
T rng_number = rng_data[sample + dist * num_samples];

// Find the bucket that a uniform sample lies in
int choice =
binarySearchFunctor<T>(cumulative_probs + curDist * num_categories,
norm_probs_data + curDist * num_categories,
num_categories, uniform_random);
// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);

out_data[sample + curDist * num_samples] = choice;
out_data[sample + dist * num_samples] = selected_category;
}
}
}
Expand All @@ -149,9 +134,14 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
const int64_t num_categories = in_dims[in_rank - 1];
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;

// If replacement is False, it's not a replaceable sample. Every category
// can
// be used only once. So after every sample, probability of the distribution
// will change. The implementation can't be parallelizable. Thus, call CPU
// implementation ``MultinomialFunctor`` to sample the distribution.
if (!replacement) {
int in_data_numel = x->numel();
int out_data_numel = out->numel();
int64_t in_data_numel = x->numel();
int64_t out_data_numel = out->numel();

T* cpu_in_data = new T[in_data_numel];
int64_t* cpu_out_data = new int64_t[out_data_numel];
Expand All @@ -169,71 +159,78 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
return;
}

framework::Tensor sum_rows_t;
// Sum of input may not be 1. To get probability in range [0, 1], calculate
// sum of each row of input, and then use the sum to normalize the input.
// sum_row_data: sum of each row
framework::Tensor sum_rows_tensor;
auto* sum_rows_data =
sum_rows_t.mutable_data<T>({num_distributions}, ctx.GetPlace());
sum_rows_tensor.mutable_data<T>({num_distributions}, ctx.GetPlace());

auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();

if (num_distributions == 1) {
auto eigen_input = framework::EigenVector<T>::Flatten(*x);
auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1))
.eval()
.reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0]));
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]));
} else {
auto eigen_input = framework::EigenMatrix<T>::From(*x);
auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
}

framework::Tensor norm_probs_t;
auto* norm_probs_data = norm_probs_t.mutable_data<T>(
// Normalize row of each distribution to get the probability in range [0,
// 1].
// norm_probs_data: probability of the distribution
framework::Tensor norm_probs_tensor;
auto* norm_probs_data = norm_probs_tensor.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace());

dim3 block(num_categories < 512 ? num_categories : 512);
dim3 grid((num_categories - 1) / block.x + 1, num_distributions);
// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
NormalizeProbability<
T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data);

framework::Tensor cumulative_probs_t;
auto* cumulative_probs = cumulative_probs_t.mutable_data<T>(
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
framework::Tensor cumulative_probs_tensor;
auto* cumulative_probs = cumulative_probs_tensor.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace());
dim3 block1(1);
dim3 grid1(num_distributions);
Cumsum<T><<<grid1, block1, 0, ctx.cuda_device_context().stream()>>>(
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0,
ctx.cuda_device_context().stream()>>>(
norm_probs_data, num_distributions, num_categories, cumulative_probs);

VLOG(3) << "Print cumsum " << cumulative_probs << "\n";

if (replacement) {
dim3 block(128);
// int grid_y = 1;
dim3 grid((num_samples - 1) / block.x + 1, num_distributions);

std::random_device rd;
auto seed = rd();

framework::Tensor rng_data_t;
auto* rng_data = rng_data_t.mutable_data<T>(
{num_distributions, num_samples}, ctx.GetPlace());

thrust::counting_iterator<unsigned int> index_sequence_begin(0);
platform::Transform<platform::CUDADeviceContext> trans;
auto* context = static_cast<const platform::CUDADeviceContext*>(
&ctx.device_context());
trans(*context, index_sequence_begin,
index_sequence_begin + num_distributions * num_samples, rng_data,
RandomGeneratorCudaFunctor<T>(seed));

sampleMultinomialWithReplacement<
T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
rng_data, num_samples, out_data, num_distributions, num_categories,
cumulative_probs, norm_probs_data);
}
// Generate random number for each sample.
std::random_device rd;
auto seed = rd();

framework::Tensor rng_data_tensor;
auto* rng_data = rng_data_tensor.mutable_data<T>(
{num_distributions, num_samples}, ctx.GetPlace());

thrust::counting_iterator<unsigned int> index_sequence_begin(0);
platform::Transform<platform::CUDADeviceContext> trans;
auto* context =
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
trans(*context, index_sequence_begin,
index_sequence_begin + num_distributions * num_samples, rng_data,
RandomGeneratorCudaFunctor<T>(seed));

// Sample the multinomial distributions.
dim3 block_sample(128);
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
sampleMultinomialWithReplacement<T><<<grid_sample, block_sample, 0,
ctx.cuda_device_context().stream()>>>(
rng_data, num_samples, out_data, num_distributions, num_categories,
cumulative_probs, norm_probs_data);
}
};

Expand Down
37 changes: 3 additions & 34 deletions python/paddle/fluid/tests/unittests/test_multinomial_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,11 @@ def setUp(self):
self.init_data()
self.inputs = {"X": self.input_np}

"""
def init_data(self):
# input probability is a vector, and replacement is True
self.input_np = np.random.rand(4)
self.outputs = {"Out": np.zeros(100000).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}
"""

def init_data(self):
# input probability is a vector, and replacement is True
self.input_np = np.random.rand(4)
self.outputs = {"Out": np.zeros(100000).astype("int64")}
# self.outputs = {"yokiOut": np.zeros(4).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}

def test_check_output(self):
self.check_output_customized(self.verify_output)
Expand All @@ -55,9 +46,6 @@ def verify_output(self, outs):
# normalize the input to get the probability
prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True)
sample_prob = self.sample_output(np.array(outs[0]))
# sample_prob = np.array(outs[0])
# print("input", self.input_np)
# print("sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
Expand All @@ -69,7 +57,6 @@ def init_data(self):
# input probability is a matrix
self.input_np = np.random.rand(3, 4)
self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")}
# self.outputs = {"yokiOut": np.zeros((3, 4)).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}

def sample_output(self, out):
Expand All @@ -92,25 +79,15 @@ def init_data(self):

def verify_output(self, outs):
out = np.array(outs[0])
# print("op3out", out)
unique_out = np.unique(out)
self.assertEqual(
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")


"""
class TestReplacementError(unittest.TestCase):
def init_data(self):
# replacement is False. if number of samples is larger than number of categories, raise error.
self.input_np = np.random.rand(4)
self.outputs = {"Out": np.zeros(10).astype("int64")}
self.attrs = {"num_samples": 10, "replacement": False}
"""


class TestMultinomialApi(unittest.TestCase):
def test_dygraph(self):
# input probability is a vector, and replacement is True
paddle.disable_static()
x = paddle.rand([4])
out = paddle.multinomial(x, num_samples=100000, replacement=True)
Expand All @@ -128,6 +105,7 @@ def test_dygraph(self):
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))

def test_dygraph2(self):
# input probability is a matrix, and replacement is True
paddle.disable_static()
x = paddle.rand([3, 4])
out = paddle.multinomial(x, num_samples=100000, replacement=True)
Expand All @@ -149,6 +127,7 @@ def test_dygraph2(self):
paddle.enable_static()

def test_dygraph3(self):
# replacement is False. number of samples must be less than number of categories.
paddle.disable_static()
x = paddle.rand([1000])
out = paddle.multinomial(x, num_samples=100, replacement=False)
Expand Down Expand Up @@ -186,16 +165,6 @@ def test_static(self):
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))

"""
def test_replacement_error(self):
def test_error():
paddle.disable_static()
x = paddle.rand([5])
out = paddle.multinomial(x, num_samples=10, replacement=False)
self.assertRaises(paddle.fluid.core.EnforceNotMet, test_error)
"""


class TestMultinomialAlias(unittest.TestCase):
def test_alias(self):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
x(Tensor): A tensor with probabilities for generating the random number. The data type
should be float32, float64.
num_samples(int, optional): Number of samples, default is 1.
replacement(bool, optional): whether it is a replaceable sample, default is False.
replacement(bool, optional): Whether it is a replaceable sample, default is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Expand Down

0 comments on commit b5b9903

Please sign in to comment.