Skip to content

Commit

Permalink
add rocm schedules to topi C++ (#4507)
Browse files Browse the repository at this point in the history
This imports the CUDA schedules to rocm.
  • Loading branch information
t-vi authored and masahi committed Dec 12, 2019
1 parent 40f1886 commit fd6560e
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 0 deletions.
65 changes: 65 additions & 0 deletions topi/include/topi/rocm/injective.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file rocm/injective.h
* \brief rocm schedule for injective operations
*/
#ifndef TOPI_ROCM_INJECTIVE_H_
#define TOPI_ROCM_INJECTIVE_H_

#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"

#include "topi/cuda/injective.h"

namespace topi {
using namespace tvm;

namespace rocm {

/*!
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
return topi::cuda::schedule_injective_from_existing(sch, out);
}

/*!
* \brief Create a rocm schedule for the given output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_injective(target, outs);
}

} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_INJECTIVE_H_
66 changes: 66 additions & 0 deletions topi/include/topi/rocm/pooling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file rocm/pooling.h
* \brief rocm schedule for pooling operations
*/
#ifndef TOPI_ROCM_POOLING_H_
#define TOPI_ROCM_POOLING_H_

#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "topi/detail/array_utils.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"

#include "topi/cuda/pooling.h"

namespace topi {
using namespace tvm;

namespace rocm {

/*!
* \brief Create a rocm schedule for pool
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_pool(target, outs);
}

/*!
* \brief Create a rocm schedule for global_pool
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_global_pool(target, outs);
}

} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_POOLING_H_
52 changes: 52 additions & 0 deletions topi/include/topi/rocm/reduction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file rocm/reduction.h
* \brief rocm schedule for reduction operations
*/
#ifndef TOPI_ROCM_REDUCTION_H_
#define TOPI_ROCM_REDUCTION_H_

#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"

#include "topi/cuda/reduction.h"

namespace topi {
using namespace tvm;

namespace rocm {
/*!
* \brief Create a rocm schedule for a reduce operation.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_reduce(const Target& target, Array<Tensor> outs) {
return topi::cuda::schedule_reduce(target, outs);
}

} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_REDUCTION_H_
53 changes: 53 additions & 0 deletions topi/include/topi/rocm/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file rocm/injective.h
* \brief ROCM schedule for injective operations
*/
#ifndef TOPI_ROCM_SOFTMAX_H_
#define TOPI_ROCM_SOFTMAX_H_

#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"

#include "topi/cuda/softmax.h"

namespace topi {
using namespace tvm;

namespace rocm {

/*!
* \brief Create a rocm schedule for the given softmax output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_softmax(target, outs);
}

} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_SOFTMAX_H_
34 changes: 34 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
#include <topi/x86/injective.h>

#include <topi/rocm/dense.h>
#include <topi/rocm/injective.h>
#include <topi/rocm/pooling.h>
#include <topi/rocm/reduction.h>
#include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h>

namespace topi {
Expand Down Expand Up @@ -638,6 +642,36 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
*rv = topi::rocm::schedule_dense(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_pool(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_global_pool(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_reduce(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_softmax(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_lrn(args[0], args[1]);
Expand Down

0 comments on commit fd6560e

Please sign in to comment.