Skip to content

Commit

Permalink
Fix a bug in AFT loss for uncensored labels
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Mar 17, 2020
1 parent d2c5c56 commit e33fab1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 28 deletions.
43 changes: 25 additions & 18 deletions src/common/survival_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,33 +141,36 @@ double AFTExtreme::HessPDF(double z) {
}


double AFTLoss::Loss(double y_lower, double y_higher, double y_pred, double sigma) {
double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) {
double pdf;
double cdf_u, cdf_l, z_u, z_l;
double cost;
if (y_lower == y_higher) { // uncensored
z_l = (y_lower - y_pred) / sigma;

const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
if (y_lower == y_upper) { // uncensored
z_l = (log_y_lower - y_pred) / sigma;
pdf = dist_->PDF(z_l);
cost = -std::log(pdf / (sigma * y_lower));
} else { // censored; now check what type of censorship we have
if (std::isinf(y_higher)) { // right-censored
if (std::isinf(y_upper)) { // right-censored
cdf_u = 1;
} else { // left-censored or interval-censored
z_u = (y_higher - y_pred) / sigma;
z_u = (log_y_upper - y_pred) / sigma;
cdf_u = dist_->CDF(z_u);
}
if (std::isinf(y_lower)) { // left-censored
cdf_l = 0;
} else { // right-censored or interval-censored
z_l = (y_lower - y_pred) / sigma;
z_l = (log_y_lower - y_pred) / sigma;
cdf_l = dist_->CDF(z_l);
}
cost = -std::log(cdf_u - cdf_l);
}
return cost;
}

double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double sigma) {
double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
double pdf_l;
double pdf_u;
double pdf;
Expand All @@ -180,25 +183,27 @@ double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double
double gradient;
const double eps = 1e-12f;

if (y_lower == y_higher) { // uncensored
z = (y_lower - y_pred) / sigma;
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
if (y_lower == y_upper) { // uncensored
z = (log_y_lower - y_pred) / sigma;
pdf = dist_->PDF(z);
grad = dist_->GradPDF(z);
gradient = grad / (sigma * pdf);
} else { // censored; now check what type of censorship we have
if (std::isinf(y_higher)) { // right-censored
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
} else { // interval-censored or left-censored
z_u = (y_higher - y_pred) / sigma;
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
cdf_u = dist_->CDF(z_u);
}
if (std::isinf(y_lower)) { // left-censored
pdf_l = 0;
cdf_l = 0;
} else { // interval-censored or right-censored
z_l = (y_lower - y_pred) / sigma;
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
cdf_l = dist_->CDF(z_l);
}
Expand All @@ -208,7 +213,7 @@ double AFTLoss::Gradient(double y_lower, double y_higher, double y_pred, double
return gradient;
}

double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double sigma) {
double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
double z;
double z_u;
double z_l;
Expand All @@ -232,19 +237,21 @@ double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double s
double hess_dist;
const double eps = 1e-12f;

if (y_lower == y_higher) { // uncensored
z = (y_lower - y_pred) / sigma;
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
if (y_lower == y_upper) { // uncensored
z = (log_y_lower - y_pred) / sigma;
pdf = dist_->PDF(z);
grad = dist_->GradPDF(z);
hess_dist = dist_->HessPDF(z);
hessian = -(pdf * hess_dist - std::pow(grad, 2)) / (std::pow(sigma, 2) * std::pow(pdf, 2));
} else { // censored; now check what type of censorship we have
if (std::isinf(y_higher)) { // right-censored
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
grad_u = 0;
} else { // interval-censored or left-censored
z_u = (y_higher - y_pred) / sigma;
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
cdf_u = dist_->CDF(z_u);
grad_u = dist_->GradPDF(z_u);
Expand All @@ -254,7 +261,7 @@ double AFTLoss::Hessian(double y_lower, double y_higher, double y_pred, double s
cdf_l = 0;
grad_l = 0;
} else { // interval-censored or right-censored
z_l = (y_lower - y_pred) / sigma;
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
cdf_l = dist_->CDF(z_l);
grad_l = dist_->GradPDF(z_l);
Expand Down
6 changes: 3 additions & 3 deletions src/common/survival_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class AFTLoss {
}

public:
double Loss(double y_lower, double y_higher, double y_pred, double sigma);
double Gradient(double y_lower, double y_higher, double y_pred, double sigma);
double Hessian(double y_lower, double y_higher, double y_pred, double sigma);
double Loss(double y_lower, double y_upper, double y_pred, double sigma);
double Gradient(double y_lower, double y_upper, double y_pred, double sigma);
double Hessian(double y_lower, double y_upper, double y_pred, double sigma);
};

} // namespace common
Expand Down
2 changes: 1 addition & 1 deletion src/metric/survival_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct EvalAFT : public Metric {
for (omp_ulong i = 0; i < nsize; ++i) {
// If weights are empty, data is unweighted so we use 1.0 everywhere
double w = is_null_weight ? 1.0 : weights[i];
double loss = loss_->Loss(std::log(y_lower[i]), std::log(y_higher[i]),
double loss = loss_->Loss(y_lower[i], y_higher[i],
yhat[i], param_.aft_loss_distribution_scale);
nloglik_sum += loss;
weight_sum += w;
Expand Down
11 changes: 5 additions & 6 deletions src/objective/aft_obj.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,16 @@ class AFTObj : public ObjFunction {
<< "yhat is too big";
const omp_ulong nsize = static_cast<omp_ulong>(yhat.size());
double first_order_grad;
double second_order_grad;

#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < nsize; ++i) {
// If weights are empty, data is unweighted so we use 1.0 everywhere
double w = is_null_weight ? 1.0 : weights[i];
first_order_grad = loss_->Gradient(std::log(y_lower[i]), std::log(y_higher[i]),
const double w = is_null_weight ? 1.0 : weights[i];
const double grad = loss_->Gradient(y_lower[i], y_higher[i],
yhat[i], param_.aft_loss_distribution_scale);
const double hess = loss_->Hessian(y_lower[i], y_higher[i],
yhat[i], param_.aft_loss_distribution_scale);
second_order_grad = loss_->Hessian(std::log(y_lower[i]), std::log(y_higher[i]),
yhat[i], param_.aft_loss_distribution_scale);
gpair[i] = GradientPair(first_order_grad * w, second_order_grad * w);
gpair[i] = GradientPair(grad * w, hess * w);
}
}

Expand Down

0 comments on commit e33fab1

Please sign in to comment.