Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 19, 2022
1 parent 08714e9 commit f2d5f12
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
16 changes: 7 additions & 9 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,12 @@ TEST(net_build, program_execute_gather) {
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input1_tensor = scope->GetTensor(std::string(input1.id()));
// SetRandData<float>(input1_tensor, target);
SetRandData<float>(input1_tensor, target);
float* input1_data = input1_tensor->mutable_data<float>(target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
// SetRandData<int>(input2_tensor, target);
SetRandInt(input2_tensor, target);
int* input2_data = input2_tensor->mutable_data<int>(target);

memset(input1_data, 0, sizeof(float) * B * H_IN1);
memset(input2_data, 0, sizeof(int) * B * H_IN2);
runtime_program->Execute();

Expand Down Expand Up @@ -329,7 +327,7 @@ TEST(net_build, program_execute_gather_nd) {
float* input1_data = input1_tensor->mutable_data<float>(target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
SetRandData<int>(input2_tensor, target);
SetRandInt(input2_tensor, target);
int* input2_data = input2_tensor->mutable_data<int>(target);

runtime_program->Execute();
Expand Down Expand Up @@ -380,11 +378,11 @@ TEST(net_build, program_execute_scatter) {
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input1_tensor = scope->GetTensor(std::string(input1.id()));
// SetRandData<float>(input1_tensor, target);
SetRandData<float>(input1_tensor, target);
float* input1_data = input1_tensor->mutable_data<float>(target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
// SetRandData<int>(input2_tensor, target);
// SetRandInt(input2_tensor, target);
int* input2_data = input2_tensor->mutable_data<int>(target);

runtime_program->Execute();
Expand Down Expand Up @@ -452,7 +450,7 @@ TEST(net_build, program_execute_scatter_nd) {
SetRandData<float>(input1_tensor, target);

auto input2_tensor = scope->GetTensor(std::string(input2.id()));
SetRandData<int>(input2_tensor, target);
SetRandInt(input2_tensor, target);

runtime_program->Execute();

Expand Down Expand Up @@ -517,7 +515,7 @@ TEST(net_build, program_execute_cast) {
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<int>(input_tensor, target);
SetRandInt(input_tensor, target);
int* input_data = input_tensor->mutable_data<int>(target);

runtime_program->Execute();
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/op/contrib/gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ ir::Tensor Gather(const ir::Tensor &A, const ir::Tensor &B, const int &axis, con
for (int i = 0; i < axis; ++i) {
A_indices.push_back(indices[i]);
}
A_indices.push_back(ir::Cast::Make(Int(32), B(indices)));
A_indices.push_back(B(indices));
for (size_t i = axis + 1; i < A->shape.size(); ++i) {
A_indices.push_back(indices[i]);
}
Expand Down
14 changes: 8 additions & 6 deletions cinn/utils/data_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,33 @@

#include "cinn/utils/data_util.h"

#include "iostream"

namespace cinn {

template <typename T>
void SetRandInt<int>(hlir::framework::Tensor tensor, const common::Target& target, int seed) {
void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, int seed) {
if (seed == -1) {
std::random_device rd;
seed = rd();
}
std::default_random_engine engine(seed);
std::uniform_int_distribution<int> dist(1, 10);
size_t num_ele = tensor->shape().numel();
std::vector<T> random_data(num_ele);
std::vector<int> random_data(num_ele);
for (size_t i = 0; i < num_ele; i++) {
random_data[i] = static_cast<T> dist(engine); // All random data
random_data[i] = static_cast<int>(dist(engine)); // All random data
}

auto* data = tensor->mutable_data<T>(target);
auto* data = tensor->mutable_data<int>(target);
#ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) {
cudaMemcpy(data, random_data.data(), num_ele * sizeof(T), cudaMemcpyHostToDevice);
cudaMemcpy(data, random_data.data(), num_ele * sizeof(int), cudaMemcpyHostToDevice);
return;
}
#endif
CHECK(target == common::DefaultHostTarget());
std::copy(random_data.begin(), random_data.end(), data);
std::cout << "success" << std::endl;
}

template <>
Expand Down
1 change: 0 additions & 1 deletion cinn/utils/data_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#endif

namespace cinn {
template <typename T>
void SetRandInt(hlir::framework::Tensor tensor, const common::Target& target, int seed = -1);

template <typename T>
Expand Down

0 comments on commit f2d5f12

Please sign in to comment.