Skip to content

Commit

Permalink
reverse unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
RylieWeaver committed Dec 13, 2024
1 parent 8002f68 commit 112df12
Showing 1 changed file with 4 additions and 29 deletions.
33 changes: 4 additions & 29 deletions examples/LennardJones/LJ_inference_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,6 @@ def getcolordensity(xdata, ydata):
return hist2d_norm


from sklearn.metrics import r2_score


def get_r2(energy_true_list, energy_pred_list, forces_true_list, forces_pred_list):
# Convert inputs to numpy arrays
energy_true_list = np.array(energy_true_list)
energy_pred_list = np.array(energy_pred_list)
forces_true_list = np.array(forces_true_list)
forces_pred_list = np.array(forces_pred_list)

# Compute R^2 for energy
energy_r2 = r2_score(energy_true_list, energy_pred_list)

# Compute R^2 for forces (flatten both arrays for 1D comparison)
forces_r2 = r2_score(forces_true_list.flatten(), forces_pred_list.flatten())

return energy_r2, forces_r2


if __name__ == "__main__":

modelname = "LJ"
Expand Down Expand Up @@ -202,7 +183,7 @@ def get_r2(energy_true_list, energy_pred_list, forces_true_list, forces_pred_lis
variable_index = 0
# for output_name, output_type, output_dim in zip(config["NeuralNetwork"]["Variables_of_interest"]["output_names"], config["NeuralNetwork"]["Variables_of_interest"]["type"], config["NeuralNetwork"]["Variables_of_interest"]["output_dim"]):

# test_MAE = 0.0
test_MAE = 0.0

num_samples = len(testset)
energy_true_list = []
Expand All @@ -216,7 +197,7 @@ def get_r2(energy_true_list, energy_pred_list, forces_true_list, forces_pred_lis
0
] # Note that this is sensitive to energy and forces prediction being single-task (current requirement)
energy_pred = torch.sum(node_energy_pred, dim=0).float()
# test_MAE += torch.norm(energy_pred - data.energy, p=1).item() / len(testset)
test_MAE += torch.norm(energy_pred - data.energy, p=1).item() / len(testset)
# predicted.backward(retain_graph=True)
# gradients = data.pos.grad
grads_energy = torch.autograd.grad(
Expand All @@ -231,12 +212,6 @@ def get_r2(energy_true_list, energy_pred_list, forces_true_list, forces_pred_lis
forces_pred_list.extend((-grads_energy).flatten().tolist())
forces_true_list.extend(data.forces.flatten().tolist())

test_energy_r2, test_force_r2 = get_r2(
energy_true_list, energy_pred_list, forces_true_list, forces_pred_list
)
print(f"Test R2 energy: ", test_energy_r2)
print(f"Test R2 forces: ", test_force_r2)

hist2d_norm = getcolordensity(energy_true_list, energy_pred_list)

fig, ax = plt.subplots()
Expand All @@ -246,12 +221,12 @@ def get_r2(energy_true_list, energy_pred_list, forces_true_list, forces_pred_lis
plt.colorbar()
plt.xlabel("True values")
plt.ylabel("Predicted values")
plt.title(f"Energy")
plt.title(f"energy")
plt.draw()
plt.tight_layout()
plt.savefig(f"./energy_Scatterplot" + ".png", dpi=400)

# print(f"Test MAE energy: ", test_MAE)
print(f"Test MAE energy: ", test_MAE)

hist2d_norm = getcolordensity(forces_pred_list, forces_true_list)
fig, ax = plt.subplots()
Expand Down

0 comments on commit 112df12

Please sign in to comment.