Skip to content

Commit

Permalink
Revert back to using numpy
Browse files Browse the repository at this point in the history
Jax reports errors, type mismatch

Numpy issue may be able to be resolve here: tensorflow/models#9706

	modified:   astronet/metrics.py
  • Loading branch information
tallamjr committed Dec 19, 2021
1 parent 1f12d49 commit 00918ad
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions astronet/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd
import sys
import tensorflow as tf
import jax.numpy as jnp

from tensorflow import keras

Expand Down Expand Up @@ -86,7 +85,7 @@ def mywloss(y_true, y_pred):
Where:
wtable - is a numpy 1d array with (the number of times class y_true occur in the data set)/(size of data set)
"""
wtable = jnp.sum(y_true, axis=0) / y_true.shape[0]
wtable = np.sum(y_true, axis=0) / y_true.shape[0]

yc = tf.clip_by_value(y_pred, 1e-15, 1 - 1e-15)

Expand Down

0 comments on commit 00918ad

Please sign in to comment.