Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unstructured (auto)correlation terms #1435

Merged
merged 24 commits into from
Dec 14, 2022
Merged

Unstructured (auto)correlation terms #1435

merged 24 commits into from
Dec 14, 2022

Conversation

paul-buerkner
Copy link
Owner

@paul-buerkner paul-buerkner commented Dec 7, 2022

This PR implements unstructure autocorrelation matrix terms as discussed with @wds15. May I ask you to check out and test this new feature?

Here is a simple example:

library(brms)
fit <- brm(count ~ Trt + unstr(visit, patient), data = epilepsy)
summary(fit)

@wds15
Copy link
Contributor

wds15 commented Dec 7, 2022

Great to see this feature landing in brms... will test these days and let you know my findings.

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

Working all well so far. I tried homogeneous and also figured how to get to the heterogeneous version. Nice. One fail I found is with threading:

image

I am still trying out things.

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

The unstructured works for factor levels... but it seems as if the drop_unused_levels=FALSE option is ignored. See the example where the fixed effects are correctly being sampled for a factor level not in the data (but defined as part of the factor) whereas the unstructured term is missing the respective visit:

image

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

One small implementation note: You seem to pass in index spans over the data with being_tg and end_tg. These indicate start / stop indexes in the data-set per patient. This works fine as is...it's just that I switched to a more compact representation of such a thing in my own code. So instead of doing

begin: 1, 5, 7
end: 4, 6, 9

for coding 3 groups (1-4, 5-6, 7-9) you can have a single slice variable defined as:

slice: 1, 5, 7, 9

then you index group 1 with slice[1] : slice[2]-1 and group i is indexed with slice[I] : slice[i + 1]-1... so slice is an array of length n+1 if you have n groups.

There is no real gain from doing other than nicer data-structures and things are more compact.

@paul-buerkner
Copy link
Owner Author

thanks for checking it out! Threading is not supposed to work at the moment, and I have added an informative error message. drop_unused_levels now works as expected.

I agree the slice representation would be cleaner. I will add a todo for this and adjust when I have time.

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

Performancewise it really hurts to not vectorise the calls. That is tedious to do in this case as you'd need to treat each pattern of data in a different block (one case for all being present and one for a certain missingness pattern)... but it would be very worth the effort of coding.

As it's written right now we get quite large AD tapes... reduce_sum would split these into smaller chunks leading to better memory locality. So reduce_sum itself should lead to speed-ups is what I'd expect here... but it's not trivial with the index wrangling in this case, of course.

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

and in a few places the comment is not correct:

  /* multi-student-t log-PDF for time-series covariance structures
   * in Cholesky parameterization and assuming heterogenous variances
   * Args:
   *   y: response vector
   *   nu: degrees of freedom parameter
   *   mu: mean parameter vector
   *   sigma: scale parameter
   *   chol_cor: cholesky factor of the correlation matrix
   *   se2: square of user defined standard errors
   *     should be set to zero if none are defined
   *   nobs: number of observations in each group
   *   begin: the first observation in each group
   *   end: the last observation in each group
   * Returns:
   *   sum of the log-PDF values of all observations
   */
  real student_t_time_hom_lpdf(vector y, real nu, vector mu, real sigma,
                               matrix chol_cor, data vector se2, int[] nobs,
                               int[] begin, int[] end) {

here the homogeneous function is referred to as heterogeneous.

@paul-buerkner
Copy link
Owner Author

Thanks! I agree reduce_sum could be useful here, but I don't have the time to work on this in the foreseable future, given the complexity of this feature, as you say.

I have fixed the doc typos. Are there more things you want to check or try out before me merging this PR?

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

Björn wanted run some tests these days. I am from my end good for now... looking at the Stan code there are some things which I may try to make things faster, but that can be done later. Maybe I even get my head into reduce_sum stuff, but that's to be seen as the brms code for that isn't too simple as you know.

Unless you want to merge quickly, maybe have Björn test it once more (he intends to couple this with a non-linear model).

@paul-buerkner
Copy link
Owner Author

Perfect, sounds good. I will wait for Björns feedback then.

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

Uh... now I see why reduce_sum will be a mess: You cannot split the data at any data rows. The splits must be aligned with the groupings/visits. Yack.

@paul-buerkner
Copy link
Owner Author

paul-buerkner commented Dec 8, 2022 via email

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

Using the idea of re-using the cholesky factor when it is possible speeds up on my mini example things by a factor of 2:

  int is_equal(int[] a, int[] b) {
    int n_a = size(a);
    int n_b = size(a);
    if(n_a != n_b) return 0;
    for(i in 1:n_a) {
      if(a[i] != b[i])
        return 0;
    }
    return 1;
  }
  real normal_time_hom_flex_lpdf(vector y, vector mu, real sigma, matrix chol_cor,
                                 data vector se2, int[] nobs, int[] begin,
                                 int[] end, int[,] Jtime) {
    int I = size(nobs);
    int has_se = max(se2) > 0;
    vector[I] lp;
    int have_lp[I] = rep_array(0, I);
    matrix[rows(chol_cor), cols(chol_cor)] Cov = sigma^2 * multiply_lower_tri_self_transpose(chol_cor);
    int i = 1;
    while(sum(have_lp) != I) {
      int iobs[nobs[i]] = Jtime[i, 1:nobs[i]];
      matrix[nobs[i], nobs[i]] L;
      if (has_se) {
        // need to add 'se' to the covariance matrix itself
        L = cholesky_decompose(add_diag(Cov, se2[begin[i]:end[i]]));
      } else {
        L = diag_pre_multiply( rep_vector(sigma, nobs[i]), chol_cor[iobs, iobs]);
      }
      lp[i] = multi_normal_cholesky_lpdf(y[begin[i]:end[i]] | mu[begin[i]:end[i]], L);
      have_lp[i] = 1;
      // find all additional cases where we have the same visit pattern
      for(j in i:I) {
        if(is_equal(Jtime[j], Jtime[i]) == 1 && have_lp[j] == 0) {
          have_lp[j] = 1;
          if (has_se) {
            // need to add 'se' to the covariance matrix itself
            L = cholesky_decompose(add_diag(Cov, se2[begin[i]:end[i]]));
          }
          lp[j] = multi_normal_cholesky_lpdf(y[begin[j]:end[j]] | mu[begin[j]:end[j]], L);
        }
      }
      while(have_lp[i] == 1 && i != I) i += 1;
    }
    return sum(lp);
  }

I hope this code covers the full generality of things. The speedup of 2 is obviously for the case when not using a fixed se... as then you have to do the cholesky decomposition every data row...

@wds15
Copy link
Contributor

wds15 commented Dec 8, 2022

The above logic can likely go to the other cases as well. And add_diag should be used instead of adding diagonal matrices. So instead of

      if (has_se) {
        // need to add 'se' to the covariance matrix itself
        L = multiply_lower_tri_self_transpose(L);
        L += diag_matrix(se2[begin[i]:end[i]]);
        L = cholesky_decompose(L);
      }

it's better to do

      if (has_se) {
        // need to add 'se' to the covariance matrix itself
        L = add_diag(multiply_lower_tri_self_transpose(L), se2[begin[i]:end[i]]);
        L = cholesky_decompose(L);
      }

This way you avoid putting all those 0's onto the AD stack... which is actually not too much of a problem here, since the se's are data. Still, better to get used to add_diag, which is in stan since 2.21... which is ok?

@wds15
Copy link
Contributor

wds15 commented Dec 12, 2022

On the possibility to parallelise this I did a bit of investigation and here is what I think could work in generality even in brms hopefully. The key issue is that we have to break down things by groups as defined in the data. However, the current pdf functions already only access the (sorted) data via a few index containers. Hence, we can simply subset the index containers and then reuse the existing lpdf and all will work out just fine. I did not know how to modify the partial sum function generated from brms, which is why I introduced in the model block a parallel flavour of the lpdf like so:

  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    mu += Intercept + Xc * b;
    target += normal_time_hom_flex_parallel_lpdf(Y | mu, sigma, Lcortime, se2, nobs_tg, begin_tg, end_tg, Jtime_tg);
  }

Then I added these two functions:

  real normal_time_hom_flex_partial_lpmf(int[] seq, int start, int end,
                                         data vector y, vector mu, real sigma, matrix chol_cor,
                                         data vector se2, int[] nobs_tg, int[] begin_tg,
                                         int[] end_tg, int[,] Jtime_tg) {
    // subset the indexing things to the groups we take the partial
    // sum over
    return normal_time_hom_flex_lpdf(y| mu, sigma, chol_cor, se2, nobs_tg[seq], begin_tg[seq], end_tg[seq], Jtime_tg[seq]);
  }
  real normal_time_hom_flex_parallel_lpdf(data vector y, vector mu, real sigma, matrix chol_cor,
                                          data vector se2, int[] nobs_tg, int[] begin_tg,
                                          int[] end_tg, int[,] Jtime_tg) {
    int Ng = size(nobs_tg);
    int seq[Ng] = sequence(1, Ng);
    return reduce_sum(normal_time_hom_flex_partial_lpmf, seq, 50, y, mu, sigma, chol_cor, se2, nobs_tg, begin_tg, end_tg, Jtime_tg);
  }

Note that the normal_time_hom_flex_lpdf is not changed at all. The trick is really to pass it the full data while just subsetting the index containers which code how the data is being accessed. From running an example model, I do get the same posterior back as compared to the serial version. For now I also hard-coded the grain size of 50 in here... but ok.

The above approach could possibly even be integrated into brms by way of defining additional parallel variants of the respective unstructured covariance structures lpdfs... which one could likely auto-generate.

Attached is the full Stan model and an example data set for it as json:

brms-mmrm-dev-rs-special_json.txt
brms-mmrm-dev-rs-special_stan.txt

... and I must say that the performance seems to really benefit from this. On this very small example 2 cores almost half the wall time spend.

If not in this PR, then maybe this should be filed as an issue for the future?

@wds15
Copy link
Contributor

wds15 commented Dec 12, 2022

And one q from Björn was if it is possible to have the covariance matrix vary by treatment group? So something like:

library(brms)
fit <- brm(count ~ Trt + unstr(visit | Trt, patient), data = epilepsy)
summary(fit)

I can see how I can vary the sigma by treatment group and possibly even by visit, but not the correlation matrix.

matrix[nobs[i], nobs[i]] L;
L = sigma * chol_cor[1:nobs[i], 1:nobs[i]];
// need to add 'se' to the covariance matrix itself
L = multiply_lower_tri_self_transpose(L);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this homogeneous case you can pull the sigma * L and the multiply_lower_tri_self_transpose out of the loop. This is always the same. Just subset the respective covariance matrix for each row to what is observed is what is needed per data group.

Cor = multiply_lower_tri_self_transpose(chol_cor);
for (i in 1:I) {
int iobs[nobs[i]] = Jtime[i, 1:nobs[i]];
matrix[nobs[i], nobs[i]] Cov = sigma^2 * Cor[iobs, iobs];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiplication with sigma^2 can happen outside of the loop

}
return err;
}
matrix[nobs[i], nobs[i]] L = cholesky_decompose(Cor[iobs, iobs]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhh..hurts to read this... maybe we can have an if statement for identifying the case whenever iobs[i] == 1,2,3,...,rows(chol_cor)? In that case you can just take L=chol_cor and can avoid a full cholesky_decompose, which is super expensive and this is probably even a rather common case?

@wds15
Copy link
Contributor

wds15 commented Dec 13, 2022

Stan code looks good. I left a few more small optimisations as comment. I also opened two more issues for following up on the points we came across in this PR (parallel & by group error structures).

Re-using the cholesky makes my mini examples run substantially faster already. 12s instead of 20s, which is great news for real big problems.

Given that we now process multiple data items at once in the cholesky reuse thing, one could imagine to even go for a vectorised call... but I don't see immediately how to code that efficiently (you need to munge data, which is impossible to do easily of the top of my head). In case there is a good way, we can come back to this.

@wds15
Copy link
Contributor

wds15 commented Dec 13, 2022

Alright... right after sending off my last comment, I had the idea how we can make use of a fully vectorized call across all grouped observations. The speedup is another 2x for me! So it's really worth doing this.

The problem is that the data is not quite formatted as we need it. While brms has things stored in a flat column vector, the vectorized version of the multi normal distributions require that we have things as array of vectors. To reformat things efficiently one has to be very careful so that data stays as data to Stan. The way to achieve that is to do the reformatting in a separate function. This way the separate function gets called with data stuff only as arguments and then we can do the data munging "for free" in case of the data. Here is on of the functions recoded this way:

  /* grouped data stored linearly in "data" as indexed by begin and end
   * is repacked to be stacked into an array of vectors.
   * end
   */
  vector[] stack_vectors(vector long_data, int n, int[] stack, int[] begin, int[] end) {
    int S = sum(stack);
    int G = size(stack);
    vector[n] stacked[S];
    int j = 1;
    for(i in 1:G) {
      if(stack[i] == 1) {
        stacked[j] = long_data[begin[i]:end[i]];
        j += 1;
      }
    }
    return stacked;
  }
  
  /* multi-normal log-PDF for time-series covariance structures
   * in Cholesky parameterization and assuming homogoneous variances
   * allows for flexible correlation matrix subsets
   * Deviating Args:
   *   Jtime: array of time indices per group
   * Returns:
   *   sum of the log-PDF values of all observations
   */
  real normal_time_hom_flex_lpdf(vector y, vector mu, real sigma, matrix chol_cor,
                                 int[] nobs, int[] begin, int[] end, int[,] Jtime) {
    int I = size(nobs);
    //vector[I] lp;
    real lp = 0.0;
    int has_lp[I] = rep_array(0, I);
    int i = 1;
    while (sum(has_lp) != I) {
      int iobs[nobs[i]] = Jtime[i, 1:nobs[i]];
      int lp_terms[I-i+1] = rep_array(0, I-i+1);
      matrix[nobs[i], nobs[i]] L = diag_pre_multiply(rep_vector(sigma, nobs[i]), chol_cor[iobs, iobs]);
      //lp[i] = multi_normal_cholesky_lpdf(y[begin[i]:end[i]] | mu[begin[i]:end[i]], L);
      has_lp[i] = 1;
      lp_terms[1] = 1;
      // find all additional groups where we have the same timepoints
      for (j in (i+1):I) {
        if (has_lp[j] == 0 && is_equal(Jtime[j], Jtime[i]) == 1) {
          //lp[j] = multi_normal_cholesky_lpdf(y[begin[j]:end[j]] | mu[begin[j]:end[j]], L);
          has_lp[j] = 1;
          lp_terms[j-i+1] = 1;
        }
      }
      lp += multi_normal_cholesky_lpdf(stack_vectors(y, nobs[i], lp_terms, begin[i:I], end[i:I]) | stack_vectors(mu, nobs[i], lp_terms, begin[i:I], end[i:I]), L);
      while (has_lp[i] == 1 && i != I) {
        i += 1;
      }
    }
    return lp;
  }

The above logic can be applied to all other functions... and again - the simulated fake example speeds up by a whopping factor of 2x.

Edit: This trick certainly only applies to the homogeneous sigma cases.

@paul-buerkner
Copy link
Owner Author

paul-buerkner commented Dec 14, 2022

Thanks for the code review and the vectorization code! I have now made another pass over all the Stan code, made it more consistent, and further optimized things as per your suggestions. @wds15 can you do another round of code review please?

One thing we need to be very careful of is not to take arbitrary subsets of cholesky factors of correlation matrices. This is only valid if we subset the first J rows:

Cor <- matrix(c(1, 0.5, 0.3, 0.5, 1, 0.1, 0.3, 0.1, 1), 3, 3)
Cor

L <- t(chol(Cor))
L %*% t(L)  # = Cor

L_i <- L[1:2, 1:2]
L_i %*% t(L_i)  # = Cor[1:2, 1:2]

L_j <- L[c(2, 3), c(2, 3)]
L_j %*% t(L_j)  # != Cor[c(2, 3), c(2, 3)]

That is, patterns such as chol_cor[iobs, iobs] are not valid in general.

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

Thanks for including these latest bits as well...sorry for throwing these things one at a time at you...but this thing develops. Will check the stan code now.

@paul-buerkner
Copy link
Owner Author

paul-buerkner commented Dec 14, 2022

No worries, we are continuously making the implementation better so that makes me happy :-)

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

There was a mistake in my is_equal... here it's corrected:

int is_equal(int[] a, int[] b) {
    int n_a = size(a);
    // int n_b = size(a); // WRONG
    int n_b = size(b);
    if(n_a != n_b) return 0;
    for(i in 1:n_a) {
      if(a[i] != b[i])
        return 0;
    }
    return 1;
  }

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

Really nice that you can now recycle the cholesky also for the heterogeneous case. That's great (the most common case we will use). I left a few more comments and I'd recommend to go back to the vector for lp and then sum over it whenever we cannot vectorise things. That should be better in terms of the AD tape usage.

I realised that these improvements are also beneficial to the other correlation models (cosy, AR)... nice.

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

Getting rid of a column and a row from a given Cholesky factor can be done efficiently:

https://math.stackexchange.com/questions/1896467/cholesky-decomposition-when-deleting-one-row-and-one-and-column#1896839

Not sure if want to do this...but leaving it here for reference at least.

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

Here is a better and more detailed derivation of how to remove a row&column from a Cholesky: https://normalsplines.blogspot.com/2019/02/algorithms-for-updating-cholesky.html

It's worthwhile to do, since we go from n^3 to n^2 effort (n being the dimension).

EDIT: let's file a Stan-math issue for this... such a function would be cool to have in possibly many other problems.

Here is the math issue for that:
stan-dev/math#2855

@paul-buerkner
Copy link
Owner Author

Thank you again! Having this cholesky factor subsetting efficiently in Stan math would be lovely! And I am glad if I don't have to maintain it inside of brms :-D

I have now incorporated all the remaining comments from above. The code seems to be really fast now.

Any other things to change before I merge?

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

and I checked: Eigen supports these rank updates to the Cholesky factors out of the box. So it really belongs to Stan-math to do these.

For me - I am good with this implementation now. It great to have it and it indeed got quite a bit faster, which is great. Can't wait to see this land on our production systems! Björn will run some of his examples once more and leave a note here once these are through. I think these are great stress tests for the code - so let's just wait for his note here once done.

@paul-buerkner
Copy link
Owner Author

Sounds good!

@bjoernholzhauer
Copy link

I tried everything, again, I had previously tried and it all works nicely now. Looks good to me.

@wds15
Copy link
Contributor

wds15 commented Dec 14, 2022

Great! Then this can go in from my view.

@paul-buerkner
Copy link
Owner Author

Great, Thank you for all your help! Merging now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants