Skip to content

Commit

Permalink
fix aggregated posterior computation as described in #3188
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Feb 19, 2025
1 parent 126e0d3 commit 09947b6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,11 +789,11 @@ def get_aggregated_posterior(
qu_locs.append(outputs["qu"].loc)
qu_scales.append(outputs["qu"].scale)

qu_loc = jnp.concatenate(qu_locs, axis=0).T
qu_scale = jnp.concatenate(qu_scales, axis=0).T
qu_loc = jnp.concatenate(qu_locs, axis=0)
qu_scale = jnp.concatenate(qu_scales, axis=0)
return MixtureSameFamily(
Categorical(probs=jnp.ones(qu_loc.shape[1]) / qu_loc.shape[1]),
Normal(qu_loc, qu_scale),
Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]),
Normal(qu_loc, qu_scale).to_event(1),
)

def differential_abundance(
Expand Down

0 comments on commit 09947b6

Please sign in to comment.