Skip to content

Commit

Permalink
Fix dtype float32/64 mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
julianpollmann committed Dec 9, 2024
1 parent 3657d30 commit 122f4ae
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,9 @@ def get_mean_vector(self, keys, weights=None, pre_normalize=True, post_normalize
if len(keys) == 0:
raise ValueError("cannot compute mean with no input")
if isinstance(weights, list):
weights = np.array(weights)
weights = np.array(weights, dtype=self.vectors.dtype)
if weights is None:
weights = np.ones(len(keys))
weights = np.ones(len(keys), dtype=self.vectors.dtype)
if len(keys) != weights.shape[0]: # weights is a 1-D numpy array
raise ValueError(
"keys and weights array must have same number of elements"
Expand Down
5 changes: 3 additions & 2 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,11 @@ def update_dir_prior(prior, N, logphat, rho):
The updated prior.
"""
dtype = logphat.dtype
gradf = N * (psi(np.sum(prior)) - psi(prior) + logphat)

c = N * polygamma(1, np.sum(prior))
q = -N * polygamma(1, prior)
c = N * polygamma(1, np.sum(prior)).astype(dtype)
q = -N * polygamma(1, prior).astype(dtype)

b = np.sum(gradf / q) / (1 / c + np.sum(1 / q))

Expand Down

0 comments on commit 122f4ae

Please sign in to comment.