Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Sep 17, 2022
1 parent 7d294fb commit e0e5238
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 150 deletions.
92 changes: 34 additions & 58 deletions cinn/hlir/op/contrib/argmax.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/argmax.h"

#include <iostream>
Expand All @@ -23,7 +37,11 @@ using common::CINNValue;
using framework::shape_t;
using ir::Tensor;

Tensor Argmax(const Tensor &in_tensor, const int &axis, const bool keep_dims, const std::string &output_name) {
Tensor Argmax(const Tensor &in_tensor,
const int &axis,
const bool keep_dims,
poly::StageMap stages,
const std::string &output_name) {
auto shape = in_tensor->shape;
auto ndim = shape.size();
CHECK_GT(ndim, 0) << "tensor's dim must be more than 0";
Expand Down Expand Up @@ -52,62 +70,20 @@ Tensor Argmax(const Tensor &in_tensor, const int &axis, const bool keep_dims, co
output_shape.push_back(Expr(1));
}

auto compute = [=](const std::vector<Expr> &indices) -> Expr {
std::vector<Expr> eval_indices(indices);

if (!keep_dims) {
eval_indices.insert(eval_indices.begin() + real_axis, Expr(1));
}
CHECK_EQ(eval_indices.size(), ndim);

// Var loop_var("k0");
// eval_indices[real_axis] = i;
// auto value = in_tensor(eval_indices);
// auto update = ir::LT::Make(value, current[1]);
// auto c1 = ir::Select::Make(update, Expr(i), current[0]);
// auto c2 = ir::Select::Make(update, value, current[1]);
// current[0] = c1;
// current[1] = c2;
// auto for_loop = ir::For::Make(i, Expr(0), current[0]);

Placeholder<float> p_max_value("max_value", {shape[real_axis]});
Placeholder<int32_t> p_max_index("max_index", {shape[real_axis]});
auto max_value = ir::Tensor(p_max_value);
auto max_index = ir::Tensor(p_max_index);

// max_value = lang::Identity(ir::Store::Make(min_value, Expr(-3.402823e+38f), {Expr(0)}));

Var loop_var("k0", Int(32));
Expr loop_expr = Expr(loop_var);
eval_indices[real_axis] = Expr(loop_var);

auto value = lang::Identity(in_tensor(eval_indices));

CHECK_EQ(value->type(), Expr(0.0f)->type());
auto update = ir::LT::Make(value, Expr(0.0f));
// auto update = ir::LT::Make(value, ir::Load::Make(max_value, {Expr(loop_var)}));
auto c_v = ir::Select::Make(update, value, Expr(0.0f));
auto c_i = ir::Select::Make(update, Expr(loop_var), Expr(0));
// auto c_v = ir::Select::Make(update, value, ir::Load::Make(max_value, {Expr(loop_var)}));
// auto c_i = ir::Select::Make(update, Expr(loop_var), ir::Load::Make(max_index, {Expr(loop_var)}));

Expr body1 = ir::Store::Make(max_value, c_v, {Expr(loop_var) + 1});
// Expr body2 = ir::Store::Make(max_index, c_i, {Expr(loop_var)+1});

Expr body = ir::Block::Make({body1});

auto output = ir::For::Make(
loop_var, common::make_const(0), shape[real_axis] - 1, ir::ForType::Serial, ir::DeviceAPI::Host, body);

// for (int i = 0; i<shape[real_axis]; i++){
// }

return ir::Load::Make(output, {shape[real_axis] - 1});
// return lang::Identity(eval_indices[0]);
// return ir::Load::Make(output, {shape[real_axis]-1});
};

Tensor res = Compute(output_shape, compute, output_name);
auto sort_index = ArgSort(A, target, pos_axis, false, name + "_index");
auto res = Compute(
output_shape,
[=](const std::vector<Expr> &indices) {
std::vector<Expr> eval_indices(indices);
if (!keep_dims) {
eval_indices.insert(eval_indices.begin() + real_axis, Expr(0));
} else {
eval_indices[real_axis] = Expr(0);
}
return sort_index[eval_indices];
},
name);
stages->InsertLazily(sort_index);
return res;
}

Expand Down Expand Up @@ -137,7 +113,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmax(const framework::NodeAt
CHECK(in_expr.as_tensor());
Tensor in_tensor = in_expr.as_tensor_ref();
auto stages = CreateStages({in_tensor});
auto out_tensor = Argmax(in_tensor, axis, keep_dims, tensor_name);
auto out_tensor = Argmax(in_tensor, axis, keep_dims, stages, tensor_name);

stages->InsertLazily(out_tensor);
std::vector<CINNValue> cinn_values{CINNValue(out_tensor), CINNValue(stages)};
Expand Down
16 changes: 15 additions & 1 deletion cinn/hlir/op/contrib/argmax.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
Expand All @@ -24,6 +38,6 @@ ir::Tensor Argmax(const ir::Tensor& A,
const int& axis,
const bool keep_dims = false,
const std::string& output_name = "T_Argmax_out");
} // namespace pe
} // namespace op
} // namespace hlir
} // namespace cinn
19 changes: 16 additions & 3 deletions cinn/hlir/op/contrib/argmax_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/argmax.h"

#include <glog/logging.h>
Expand Down Expand Up @@ -31,10 +45,9 @@ TEST(GenerateCode_Cpu, Argmax_Keep) {
ir::Expr w(28);

lang::Placeholder<float> in("in", {n, in_c, h, w});
lang::Placeholder<float> out("out", {n, out_c, h, w});
ir::Tensor res = Argmax(in, axis, true, "test_argmax_in");
poly::StageMap stages = poly::CreateStages({in});
ir::Tensor res = Argmax(in, axis, true, stages, "test_argmax_in");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCpu_Argmax_Keep", stages, {res}, {}, {}, nullptr, target, true);

Expand Down
112 changes: 27 additions & 85 deletions cinn/hlir/op/contrib/argmin.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/argmin.h"

#include <iostream>
Expand All @@ -24,7 +38,11 @@ using common::CINNValue;
using framework::shape_t;
using ir::Tensor;

Tensor Argmin(const Tensor &in_tensor, const int &axis, const bool keep_dims, const std::string &output_name) {
Tensor Argmin(const Tensor &in_tensor,
const int &axis,
const bool keep_dims,
poly::StageMap stages,
const std::string &output_name) {
auto shape = in_tensor->shape;
auto ndim = shape.size();
CHECK_GT(ndim, 0) << "tensor's dim must be more than 0";
Expand All @@ -50,96 +68,20 @@ Tensor Argmin(const Tensor &in_tensor, const int &axis, const bool keep_dims, co
if (output_shape.empty()) {
output_shape.push_back(Expr(1));
}

std::string extern_fun_name;
if (target.arch == common::Target::Arch::NVGPU) {
extern_fun_name.assign("cinn_cuda_");
} else if (target.arch == common::Target::Arch::X86) {
extern_fun_name.assign("cinn_host_");
} else {
LOG(FATAL) << "Argmin only supports X86 and NVGPU ! Please Check.\n";
}
if (true) {
extern_fun_name.append("lt_num_float");
} else {
extern_fun_name.append("gt_num_float");
}

auto res = Compute(
auto sort_index = ArgSort(A, target, pos_axis, true, name + "_index");
auto res = Compute(
output_shape,
[=](const std::vector<Expr> &indices) {
std::vector<Expr> eval_indices(indices);
if (!keep_dims) {
eval_indices.insert(eval_indices.begin() + real_axis, Expr(1));
}
Expr offset(0);
Expr stride(1);
for (int i = 0; i < indices.size(); i++) {
if (i < pos_axis) {
offset = offset * A->shape[i] + indices[i];
} else if (i == pos_axis) {
offset = offset * A->shape[i];
} else {
offset = offset * A->shape[i] + indices[i];
stride = stride * A->shape[i];
}
eval_indices.insert(eval_indices.begin() + real_axis, Expr(0));
} else {
eval_indices[real_axis] = Expr(0);
}
offset = common::AutoSimplify(offset);
stride = common::AutoSimplify(stride);
auto A_shape_axis = A->shape[pos_axis];
return lang::CallExtern(extern_fun_name, {A, A_shape_axis, A(indices), offset, stride});
return sort_index[eval_indices];
},
name);
return res;

auto compute = [=](const std::vector<Expr> &indices) -> Expr {
std::vector<Expr> eval_indices(indices);
if (!keep_dims) {
eval_indices.insert(eval_indices.begin() + real_axis, Expr(1));
}

// Var loop_var("k0");
// eval_indices[real_axis] = i;
// auto value = in_tensor(eval_indices);
// auto update = ir::LT::Make(value, current[1]);
// auto c1 = ir::Select::Make(update, Expr(i), current[0]);
// auto c2 = ir::Select::Make(update, value, current[1]);
// current[0] = c1;
// current[1] = c2;
// auto for_loop = ir::For::Make(i, Expr(0), current[0]);

Placeholder<float> p_min_value("min_value", {shape[real_axis] + 1});
Placeholder<int32_t> p_min_index("min_index", {shape[real_axis] + 1});
auto min_value = ir::Tensor(p_min_value);
auto min_index = ir::Tensor(p_min_index);

Var loop_var("k0", Int(32));
Expr loop_expr = Expr(loop_var);
eval_indices[real_axis] = loop_expr;

auto value = lang::Identity(in_tensor(eval_indices));
CHECK_EQ(min_value->type(), Float(32));
// ir::Store::Make(min_value, Expr(-3.402823e+38f), {Expr(int32_t(0))});

// auto update = ir::GT::Make(value, Expr(0));
auto update = ir::GT::Make(value, ir::Load::Make(min_value, {loop_expr}));
CHECK_EQ(min_index->type(), Int(32));
auto c_v = ir::Select::Make(update, value, ir::Load::Make(min_value, {loop_expr}));
auto c_i = ir::Select::Make(update, loop_expr, ir::Load::Make(min_index, {loop_expr}));

Expr init = ir::Store::Make(min_value, Expr(-3.402823e+38f), {Expr(int32_t(0))});
Expr body1 = ir::Store::Make(min_value, c_v, {loop_expr + 1});
Expr body2 = ir::Store::Make(min_index, c_i, {loop_expr + 1});

Expr body = ir::Block::Make({init, body1, body2});

auto output = ir::For::Make(
loop_var, common::make_const(0), shape[real_axis], ir::ForType::Serial, ir::DeviceAPI::Host, body);

return ir::Load::Make(output, {shape[real_axis]});
};

Tensor res = Compute(output_shape, compute, output_name);
stages->InsertLazily(sort_index);
return res;
}

Expand Down Expand Up @@ -169,7 +111,7 @@ std::shared_ptr<framework::OpStrategy> StrategyForArgmin(const framework::NodeAt
CHECK(in_expr.as_tensor());
Tensor in_tensor = in_expr.as_tensor_ref();
auto stages = CreateStages({in_tensor});
auto out_tensor = Argmin(in_tensor, axis, keep_dims, tensor_name);
auto out_tensor = Argmin(in_tensor, axis, keep_dims, stages, tensor_name);

stages->InsertLazily(out_tensor);
std::vector<CINNValue> cinn_values{CINNValue(out_tensor), CINNValue(stages)};
Expand Down
16 changes: 15 additions & 1 deletion cinn/hlir/op/contrib/argmin.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>
Expand All @@ -24,6 +38,6 @@ ir::Tensor Argmin(const ir::Tensor& A,
const int& axis,
const bool keep_dims = false,
const std::string& output_name = "T_Argmin_out");
} // namespace pe
} // namespace op
} // namespace hlir
} // namespace cinn
18 changes: 16 additions & 2 deletions cinn/hlir/op/contrib/argmin_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/argmin.h"

#include <glog/logging.h>
Expand Down Expand Up @@ -31,8 +45,8 @@ TEST(GenerateCode_Cpu, Argmin_Keep) {
ir::Expr w(28);

lang::Placeholder<float> in("in", {n, in_c, h, w});
lang::Placeholder<float> out("out", {n, out_c, h, w});
ir::Tensor res = Argmin(in, axis, true, "test_argmin_in");
poly::StageMap stages = poly::CreateStages({in});
ir::Tensor res = Argmin(in, axis, true, stages, "test_argmin_in");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
Expand Down

0 comments on commit e0e5238

Please sign in to comment.