Skip to content

Commit

Permalink
Adds timing and r2
Browse files Browse the repository at this point in the history
PTNobel committed Dec 11, 2024
1 parent 0597de2 commit 0abd6cf
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/adelie_example.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@
import randalo.adelie_integration as ai
import torch

ld, alo = ai.get_alo_for_sweep(y, state, torch.nn.MSELoss())
ld, alo, ts, r2 = ai.get_alo_for_sweep(y, state, torch.nn.MSELoss())
dg = ad.diagnostic.diagnostic(state)
dg.plot_devs()
plt.plot(-np.log(ld), alo)
10 changes: 6 additions & 4 deletions randalo/adelie_integration.py
Original file line number Diff line number Diff line change
@@ -76,8 +76,7 @@ def adelie_state_to_jacobian(y, state, adelie_state):

return loss, J

def adelie_state_to_randalo(y, state, adelie_state, loss, J, index, rng=None):
y_hat = (state.X @ state.betas[index].T).squeeze()
def adelie_state_to_randalo(y, y_hat, state, adelie_state, loss, J, index, rng=None):
adelie_state.set_index(index)
randalo = ra.RandALO(
loss,
@@ -92,15 +91,18 @@ def get_alo_for_sweep(y, state, risk_fun):
L, _ = state.betas.shape
adelie_state = AdelieState(state)
loss, J = adelie_state_to_jacobian(y, state, adelie_state)
y_hat = ad.diagnostic.predict(state.X, state.betas, state.intercepts)

output = np.empty(L)
times = np.empty(L)
r2 = np.empty(L)

for i in range(L):
t0 = time.monotonic()
randalo = adelie_state_to_randalo(y, state, adelie_state, loss, J, i)
randalo = adelie_state_to_randalo(y, y_hat[i], state, adelie_state, loss, J, i)
output[i] = randalo.evaluate(risk_fun)
times[i] = time.monotonic() - t0
r2[i] = 1 - np.square(y - y_hat[i]).sum() / np.square(y - np.mean(y)).sum()

return state.lmda_path[:L], output, times
return state.lmda_path[:L], output, times, r2

4 changes: 2 additions & 2 deletions utils/sherlock_script.py
Original file line number Diff line number Diff line change
@@ -67,6 +67,6 @@
oos[i] = loss(torch.from_numpy(y_hat_test[:, i]), torch.from_numpy(y_test))
ins[i] = loss(torch.from_numpy(y_hat_train[:, i]), torch.from_numpy(y_train))

ld, alo, ts = ai.get_alo_for_sweep(y_train, state, loss)
ld, alo, ts, r2 = ai.get_alo_for_sweep(y_train, state, loss)

np.savez(sys.argv[-1], lamda=ld, alo=alo, oos=oos, in_sample=ins, ts=ts)
np.savez(sys.argv[-1], lamda=ld, alo=alo, oos=oos, in_sample=ins, ts=ts, r2=r2)

0 comments on commit 0abd6cf

Please sign in to comment.