Skip to content

Commit

Permalink
Add site_inds optional arg to plot_forces for specific sites only
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed Apr 12, 2024
1 parent 69aa1b1 commit 84e1273
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions abipy/ml/aseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,27 +636,36 @@ def xy_energies_for_keys(self, key1: str, key2: str, sort=True) -> tuple:

return zip_sort(xs, ys) if sort else (xs, ys)

def xy_forces_for_keys(self, key1, key2, direction, symbol=None) -> tuple:
def xy_forces_for_keys(self, key1, key2, direction, symbol=None, site_inds=None) -> tuple:
"""
Return (xs, ys), sorted arrays with forces along the cart direction for (key1, key2).
Args:
symbol: If not None, select only forces for this atomic species:
symbol: If not None, select only forces for this atomic specie.
site_inds: List of site indices to consider. None if all sites should be included.
"""
idir = self.idir_from_direction(direction)
ik1, ik2 = self.inds_of_keys(key1, key2)

if symbol is None:
xs = self.forces_list[ik1,:,:,idir].flatten()
ys = self.forces_list[ik2,:,:,idir].flatten()
else:
if symbol is not None and site_inds is not None:
raise ValueError("symbol and site_inds are mutually exclusive!")

if symbol is not None:
inds = np.array(self.structure.indices_from_symbol(symbol))
if len(inds) == 0:
raise ValueError(f"Cannot find chemical symbol {symbol} in structure!")
#print("selecting symbol", symbol, "with inds:", inds)
raise ValueError(f"Cannot find chemical {symbol=} in structure!")
xs = self.forces_list[ik1,:,inds,idir].flatten()
ys = self.forces_list[ik2,:,inds,idir].flatten()

elif site_inds is not None:
site_inds = np.array(site_inds)
xs = self.forces_list[ik1,:,site_inds,idir].flatten()
ys = self.forces_list[ik2,:,site_inds,idir].flatten()

else:
xs = self.forces_list[ik1,:,:,idir].flatten()
ys = self.forces_list[ik2,:,:,idir].flatten()

return zip_sort(xs, ys)

def traj_forces_for_keys(self, key1, key2) -> tuple:
Expand Down Expand Up @@ -759,12 +768,13 @@ def plot_energies(self, fontsize=8, **kwargs):
return fig

@add_fig_kwargs
def plot_forces(self, symbol=None, fontsize=8, **kwargs):
def plot_forces(self, symbol=None, site_inds=None, fontsize=8, **kwargs):
"""
Parity plot for forces.
Args:
symbol: If not None, select only forces for this atomic species
symbol: If not None, select only forces for this atomic specie.
site_inds: List of site indices to consider. None if all sites should be included.
"""
key_pairs = self.get_key_pairs()
nrows, ncols = 3, len(key_pairs)
Expand All @@ -776,12 +786,14 @@ def plot_forces(self, symbol=None, fontsize=8, **kwargs):

for icol, (key1, key2) in enumerate(key_pairs):
for irow, direction in enumerate(("x", "y", "z")):
xs, ys = self.xy_forces_for_keys(key1, key2, direction, symbol=symbol)
stats = diff_stats(xs, ys)
ax = ax_mat[irow, icol]
ax.scatter(xs, ys, marker="o")
ax.grid(True)

xs, ys = self.xy_forces_for_keys(key1, key2, direction, symbol=symbol, site_inds=site_inds)
stats = diff_stats(xs, ys)
ax.scatter(xs, ys, marker="o")
linear_fit_ax(ax, xs, ys, fontsize=fontsize, with_label=True)

ax.legend(loc="best", shadow=True, fontsize=fontsize)
f_tex = f"$F_{direction}$"
if icol == 0:
Expand Down Expand Up @@ -865,12 +877,13 @@ def plot_energies_traj(self, delta_mode=True, fontsize=6, markersize=2, **kwargs
return fig

@add_fig_kwargs
def plot_forces_traj(self, delta_mode=True, fontsize=6, markersize=2, **kwargs):
def plot_forces_traj(self, delta_mode=True, symbol=None, fontsize=6, markersize=2, **kwargs):
"""
Plot forces along the trajectory.
Args:
delta_mode: True to plot differences instead of absolute values.
symbol: If not None, select only forces for this atomic species
"""
# Fx,Fy,Fx along rows, pairs along columns.
key_pairs = self.get_key_pairs()
Expand All @@ -884,19 +897,27 @@ def plot_forces_traj(self, delta_mode=True, fontsize=6, markersize=2, **kwargs):
atom2_cmap = plt.get_cmap("jet")
marker_idir = {0: ">", 1: "<", 2: "^"}

if symbol is None:
inds = np.array(self.structure.indices_from_symbol(symbol))
if len(inds) == 0:
raise ValueError(f"Cannot find chemical {symbol=} in structure!")

for icol, (key1, key2) in enumerate(key_pairs):
# Arrays of shape: [nsteps, natom, 3]
f1_tad, f2_tad = self.traj_forces_for_keys(key1, key2)
for idir, direction in enumerate(("x", "y", "z")):
last_row = idir == 2
fp_tex = f"F_{direction}"
xs, ys = self.xy_forces_for_keys(key1, key2, direction)
xs, ys = self.xy_forces_for_keys(key1, key2, direction, symbol=symbol)
stats = diff_stats(xs, ys)
ax = ax_mat[idir, icol]
ax.set_title(f"{key1}/{key2} MAE: {stats.MAE:.6f}", fontsize=fontsize)
symb = "" if symbol is None else f"{symbol=}"
ax.set_title(f"{key1}/{key2} MAE: {stats.MAE:.6f} {symb}", fontsize=fontsize)

zero_values = False
for iatom in range(self.natom):
# Select atoms by symbol
if symbol is not None and iatom not in inds: continue
if delta_mode:
# Plot delta of forces along the trajectory.
style = dict(marker=marker_idir[idir], markersize=markersize,
Expand Down

0 comments on commit 84e1273

Please sign in to comment.