-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Enhancements for MXTensor for custom operators #17204
Changes from 6 commits
0c6a300
bfe1195
fa80372
9cb3bd0
85d2879
f8b3035
2eac368
15679f6
bbf8e63
989b394
8d4ea67
789302a
e6cd9ac
3f001aa
c310281
d139686
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ | |
#include <utility> | ||
#include <stdexcept> | ||
|
||
#define MX_LIBRARY_VERSION 1 | ||
#define MX_LIBRARY_VERSION 2 | ||
|
||
/* | ||
* Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h | ||
|
@@ -198,6 +198,7 @@ enum MXDType { | |
kInt32 = 4, | ||
kInt8 = 5, | ||
kInt64 = 6, | ||
kUNSET = 100, | ||
}; | ||
|
||
enum MXReturnValue { | ||
|
@@ -209,10 +210,15 @@ enum MXReturnValue { | |
* \brief Tensor data structure used by custom operator | ||
*/ | ||
struct MXTensor { | ||
MXTensor() : data_ptr(NULL) {} | ||
MXTensor() : data_ptr(NULL), dtype(kUNSET), version(0) {} | ||
|
||
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype) | ||
: data_ptr(data_ptr), shape(shape), dtype(dtype) {} | ||
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype, | ||
size_t ID) | ||
: data_ptr(data_ptr), shape(shape), dtype(dtype), version(ID) {} | ||
|
||
void update(void *dptr, MXDType type, size_t ver) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need this function? it doesn't have any checks, only copy pointers. I think we can copy them line by line in lib_api.h and keep MXTensor as simple as possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per our discussion, lets move the the for loop to copy shape and call to setDLTensor inside this function. Change name to "setTensor" |
||
data_ptr = dptr; dtype = type; version = ver; | ||
} | ||
|
||
/*! \brief populate DLTensor fields */ | ||
void setDLTensor() { | ||
|
@@ -277,6 +283,14 @@ struct MXTensor { | |
return size; | ||
} | ||
|
||
/*! \brief helper function to compare two MXTensors */ | ||
inline bool isSame(const MXTensor &oth) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we override operator==? since we won't support C anyway There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think operator== is confusing. For a tensor object, == usually means value comparison. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comparing object is how c++ is doing for vector and usually for struct, and in NDArray we don't have operators !=, <, > either, so I don't think it is going to be confusing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I agree with @wkcn here, == should compare the values of the tensor not the "state" of the tensor (data_ptr, versionID, etc) @mseth10 @eric-haibin-lin @haojin2 what do you guys think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For C++ vector container, #include <iostream>
#include <vector>
using namespace std;
int main() {
vector<int> a{1,2,3};
vector<int> b{1,2,3};
vector<int> c{1,2,4};
cout << (a == b) << endl; // 1
cout << (a == c) << endl; // 0
cout << (a.data() == b.data()) << endl; // 0
return 0;
} Although |
||
return data_ptr == oth.data_ptr && | ||
dtype == oth.dtype && | ||
version == oth.version && | ||
shape == oth.shape; | ||
} | ||
|
||
// data is flatten 1D repr of tensor, elements are in continuous memory | ||
// user can access each element using the shape of tensor | ||
void *data_ptr; | ||
|
@@ -287,6 +301,9 @@ struct MXTensor { | |
// type can only be MXDType enum types | ||
MXDType dtype; | ||
|
||
// version number updated if the tensor has changed since the last use by custom op | ||
size_t version; | ||
|
||
// corresponding DLTensor repr of MXTensor | ||
// easy way to reuse functions taking DLTensor | ||
DLTensor dltensor; | ||
|
@@ -684,15 +701,9 @@ typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* co | |
|
||
#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" | ||
typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int, | ||
const int64_t**, int*, void**, int*, int, | ||
const int64_t**, int*, void**, int*, int, | ||
xpu_malloc_t, void*); | ||
|
||
#define MXLIB_OPCALLBKWD_STR "_opCallBackward" | ||
typedef int (*opCallBkwd_t)(fcomp_t, const char* const*, const char* const*, int, | ||
const int64_t**, int*, void**, int*, int, | ||
const int64_t**, int*, void**, int*, int, | ||
xpu_malloc_t, void*); | ||
const int64_t**, int*, void**, int*, size_t*, int, | ||
const int64_t**, int*, void**, int*, size_t*, int, | ||
xpu_malloc_t, void*); | ||
|
||
#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" | ||
typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const char* const*, int, | ||
|
@@ -703,9 +714,9 @@ typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const | |
void**); | ||
|
||
#define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" | ||
typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, int, | ||
const int64_t**, int*, void**, int*, int, | ||
xpu_malloc_t, void*); | ||
typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, size_t*, | ||
int, const int64_t**, int*, void**, int*, size_t*, | ||
int, xpu_malloc_t, void*); | ||
|
||
#define MXLIB_INITIALIZE_STR "initialize" | ||
typedef int (*initialize_t)(int); | ||
|
@@ -876,9 +887,9 @@ extern "C" { | |
_opCallFCompute(fcomp_t fcomp, const char* const* keys, | ||
const char* const* vals, int num, | ||
const int64_t** inshapes, int* indims, | ||
void** indata, int* intypes, int num_in, | ||
void** indata, int* intypes, size_t* inIDs, int num_in, | ||
const int64_t** outshapes, int* outdims, | ||
void** outdata, int* outtypes, int num_out, | ||
void** outdata, int* outtypes, size_t* outIDs, int num_out, | ||
xpu_malloc_t cpu_malloc, void* cpu_alloc) { | ||
// create map of attributes from list | ||
std::map<std::string, std::string> attrs; | ||
|
@@ -889,8 +900,7 @@ extern "C" { | |
// create a vector of tensors for inputs | ||
std::vector<MXTensor> inputs(num_in); | ||
for (int i = 0; i < num_in; i++) { | ||
inputs[i].data_ptr = indata[i]; | ||
inputs[i].dtype = (MXDType)intypes[i]; | ||
inputs[i].update(indata[i], (MXDType)intypes[i], inIDs[i]); | ||
for (int j = 0; j < indims[i]; j++) { | ||
inputs[i].shape.push_back(inshapes[i][j]); | ||
} | ||
|
@@ -900,8 +910,7 @@ extern "C" { | |
// create a vector of tensors for outputs | ||
std::vector<MXTensor> outputs(num_out); | ||
for (int i = 0; i < num_out; i++) { | ||
outputs[i].data_ptr = outdata[i]; | ||
outputs[i].dtype = (MXDType) outtypes[i]; | ||
outputs[i].update(outdata[i], (MXDType)outtypes[i], outIDs[i]); | ||
for (int j = 0; j < outdims[i]; j++) { | ||
outputs[i].shape.push_back(outshapes[i][j]); | ||
} | ||
|
@@ -973,15 +982,14 @@ extern "C" { | |
#endif | ||
_opCallFStatefulCompute(bool is_forward, void* state_op, | ||
const int64_t** inshapes, int* indims, | ||
void** indata, int* intypes, int num_in, | ||
void** indata, int* intypes, size_t* inIDs, int num_in, | ||
const int64_t** outshapes, int* outdims, | ||
void** outdata, int* outtypes, int num_out, | ||
void** outdata, int* outtypes, size_t* outIDs, int num_out, | ||
xpu_malloc_t cpu_malloc, void* cpu_alloc) { | ||
// create a vector of tensors for inputs | ||
std::vector<MXTensor> inputs(num_in); | ||
for (int i = 0; i < num_in; i++) { | ||
inputs[i].data_ptr = indata[i]; | ||
inputs[i].dtype = (MXDType)intypes[i]; | ||
inputs[i].update(indata[i], (MXDType)intypes[i], inIDs[i]); | ||
for (int j = 0; j < indims[i]; j++) { | ||
inputs[i].shape.push_back(inshapes[i][j]); | ||
} | ||
|
@@ -991,8 +999,7 @@ extern "C" { | |
// create a vector of tensors for outputs | ||
std::vector<MXTensor> outputs(num_out); | ||
for (int i = 0; i < num_out; i++) { | ||
outputs[i].data_ptr = outdata[i]; | ||
outputs[i].dtype = (MXDType) outtypes[i]; | ||
outputs[i].update(outdata[i], (MXDType)outtypes[i], outIDs[i]); | ||
for (int j = 0; j < outdims[i]; j++) { | ||
outputs[i].shape.push_back(outshapes[i][j]); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it will be better to unify the naming across all places, like using verID in lib_api.h and here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done