Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
force constructor to has init values
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Apr 16, 2020
1 parent 601bab4 commit 8eb3add
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 90 deletions.
14 changes: 4 additions & 10 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ struct DataIteratorReg
*/
class Dataset {
public:
/*!
* \brief Initialize the Operator by setting the parameters
* This function need to be called before all other functions.
* \param kwargs the keyword arguments parameters
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*!
* \brief Get the size of the dataset
*/
Expand All @@ -154,7 +148,8 @@ class Dataset {
}; // class Dataset

/*! \brief typedef the factory function of dataset */
typedef std::function<Dataset *()> DatasetFactory;
typedef std::function<Dataset *(
const std::vector<std::pair<std::string, std::string> >&)> DatasetFactory;
/*!
* \brief Registry entry for Dataset factory functions.
*/
Expand Down Expand Up @@ -184,8 +179,6 @@ class BatchifyFunction {
public:
/*! \brief Destructor */
virtual ~BatchifyFunction(void) {}
/*! \brief Init */
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
/*! \brief The batchify logic */
virtual bool Batchify(const std::vector<std::vector<NDArray> >& inputs,
std::vector<NDArray>* outputs) = 0;
Expand All @@ -194,7 +187,8 @@ class BatchifyFunction {
using BatchifyFunctionPtr = std::shared_ptr<BatchifyFunction>;

/*! \brief typedef the factory function of data sampler */
typedef std::function<BatchifyFunction *()> BatchifyFunctionFactory;
typedef std::function<BatchifyFunction *(
const std::vector<std::pair<std::string, std::string> >&)> BatchifyFunctionFactory;
/*!
* \brief Registry entry for DataSampler factory functions.
*/
Expand Down
6 changes: 2 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1984,12 +1984,11 @@ int MXDatasetCreateDataset(DatasetCreator handle,
Dataset *dataset = nullptr;
API_BEGIN();
DatasetReg *e = static_cast<DatasetReg *>(handle);
dataset = e->body();
std::vector<std::pair<std::string, std::string> > kwargs;
for (uint32_t i = 0; i < num_param; ++i) {
kwargs.push_back({std::string(keys[i]), std::string(vals[i])});
}
dataset->Init(kwargs);
dataset = e->body(kwargs);
*out = new std::shared_ptr<Dataset>(dataset);
API_END_HANDLE_ERROR(delete dataset);
}
Expand Down Expand Up @@ -2073,12 +2072,11 @@ int MXBatchifyFunctionCreateFunction(BatchifyFunctionCreator handle,
BatchifyFunction *bf = nullptr;
API_BEGIN();
BatchifyFunctionReg *e = static_cast<BatchifyFunctionReg *>(handle);
bf = e->body();
std::vector<std::pair<std::string, std::string> > kwargs;
for (uint32_t i = 0; i < num_param; ++i) {
kwargs.push_back({std::string(keys[i]), std::string(vals[i])});
}
bf->Init(kwargs);
bf = e->body(kwargs);
*out = new BatchifyFunctionPtr(bf);
API_END_HANDLE_ERROR(delete bf);
}
Expand Down
18 changes: 9 additions & 9 deletions src/io/batchify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ DMLC_REGISTER_PARAMETER(GroupBatchifyParam);

class GroupBatchify : public BatchifyFunction {
public:
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
explicit GroupBatchify(const std::vector<std::pair<std::string, std::string> >& kwargs) {
param_.InitAllowUnknown(kwargs);
fs_.reserve(param_.functions.ndim());
for (int i = 0; i < param_.functions.ndim(); ++i) {
Expand Down Expand Up @@ -92,8 +92,8 @@ MXNET_REGISTER_IO_BATCHIFY_FUNCTION(GroupBatchify)
.describe(R"code(Returns the GroupBatchify function.
)code" ADD_FILELINE)
.add_arguments(GroupBatchifyParam::__FIELDS__())
.set_body([]() {
return new GroupBatchify();
.set_body([](const std::vector<std::pair<std::string, std::string> >& kwargs) {
return new GroupBatchify(kwargs);
});

struct StackBatchifyParam : public dmlc::Parameter<StackBatchifyParam> {
Expand All @@ -110,7 +110,7 @@ DMLC_REGISTER_PARAMETER(StackBatchifyParam);

class StackBatchify : public BatchifyFunction {
public:
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
explicit StackBatchify(const std::vector<std::pair<std::string, std::string> >& kwargs) {
param_.InitAllowUnknown(kwargs);
}

Expand Down Expand Up @@ -190,8 +190,8 @@ MXNET_REGISTER_IO_BATCHIFY_FUNCTION(StackBatchify)
.describe(R"code(Returns the StackBatchify function.
)code" ADD_FILELINE)
.add_arguments(StackBatchifyParam::__FIELDS__())
.set_body([]() {
return new StackBatchify();
.set_body([](const std::vector<std::pair<std::string, std::string> >& kwargs) {
return new StackBatchify(kwargs);
});

struct PadBatchifyParam : public dmlc::Parameter<PadBatchifyParam> {
Expand All @@ -216,7 +216,7 @@ DMLC_REGISTER_PARAMETER(PadBatchifyParam);

class PadBatchify : public BatchifyFunction {
public:
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
explicit PadBatchify(const std::vector<std::pair<std::string, std::string> >& kwargs) {
param_.InitAllowUnknown(kwargs);
}

Expand Down Expand Up @@ -390,8 +390,8 @@ MXNET_REGISTER_IO_BATCHIFY_FUNCTION(PadBatchify)
.describe(R"code(Returns the StackBatchify function.
)code" ADD_FILELINE)
.add_arguments(PadBatchifyParam::__FIELDS__())
.set_body([]() {
return new PadBatchify();
.set_body([](const std::vector<std::pair<std::string, std::string> >& kwargs) {
return new PadBatchify(kwargs);
});
} // namespace io
} // namespace mxnet
Loading

0 comments on commit 8eb3add

Please sign in to comment.