Skip to content

Commit

Permalink
- decouple the cpu and gpu implementation of metrics
Browse files Browse the repository at this point in the history
  - this is to avoid including the cuda file into the cc file, thus cluttering the code
    with ifdef gimmicks
  - this approach can be used when the 2 implementations digresses fairly considerably
  - it is possible for such an approach to be used elsewhere (objectives and such) where
    the implementation digresses considerably
  - there is no impact to performance
  • Loading branch information
sriramch committed Feb 26, 2020
1 parent 8054f72 commit 8509685
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 352 deletions.
50 changes: 40 additions & 10 deletions src/metric/metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
* \brief Registry of objective functions.
*/
#include <dmlc/registry.h>

#include <xgboost/metric.h>
#include <xgboost/generic_parameters.h>

namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
}
#include "metric_common.h"

namespace xgboost {
Metric* Metric::Create(const std::string& name, GenericParameter const* tparam) {
template <typename MetricRegistry>
Metric* CreateMetricImpl(const std::string& name, GenericParameter const* tparam) {
std::string buf = name;
std::string prefix = name;
const char* param;
Expand All @@ -26,26 +26,56 @@ Metric* Metric::Create(const std::string& name, GenericParameter const* tparam)
prefix = buf;
param = nullptr;
}
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
auto *e = ::dmlc::Registry<MetricRegistry>::Get()->Find(prefix.c_str());
if (e == nullptr) {
LOG(FATAL) << "Unknown metric function " << name;
return nullptr;
}
auto p_metric = (e->body)(param);
p_metric->tparam_ = tparam;
return p_metric;
} else {
std::string prefix = buf.substr(0, pos);
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
auto *e = ::dmlc::Registry<MetricRegistry>::Get()->Find(prefix.c_str());
if (e == nullptr) {
LOG(FATAL) << "Unknown metric function " << name;
return nullptr;
}
auto p_metric = (e->body)(buf.substr(pos + 1, buf.length()).c_str());
p_metric->tparam_ = tparam;
return p_metric;
}
}

Metric *
Metric::Create(const std::string& name, GenericParameter const* tparam) {
auto metric = CreateMetricImpl<MetricReg>(name, tparam);
if (metric == nullptr) {
LOG(FATAL) << "Unknown metric function " << name;
}

metric->tparam_ = tparam;
return metric;
}

Metric *
GPUMetric::CreateGPUMetric(const std::string& name, GenericParameter const* tparam) {
auto metric = CreateMetricImpl<MetricGPUReg>(name, tparam);
if (metric == nullptr) {
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
<< ". Resorting to the CPU builder";
return metric;
}

// Narrowing reference only for the compiler to allow assignment to a base class member.
// As such, using this narrowed reference to refer to derived members will be an illegal op.
// This is moot, as this type is stateless.
static_cast<GPUMetric *>(metric)->tparam_ = tparam;
return metric;
}
} // namespace xgboost

namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg);
}

namespace xgboost {
namespace metric {
// List of files that will be force linked in static links.
Expand Down
45 changes: 45 additions & 0 deletions src/metric/metric_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,53 @@ using PredIndPairContainer = std::vector<PredIndPair>;
} // anonymous namespace

namespace xgboost {
// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
// present already. This is created when there is a device ordinal present and if xgboost
// is compiled with CUDA support
struct GPUMetric : Metric {
static Metric *CreateGPUMetric(const std::string& name, GenericParameter const* tparam);
};

/*!
* \brief Internal registry entries for GPU Metric factory functions.
* The additional parameter const char* param gives the value after @, can be null.
* For example, metric map@3, then: param == "3".
*/
struct MetricGPUReg
: public dmlc::FunctionRegEntryBase<MetricGPUReg,
std::function<Metric * (const char*)> > {
};

/*!
* \brief Macro to register metric computed on GPU.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg")
* .describe("NDCG metric computer on GPU.")
* .set_body([](const char* param) {
* int at_k = atoi(param);
* return new NDCG(at_k);
* });
* \endcode
*/

// Note: Metric names registered in the GPU registry should follow this convention:
// - GPU metric types should be registered with the same name as the non GPU metric types
#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \
::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name)

namespace metric {

// Ranking config to be used on device and host
struct EvalRankConfig {
public:
unsigned topn{std::numeric_limits<unsigned>::max()};
std::string name;
bool minus{false};
};

class PackedReduceResult {
double residue_sum_;
double weights_sum_;
Expand Down
Loading

0 comments on commit 8509685

Please sign in to comment.