Skip to content

Commit

Permalink
Gstore
Browse files Browse the repository at this point in the history
  • Loading branch information
gmatteo committed May 23, 2024
1 parent bab0e84 commit 0ad2152
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
23 changes: 13 additions & 10 deletions abipy/eph/gstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ class Gqk:
vkmat_cart_ibz: np.ndarray | None

@classmethod
def from_gstore(cls, gstore, spin: int):
def from_gstore(cls, gstore: GstoreFile, spin: int):
"""
Build an istance from a GstoreFile and the spin index.
"""
ncr = gstore.r
path = f"gqk_spin{spin+1}"
Expand Down Expand Up @@ -213,25 +214,27 @@ def to_string(self, verbose=0) -> str:

return "\n".join(lines)

def to_dataframe(self, what="g2"):
def get_dataframe(self, what="g2"):
"""
"""
if what == "g2":
g2 = self.g2 if self.g2 is not None else np.abs(self.gvals) ** 2
shape, ndim = g2.shape, g2.ndim
# Flatten the array , get the indices and combine indices and values into a DataFrame
# Flatten the array, get the indices and combine indices and values into a DataFrame
indices = np.indices(shape).reshape(ndim, -1).T
df = pd.DataFrame(indices, columns=["iq", "ik", "imode", "m_kq", "n_k"])
df["g2"] = g2.flatten()

elif what == "v2":
if self.vcart_ibz is None:
raise ValueError("vcart_ibz is not available!")
if self.vk_cart_ibz is None:
raise ValueError("vk_cart_ibz is not available!")

# Compute the squared norm of each vector
v2 = np.sum(self.vcart_ibz ** 2, axis=2)
v2 = np.sum(self.vk_cart_ibz ** 2, axis=2)
shape, ndim = v2.shape, v2.ndim
df = pd.DataFrame(indices, columns=["ik", "imode", "ib"])
# Flatten the array, get the indices and combine indices and values into a DataFrame
indices = np.indices(shape).reshape(ndim, -1).T
df = pd.DataFrame(indices, columns=["ik", "ib"])
df["v2"] = v2.flatten()

else:
Expand All @@ -241,7 +244,7 @@ def to_dataframe(self, what="g2"):

def get_gdf_at_qpt_kpt(self, qpoint, kpoint) -> pd.DataFrame:
"""
Build and return a Dataframe with the |g(k+q,k)|^2 for the given (qpoint, kpoint) pair.
Build and return a dataframe with the |g(k+q,k)|^2 for the given (qpoint, kpoint) pair.
"""
spin = self.spin

Expand Down Expand Up @@ -350,7 +353,7 @@ def __init__(self, filepath: PathLike):
self.kglob2bz -= 1

def find_iq_glob_qpoint(self, qpoint, spin: int):
"""Find the internal index of qpoint"""
"""Find the internal index of qpoint needed to access the gvals array."""
qpoint = np.array(qpoint)
for iq_g, iq_bz in enumerate(self.qglob2bz[spin]):
if np.allclose(qpoint, self.qbz[iq_bz]):
Expand All @@ -359,7 +362,7 @@ def find_iq_glob_qpoint(self, qpoint, spin: int):
raise ValueError(f"Cannot find {qpoint=} in GSTORE.nc")

def find_ik_glob_kpoint(self, kpoint, spin: int):
"""Find the internal indices of kpoint"""
"""Find the internal indices of kpoint needed in access the gvals array."""
kpoint = np.array(kpoint)
for ik_g, ik_bz in enumerate(self.kglob2bz[spin]):
if np.allclose(kpoint, self.kbz[ik_bz]):
Expand Down
1 change: 1 addition & 0 deletions abipy/ml/aseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,7 @@ class MyM3GNetCalculator(_MyCalculator, M3GNetCalculator):
else:
#model_name = "M3GNet-MP-2021.2.8-PES" if self.model_name is None else self.model_name
model_name = "M3GNet-MP-2021.2.8-DIRECT-PES" if self.model_name is None else self.model_name
print("Using model_name:", model_name)
self._model = matgl.load_model(model_name)

class MyM3GNetCalculator(_MyCalculator, M3GNetCalculator):
Expand Down

0 comments on commit 0ad2152

Please sign in to comment.