Skip to content

Commit

Permalink
Add check of placement constructor (#6991)
Browse files Browse the repository at this point in the history
* add_check_of_placement_constructor

* move CheckDeviceIdsIsValid to runtime

* handle comment

* fix error

* fix error

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
clackhan and oneflow-ci-bot authored Dec 16, 2021
1 parent 367e3c4 commit bbd0d1d
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions oneflow/api/python/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Maybe<one::TensorTuple> Interpret(const one::OpExpr& op,
Maybe<one::TensorTuple> Interpret(const one::OpExpr& op, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const AttrMap& attrs) {
JUST(CheckDeviceIdsIsValid(placement));
CHECK_EQ_OR_RETURN(op.input_size(), 0)
<< " the op : " << op.op_type_name()
<< " is NOT source op with input_size = " << op.input_size();
Expand Down
4 changes: 4 additions & 0 deletions oneflow/api/python/functional/tensor_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ConsistentTensorWithDataFunctor {
const bool& requires_grad) const {
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
JUST(CheckDeviceIdsIsValid(placement));

if (PyTensorCheck(data)) {
// Throw warnings like pytorch.
Expand Down Expand Up @@ -107,6 +108,7 @@ class ConsistentTensorEmptyCtorFunctor {
Maybe<Tensor> operator()(const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
Shape shape(DimVector{0});
JUST(CheckDeviceIdsIsValid(placement));
return ConsistentTensorWithShapeCtor(shape, placement, sbp_tuple);
}
};
Expand Down Expand Up @@ -148,6 +150,7 @@ class ConsistentTensorWithDataCtorFunctor {
public:
Maybe<Tensor> operator()(PyObject* data, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
// Treat the single long as shape.
if (PyLong_Check(data)) {
int64_t size = PyLong_AsLongLong(data);
Expand Down Expand Up @@ -190,6 +193,7 @@ class ConsistentTensorWithShapeCtorFunctor {
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
// NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now.
LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false);
JUST(CheckDeviceIdsIsValid(placement));
return functional::ConsistentEmpty(shape, DType::Float(), placement, sbp_tuple);
}
};
Expand Down
20 changes: 20 additions & 0 deletions oneflow/api/python/symbol/placement_symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <algorithm>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
Expand All @@ -30,13 +31,26 @@ limitations under the License.
#include "oneflow/core/job/placement.cfg.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#ifdef WITH_CUDA
#include <cuda.h>
#endif // WITH_CUDA

namespace py = pybind11;

namespace oneflow {

namespace {

int64_t GetGpuDeviceNum() {
#ifndef WITH_CUDA
return 0;
#else
int device_count = 0;
cudaGetDeviceCount(&device_count);
return device_count;
#endif
}

Maybe<Shape> MakeShape(const py::tuple& py_shape) {
DimVector shape_dims{};
for (const auto& dim : py_shape) { shape_dims.emplace_back(dim.cast<int64_t>()); }
Expand Down Expand Up @@ -150,6 +164,12 @@ struct PlacementSymbolExportUtil {
if (iter == device_tag2placement.end()) {
int64_t node_size = GlobalProcessCtx::NodeSize();
int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode();
if (device_tag == "gpu") {
const int64_t gpu_device_num = GetGpuDeviceNum();
CHECK_NE(gpu_device_num, 0)
<< "Can\'t construct placment with \"cuda\" type because there is no CUDA device!";
device_num = std::min(device_num, gpu_device_num);
}
std::vector<std::string> machine_device_ids;
for (int64_t node_id = 0; node_id < node_size; ++node_id) {
std::string device_name = std::to_string(node_id) + ":0-" + std::to_string(device_num - 1);
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class ConsistentConstantFunctor {
Maybe<Tensor> operator()(const Shape& shape, const Scalar& value, const Symbol<DType>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", dtype->data_type()));
Expand Down Expand Up @@ -210,6 +211,7 @@ class ConsistentEmptyFunctor {
Maybe<Tensor> operator()(const Shape& shape, const Symbol<DType>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", dtype->data_type()));
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/functional/impl/consistent_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class LocalToConsistentFunctor {
Symbol<ParallelDesc> parallel_desc,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_parallels,
const Shape& shape, const Symbol<DType>& dtype) const {
JUST(CheckDeviceIdsIsValid(parallel_desc));
CHECK_OR_RETURN(x->is_local());
std::shared_ptr<one::Tensor> input = x;
// copy to right device first if input's device type is wrong
Expand Down Expand Up @@ -336,6 +337,7 @@ class ToConsistentFunctor {
Symbol<ParallelDesc> parallel_desc,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_parallels,
const std::vector<Symbol<cfg::SbpParallel>>& grad_sbp_parallels) const {
JUST(CheckDeviceIdsIsValid(parallel_desc));
std::shared_ptr<Tensor> tensor;
if (x->is_consistent()) {
tensor = JUST(ConsistentToConsistent(x, parallel_desc, sbp_parallels, grad_sbp_parallels));
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/functional/impl/dataset_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class ReadOneRecFunctor {
JUST(attrs.SetAttr<bool>("verify_example", verify_example));

if (placement.has_value()) {
JUST(CheckDeviceIdsIsValid(JUST(placement)));
CHECK_OR_RETURN(sbp.has_value())
<< "placement is not None, but sbp is None. It's not allowed.";
AttrMap attrmap(attrs);
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ class ConsistentEyeFunctor {
const Optional<Symbol<DType>>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("rows", JUST(rows.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("cols", JUST(cols.value_or(rows).As<int64_t>())));
Expand Down Expand Up @@ -732,6 +733,7 @@ class ConsistentArangeFunctor {
const Optional<Symbol<DType>>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
MutableAttrMap attrs;
if (dtype.has_value()) {
const DataType range_dtype = JUST(dtype)->data_type();
Expand Down Expand Up @@ -781,6 +783,7 @@ class ConsistentArange2Functor {
Maybe<Tensor> operator()(const Scalar& limit, const Symbol<DType>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
return ConsistentArange(Scalar(0), limit, Scalar(1), dtype, placement, sbp_tuple);
}
};
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/functional/impl/random_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class ConsistentRandFunctor {
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
JUST(CheckDeviceIdsIsValid(placement));
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
Expand Down Expand Up @@ -182,6 +183,7 @@ class ConsistentRandNFunctor {
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
JUST(CheckDeviceIdsIsValid(placement));
DataType dtype_val = DataType::kFloat;
if (dtype) { dtype_val = JUST(dtype)->data_type(); }
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
Expand Down Expand Up @@ -269,6 +271,7 @@ class ConsistentRandIntFunctor {
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
JUST(CheckDeviceIdsIsValid(placement));
DataType dtype_val = DataType::kInt64;
if (dtype) { dtype_val = JUST(dtype)->data_type(); }

Expand Down Expand Up @@ -305,6 +308,7 @@ class ConsistentRandInt2Functor {
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
JUST(CheckDeviceIdsIsValid(placement));
return ConsistentRandInt(/*low*/ 0, high, shape, placement, sbp_tuple, dtype, generator,
requires_grad);
}
Expand Down Expand Up @@ -344,6 +348,7 @@ class ConsistentRandPermFunctor {
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<one::Generator>& generator, const Symbol<DType>& dtype,
const bool& requires_grad) const {
JUST(CheckDeviceIdsIsValid(placement));
const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator()));
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("n", n));
Expand Down
51 changes: 51 additions & 0 deletions oneflow/core/job/parallel_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <algorithm>
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/placement.cfg.h"
#include "oneflow/core/common/decorator.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/multi_client.h"
#include "oneflow/core/common/cpp_attribute.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/framework/parallel_conf_util.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/vm/vm_util.h"
#ifdef WITH_CUDA
#include <cuda.h>
#endif // WITH_CUDA

namespace oneflow {

namespace {

int64_t GetGpuDeviceNum() {
#ifndef WITH_CUDA
return 0;
#else
int device_count = 0;
cudaGetDeviceCount(&device_count);
return device_count;
#endif
}

using MachineId2DeviceIdList =
std::shared_ptr<HashMap<int64_t, std::shared_ptr<std::vector<int64_t>>>>;

Expand Down Expand Up @@ -302,6 +318,34 @@ Maybe<void> ParallelDesc::CheckWithResourceDesc(const ResourceDesc& resource_des
return Maybe<void>::Ok();
}

Maybe<void> ParallelDesc::CheckDeviceIdsIsValid() const {
if (likely(JUST(IsMultiClient()))) {
const auto& sorted_dev_phy_ids_iter =
machine_id2sorted_dev_phy_ids_->find(GlobalProcessCtx::Rank());
if (sorted_dev_phy_ids_iter != machine_id2sorted_dev_phy_ids_->end()) {
for (int64_t dev_phy_id : *sorted_dev_phy_ids_iter->second) {
if (device_type_ == DeviceType::kCUDA) {
const int64_t gpu_device_num = GetGpuDeviceNum();
CHECK_NE_OR_RETURN(gpu_device_num, 0)
<< "Placment with \"cuda\" type is invalid because there is no CUDA device!";
int64_t device_num = std::min(GlobalProcessCtx::NumOfProcessPerNode(), gpu_device_num);
CHECK_LT_OR_RETURN(dev_phy_id, device_num)
<< "Placment is invalid because device id must be less than "
<< (gpu_device_num < GlobalProcessCtx::NumOfProcessPerNode()
? "num of CUDA devices on node"
: "num of process per node");
} else if (device_type_ == DeviceType::kCPU) {
CHECK_LT_OR_RETURN(dev_phy_id, GlobalProcessCtx::NumOfProcessPerNode())
<< "Placment is invalid because device id must be less than num of process per node";
} else {
OF_UNIMPLEMENTED();
}
}
}
}
return Maybe<void>::Ok();
}

ParallelConf ParallelDesc::GetParallelIdOnlyParallelConf(int64_t parallel_id) const {
ParallelConf parallel_conf;
std::string rank = std::to_string(CHECK_JUST(MachineId4ParallelId(parallel_id)));
Expand Down Expand Up @@ -456,6 +500,11 @@ Maybe<Symbol<ParallelDesc>> RawTxtStringToPlacement(const std::string& parallel_
return SymbolOf(ParallelDesc(parallel_conf));
}

Maybe<void> RawCheckDeviceIdsIsValid(Symbol<ParallelDesc> placement) {
JUST(placement->CheckDeviceIdsIsValid());
return Maybe<void>::Ok();
}

} // namespace

decltype(GetParallelId4CurrentProcessCtx) GetParallelId4CurrentProcessCtx =
Expand All @@ -467,5 +516,7 @@ decltype(PlacementToString) PlacementToString = DECORATE(&RawPlacementToString,
decltype(GetTensorDevice) GetTensorDevice = DECORATE(&RawGetTensorDevice, ThreadLocal);
decltype(TxtStringToPlacement) TxtStringToPlacement =
DECORATE(&RawTxtStringToPlacement, ThreadLocalCopiable);
decltype(CheckDeviceIdsIsValid) CheckDeviceIdsIsValid =
DECORATE(&RawCheckDeviceIdsIsValid, ThreadLocal);

} // namespace oneflow
2 changes: 2 additions & 0 deletions oneflow/core/job/parallel_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ParallelDesc final {

std::shared_ptr<cfg::ParallelConf> cfg_parallel_conf() const { return cfg_parallel_conf_; }
bool TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const;
Maybe<void> CheckDeviceIdsIsValid() const;

private:
friend Maybe<OFRecord> ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf);
Expand Down Expand Up @@ -149,6 +150,7 @@ extern Maybe<Symbol<ParallelDesc>> (*ReplaceDeviceType)(Symbol<ParallelDesc>, De
extern Maybe<std::string> (*PlacementToString)(Symbol<ParallelDesc> placement);
extern Maybe<Symbol<Device>> (*GetTensorDevice)(Symbol<ParallelDesc> parallel_desc);
extern Maybe<Symbol<ParallelDesc>> (*TxtStringToPlacement)(const std::string& parallel_conf_str);
extern Maybe<void> (*CheckDeviceIdsIsValid)(Symbol<ParallelDesc> placement);

inline bool operator==(const ParallelConf& lhs, const ParallelConf& rhs) {
return ParallelDesc(lhs) == ParallelDesc(rhs);
Expand Down

0 comments on commit bbd0d1d

Please sign in to comment.