Skip to content

Commit

Permalink
bug fixes and some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
cpignedoli committed Apr 4, 2024
1 parent 33fc6b8 commit 02c8643
Showing 1 changed file with 67 additions and 26 deletions.
93 changes: 67 additions & 26 deletions cubehandler/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@
ANG_TO_BOHR = 1.8897259886


def remove_useless_zeros(s):
# Pattern to identify unnecessary zeros in a number
# This pattern looks for numbers that have a decimal point followed by any number of zeros, optionally ending with more digits
# It captures the number part before the decimal point and the non-zero digits after the decimal point (if any)
pattern = r"(\d+)\.0+([1-9]*)"

# Replacement function
# If there are non-zero digits after the decimal point, keep one trailing zero (for numbers like 1.0, 2.0, etc.)
# Otherwise, remove the decimal part entirely
def replacer(match):
if match.group(2): # If there are digits after the zeros
return match.group(1) + "." + match.group(2) # Keep the non-zero digits
else:
return match.group(1) # Keep only the integer part
def remove_trailing_zeros(number):
# Format the number using fixed-point notation with high precision
number_str = "{:.11f}".format(number)

# Remove trailing zeros and a possible trailing decimal point
number_str = number_str.rstrip("0").rstrip(".")

# Return the cleaned-up number as a string to avoid automatic conversion to scientific notation
return number_str


def remove_useless_zeros(input_string):
# Regular expression to match floating-point numbers
float_pattern = re.compile(r"\b\d+\.\d+\b")

# Use re.sub() to replace the matches of the pattern in the input string with the output of the replacer function
def replace_float(match):
float_str = match.group(0)
cleaned_float_str = remove_trailing_zeros(float(float_str))
return cleaned_float_str

return re.sub(pattern, replacer, s)
# Replace all occurrences of floating-point numbers in the input string
return float_pattern.sub(replace_float, input_string)


class Cube:
Expand All @@ -37,13 +41,18 @@ class Cube:
"""

default_origin = np.array([0.0, 0.0, 0.0])
default_scaling_factor = 1.0
default_comment = "Cubehandler"
default_low_precision_decimals = 3

def __init__(
self,
title=None,
comment=None,
comment=default_comment,
ase_atoms=None,
origin=default_origin,
scaling_factor=default_scaling_factor,
low_precision_decimals=default_low_precision_decimals,
cell=None,
cell_n=None,
data=None,
Expand All @@ -56,9 +65,10 @@ def __init__(
self.comment = comment
self.ase_atoms = ase_atoms
self.origin = origin
self.scaling_factor = scaling_factor
self.cell = cell
self.data = data
self.scaling_factor = 1.0
self.low_precision_decimals = low_precision_decimals
if data is not None:
self.cell_n = data.shape
else:
Expand All @@ -70,6 +80,8 @@ def from_file_handle(cls, filehandle, read_data=True):
c = cls()
c.title = f.readline().rstrip()
c.comment = f.readline().rstrip()
if "Scaling factor:" in c.comment:
c.scaling_factor = float(c.comment.split()[-1])

line = f.readline().split()
natoms = int(line[0])
Expand Down Expand Up @@ -131,6 +143,8 @@ def from_file(cls, filepath, read_data=True):
def write_cube_file(self, filename, low_precision=False):

natoms = len(self.ase_atoms)
if low_precision:
self.rescale_data()

f = open(filename, "w")

Expand All @@ -139,10 +153,15 @@ def write_cube_file(self, filename, low_precision=False):
else:
f.write(self.title + "\n")

if self.comment is None:
f.write(f"Scaling factor: {self.scaling_factor}\n")
if "Scaling factor:" in self.comment:
self.comment = re.sub(
r"Scaling factor: \d+\.\d+",
f"Scaling factor: {self.scaling_factor}",
self.comment,
)
else:
f.write(self.comment + f" Scaling factor: {self.scaling_factor}" + "\n")
self.comment += f" Scaling factor: {self.scaling_factor}"
f.write(self.comment + "\n")

dv_br = self.cell / self.data.shape

Expand Down Expand Up @@ -170,7 +189,8 @@ def write_cube_file(self, filename, low_precision=False):

if low_precision:
string_io = io.StringIO()
np.savetxt(string_io, self.data.flatten(), fmt="%.3f")
format_string = "%.{}f".format(self.low_precision_decimals)
np.savetxt(string_io, self.data.flatten(), fmt=format_string)
result_string = remove_useless_zeros(string_io.getvalue())
f.write(result_string)
else:
Expand All @@ -182,9 +202,7 @@ def reduce_data_density(self, points_per_angstrom=2):
"""Reduces the data density"""
# We should have ~ 1 point per Bohr
slicer = np.round(
self.data.shape
/ np.linalg.norm(self.ase_atoms.cell, axis=1)
/ points_per_angstrom
1.0 / (points_per_angstrom * np.linalg.norm(self.dv, axis=1))
).astype(int)
try:
self.data = self.data[:: slicer[0], :: slicer[1], :: slicer[2]]
Expand All @@ -194,8 +212,9 @@ def reduce_data_density(self, points_per_angstrom=2):
def rescale_data(self):
"""Rescales the data to be between -1 and 1"""
self.scaling_factor = max(abs(self.data.min()), abs(self.data.max()))
print("check", self.scaling_factor, abs(self.data.min()), abs(self.data.max()))
self.data /= self.scaling_factor
self.data = np.round(self.data, decimals=3)
self.data = np.round(self.data, decimals=self.low_precision_decimals)

# Convert -0 to 0
self.data[self.data == 0] = 0
Expand Down Expand Up @@ -270,6 +289,28 @@ def get_z_index(self, z_ang):
)
)

@property
def scaling_f(self):
scaling_f = self.scaling_factor
if "Scaling_factor" in self.comment:
scaling_f = float(self.comment.split()[-1])
return scaling_f

@property
def dV(self):
"""in [ang]"""
return self.ase_atoms.get_volume() / self.data.size

@property
def dV_ang(self):
"""in [ang]"""
return self.ase_atoms.get_volume() / self.data.size

@property
def dV_au(self):
"""in [au]"""
return ANG_TO_BOHR**3 * self.ase_atoms.get_volume() / self.data.size

@property
def dv(self):
"""in [ang]"""
Expand Down

0 comments on commit 02c8643

Please sign in to comment.