Skip to content

Commit

Permalink
Support column major array. (#6765)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 19, 2021
1 parent f6fe15d commit 4ee8340
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 151 deletions.
1 change: 0 additions & 1 deletion python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def _transform_cupy_array(data):
data, '__array__'):
import cupy # pylint: disable=import-error
data = cupy.array(data, copy=False)
data = data.astype(dtype=data.dtype, order='C', copy=False)
return data


Expand Down
37 changes: 25 additions & 12 deletions src/data/adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,22 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
class Line {
ArrayInterface array_interface_;
size_t ridx_;

public:
Line(ArrayInterface array_interface, size_t ridx)
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}

size_t Size() const { return array_interface_.num_cols; }

COOTuple GetElement(size_t idx) const {
return {ridx_, idx, array_interface_.GetElement(idx)};
return {ridx_, idx, array_interface_.GetElement(ridx_, idx)};
}
};

public:
ArrayAdapterBatch() = default;
Line const GetLine(size_t idx) const {
auto line = array_interface_.SliceRow(idx);
return Line{line, idx};
return Line{array_interface_, idx};
}

explicit ArrayAdapterBatch(ArrayInterface array_interface)
Expand Down Expand Up @@ -286,14 +286,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indices_;
ArrayInterface values_;
size_t ridx_;
size_t offset_;

public:
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
Line(ArrayInterface indices, ArrayInterface values, size_t ridx,
size_t offset)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
offset_{offset} {}

COOTuple GetElement(size_t idx) const {
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
return {ridx_, indices_.GetElement<size_t>(offset_ + idx, 0),
values_.GetElement(offset_ + idx, 0)};
}

size_t Size() const {
return values_.num_rows * values_.num_cols;
}
Expand All @@ -304,7 +309,11 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values)
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {}
values_{std::move(values)} {
indptr_.AsColumnVector();
values_.AsColumnVector();
indices_.AsColumnVector();
}

size_t Size() const {
size_t size = indptr_.num_rows * indptr_.num_cols;
Expand All @@ -313,15 +322,19 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
}

Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx);
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
auto indices = indices_.SliceOffset(begin_offset);
auto values = values_.SliceOffset(begin_offset);
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
auto end_offset = indptr_.GetElement<size_t>(idx + 1, 0);

auto indices = indices_;
auto values = values_;

values.num_cols = end_offset - begin_offset;
values.num_rows = 1;

indices.num_cols = values.num_cols;
indices.num_rows = values.num_rows;
return Line{indices, values, idx};

return Line{indices, values, idx, begin_offset};
}
};

Expand Down
139 changes: 53 additions & 86 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_

#include <algorithm>
#include <cinttypes>
#include <map>
#include <string>
Expand Down Expand Up @@ -40,7 +41,7 @@ struct ArrayInterfaceErrors {
return str.c_str();
}
static char const* Version() {
return "Only version 1 of `__cuda_array_interface__' is supported.";
return "Only version 1 and 2 of `__cuda_array_interface__' are supported.";
}
static char const* OfType(std::string const& type) {
static std::string str;
Expand Down Expand Up @@ -191,43 +192,46 @@ class ArrayInterfaceHandler {
std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape"));
auto typestr = get<String const>(column.at("typestr"));
if (column.find("strides") != column.cend()) {
if (!IsA<Null>(column.at("strides"))) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), j_shape.size())
<< ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
<< ArrayInterfaceErrors::Contigious();
}
}

if (j_shape.size() == 1) {
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
} else {
CHECK_EQ(j_shape.size(), 2)
<< "Only 1D or 2-D arrays currently supported.";
CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported.";
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
}
}

template <typename T>
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
Validate(column);

auto typestr = get<String const>(column.at("typestr"));
CHECK_EQ(typestr.at(1), TypeChar<T>())
<< "Input data type and typestr mismatch. typestr: " << typestr;
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
<< "Input data type and typestr mismatch. typestr: " << typestr;

auto shape = ExtractShape(column);
static void ExtractStride(std::map<std::string, Json> const &column,
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
auto strides_it = column.find("strides");
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
// default strides
strides[0] = cols;
strides[1] = 1;
} else {
// strides specified by the array interface
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
size_t n = 1;
if (j_strides.size() == 2) {
n = get<Integer const>(j_strides[1]) / itemsize;
}
strides[1] = n;
}
auto valid = (rows - 1) * strides[0] + (cols - 1) * strides[1] == (rows * cols) - 1;
CHECK(valid) << "Invalid strides in array.";
}

T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
static void* ExtractData(std::map<std::string, Json> const &column,
StringView typestr,
std::pair<size_t, size_t> shape) {
Validate(column);
void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
if (!p_data) {
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
}
return common::Span<T>{p_data, shape.first * shape.second};
return p_data;
}
};

Expand All @@ -236,11 +240,15 @@ class ArrayInterface {
void Initialize(std::map<std::string, Json> const &column,
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
auto typestr = get<String const>(column.at("typestr"));
this->AssignType(StringView{typestr});

auto shape = ArrayInterfaceHandler::ExtractShape(column);
num_rows = shape.first;
num_cols = shape.second;

data = ArrayInterfaceHandler::ExtractData(column, StringView{typestr}, shape);

if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
Expand All @@ -257,8 +265,8 @@ class ArrayInterface {
<< "Masked array is not yet supported.";
}

auto typestr = get<String const>(column.at("typestr"));
this->AssignType(StringView{typestr});
ArrayInterfaceHandler::ExtractStride(column, strides, num_rows, num_cols,
typestr[2] - '0');
}

public:
Expand Down Expand Up @@ -288,6 +296,15 @@ class ArrayInterface {
}
}

void AsColumnVector() {
CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix.";
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
num_cols = 1;

strides[0] = std::max(strides[0], strides[1]);
strides[1] = 1;
}

void AssignType(StringView typestr) {
if (typestr[1] == 'f' && typestr[2] == '4') {
type = kF4;
Expand Down Expand Up @@ -320,95 +337,45 @@ class ArrayInterface {
switch (type) {
case kF4:
return func(reinterpret_cast<float *>(data));
break;
case kF8:
return func(reinterpret_cast<double *>(data));
break;
case kI1:
return func(reinterpret_cast<int8_t *>(data));
break;
case kI2:
return func(reinterpret_cast<int16_t *>(data));
break;
case kI4:
return func(reinterpret_cast<int32_t *>(data));
break;
case kI8:
return func(reinterpret_cast<int64_t *>(data));
break;
case kU1:
return func(reinterpret_cast<uint8_t *>(data));
break;
case kU2:
return func(reinterpret_cast<uint16_t *>(data));
break;
case kU4:
return func(reinterpret_cast<uint32_t *>(data));
break;
case kU8:
return func(reinterpret_cast<uint64_t *>(data));
break;
}
SPAN_CHECK(false);
return func(reinterpret_cast<uint64_t *>(data));
}

XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values{nullptr};
this->DispatchCall([&p_values, offset](auto *ptr) {
p_values = ptr + offset;
});

ArrayInterface ret = *this;
ret.data = p_values;
return ret;
}

XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
size_t offset = idx * num_cols;
auto ret = this->SliceOffset(offset);
ret.num_rows = 1;
return ret;
}

template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows);
switch (type) {
case kF4:
return reinterpret_cast<float*>(data)[idx];
case kF8:
return reinterpret_cast<double*>(data)[idx];
case kI1:
return reinterpret_cast<int8_t*>(data)[idx];
case kI2:
return reinterpret_cast<int16_t*>(data)[idx];
case kI4:
return reinterpret_cast<int32_t*>(data)[idx];
case kI8:
return reinterpret_cast<int64_t*>(data)[idx];
case kU1:
return reinterpret_cast<uint8_t*>(data)[idx];
case kU2:
return reinterpret_cast<uint16_t*>(data)[idx];
case kU4:
return reinterpret_cast<uint32_t*>(data)[idx];
case kU8:
return reinterpret_cast<uint64_t*>(data)[idx];
}
SPAN_CHECK(false);
return reinterpret_cast<float*>(data)[idx];
}

XGBOOST_DEVICE size_t ElementSize() {
return this->DispatchCall([](auto* p_values) {
return sizeof(std::remove_pointer_t<decltype(p_values)>);
});
}

template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
return this->DispatchCall(
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
}

RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
size_t strides[2]{0, 0};
void* data;

Type type;
Expand Down
Loading

0 comments on commit 4ee8340

Please sign in to comment.