Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-epsilon merge from main #442

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 153 additions & 20 deletions multiego.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import sys
import os
import numpy as np

from src.multiego import ensemble
from src.multiego import io
Expand Down Expand Up @@ -132,11 +133,19 @@ def meGO_parsing():
type=str,
help="Custom dictionary for special molecules",
)
optional_args.add_argument(
"--custom_c12",
type=str,
help="Custom dictionary of c12 for special molecules",
)
optional_args.add_argument(
"--no_header",
action="store_true",
help="Removes headers from the output files when set",
)
parser.add_argument("--multi_epsi_intra", type=str, help="Path to the input file specifying the intra epsilons")
parser.add_argument("--multi_epsi_inter_domain", type=str, help="Path to the input file specifying the intra epsilons")
parser.add_argument("--multi_epsi_inter", type=str, help="Path to the input file specifying the inter epsilons")
optional_args.add_argument(
"--symmetry",
default="",
Expand Down Expand Up @@ -180,8 +189,10 @@ def meGO_parsing():
print("--egos=production requires the list of folders containing the training simulations using the --train flag")
sys.exit()

if args.epsilon is None and args.egos != "rc":
print("--epsilon is required when using --egos=production. The typical range is between 0.2 and 0.4 kJ/mol")
if (args.epsilon is None and args.multi_epsi_intra is None) and args.egos != "rc":
print(
"--epsilon or --multi_epsi_intra is required when using --egos=production. The typical range is between 0.2 and 0.4 kJ/mol"
)
sys.exit()

if args.p_to_learn < 0.9:
Expand All @@ -191,29 +202,151 @@ def meGO_parsing():
print("--epsilon_min (" + str(args.epsilon_min) + ") must be greater than 0.")
sys.exit()

if args.egos != "rc" and args.epsilon <= args.epsilon_min:
print("--epsilon (" + str(args.epsilon) + ") must be greater than --epsilon_min (" + str(args.epsilon_min) + ")")
# MULTI EPSILON CASES
# TODO add the option to write in the multi epsi inter file Nones in order to remove interaction between systems
# CHECK if either the single option or the multi option are provided. If both break
if args.epsilon is not None and args.multi_epsi_intra is not None:
print("""Choose either a single intra epsilon for the system or the multi-epsilon inter. Cannot choose both""")
sys.exit()

if args.egos != "rc" and args.inter_domain_epsilon <= args.epsilon_min:
if args.inter_domain_epsilon is not None and args.multi_epsi_inter_domain is not None:
print(
"--inter_domain_epsilon ("
+ str(args.inter_domain_epsilon)
+ ") must be greater than --epsilon_min ("
+ str(args.epsilon_min)
+ ")"
"Choose either a single inter domain epsilon for the system or the multi-epsilon inter domain. Cannot choose both"
)
sys.exit()
if args.inter_epsilon is not None and args.multi_epsi_inter is not None:
print("Choose either a single inter epsilon for the system or the multi-epsilon inter. Cannot choose both")
sys.exit()

if args.egos != "rc" and args.inter_epsilon <= args.epsilon_min:
# CHECK if multi_epsi_inter_domain or multi_epsi_intra are parsed but not the multi_epsi_intra break
if args.multi_epsi_intra is None and (args.multi_epsi_inter_domain is not None or args.multi_epsi_inter is not None):
print(
"--inter_epsilon ("
+ str(args.inter_epsilon)
+ ") must be greater than --epsilon_min ("
+ str(args.epsilon_min)
+ ")"
"""--multi_epsi_inter_domain or --multi_epsi_inter where used, but --multi_epsi_intra was not parsed.
In order to use the multi-epsilon option --multi_epsi_intra must be parsed. Please provide one or use the single epsslon case with:
--epsilon
--inter_domain_epsilon
--inter_epsilon"""
)
sys.exit()
exit()

# if multi-epsi intra is parsed start overwrite other parameters
if args.multi_epsi_intra is not None:
setattr(args, "multi_mode", True)
args.names, args.multi_epsilon = io.read_intra_file(args.multi_epsi_intra)

# INTER-DOMAIN
# multi_epsi inter domain
if args.multi_epsi_inter_domain is not None:
args.names_inter_domain, args.multi_epsilon_inter_domain = io.read_intra_file(args.multi_epsi_inter_domain)

# multi_inter domain None but inter domain parsed --> ERROR
if args.multi_epsi_inter_domain is None and args.inter_domain_epsilon is not None:
print(
"""Inter domain option should be parsed with --multi_epsi_inter_domain if --multi_epsi_intra is used and not with --inter_domain_epsilon
Choose either multiple epsilon options:
--multi_epsi_intra PATH_TO_FILE
--multi_epsi_inter_domain PATH_TO_FILE
Or the single epsilon options:
--epsilon VALUE
--inter_domain_epsilon VALUE
"""
)
exit()

# CASE: multi intra but no multi inter domain --> set multi_inter_domain as multi_intra
if args.multi_epsi_inter_domain is None and args.inter_domain_epsilon is None and args.multi_epsi_intra is not None:
setattr(args, "names_inter_domain", args.names)
setattr(args, "multi_epsilon_inter_domain", args.multi_epsilon)

# INTER
if args.multi_epsi_inter:
args.names_inter, args.multi_epsilon_inter = io.read_inter_file(args.multi_epsi_inter)

# No multi_epsilon_inter, no inter_epsilon --> set multi_epsilon_inter as one of the multi_epsi_intra (should not be needed if it's not defined explicetily)
if args.multi_epsi_inter is None and args.inter_epsilon is None and args.multi_epsi_intra is not None:
print(
"""--multi intra mode activated, but no information for inter epsilon was set.
Please set also the inter molecular interaction using one of the following options:
-inter_epsilon VALUE
-multi_epsi_inter PATH_TO_FILE """
)
exit()

# No multi_epsilon_inter, inter_epsilon --> set multi_epsilon_inter as inter_epsilon
if args.multi_epsi_inter is None and args.inter_epsilon is not None:
setattr(args, "names_inter", np.array(args.names))
setattr(
args, "multi_epsilon_inter", np.zeros((len(args.multi_epsilon), len(args.multi_epsilon))) + args.inter_epsilon
)

# Multi-case checks:
if args.multi_epsi_inter is not None and args.multi_epsi_intra is not None:
if np.any(np.array(args.names) != np.array(args.names_inter)):
print(
f"""ERROR: the names of the molecules in the files {args.multi_epsi_intra} and {args.multi_epsi_inter} are different.
The names of the molecules must be consistent with each other and with those in the topology"""
)
exit()

# if multi_inter and no multi intra break
if args.multi_epsi_inter is not None and args.multi_epsi_intra is None:
print(
"""if multi_epsi_inter is used, also multi_epsi must be used. define also the set of epsilons via --multi_epsi_intra """
)

# if multi_inter_domain and no multi intra break
if args.multi_epsi_inter_domain is not None and args.multi_epsi_intra is None:
print(
"""--if multi_epsi_inter_domain is used, also multi_epsi must be used. define also the set of epsilons via --multi_epsi_intra """
)
else:
setattr(args, "multi_mode", False)

# CHECK all epsilons are greater than epsilon_min
if args.epsilon is not None:
if args.egos != "rc" and args.epsilon <= args.epsilon_min:
print("--epsilon (" + str(args.epsilon) + ") must be greater than --epsilon_min (" + str(args.epsilon_min) + ")")
sys.exit()

if args.egos != "rc" and args.inter_domain_epsilon <= args.epsilon_min:
print(
"--inter_domain_epsilon ("
+ str(args.inter_domain_epsilon)
+ ") must be greater than --epsilon_min ("
+ str(args.epsilon_min)
+ ")"
)
sys.exit()

if args.egos != "rc" and args.inter_epsilon <= args.epsilon_min:
print(
"--inter_epsilon ("
+ str(args.inter_epsilon)
+ ") must be greater than --epsilon_min ("
+ str(args.epsilon_min)
+ ")"
)
sys.exit()

elif args.multi_mode is not None:
if args.egos != "rc" and np.min(args.multi_epsilon) <= args.epsilon_min:
print(
f"all epsilons in {args.multi_epsi_intra} must be greater than --epsilon_min (" + str(args.epsilon_min) + ")"
)
sys.exit()

if args.egos != "rc" and np.min(args.multi_epsilon_inter_domain) <= args.epsilon_min:
print(
f"all epsilons in {args.multi_epsi_inter_domain} must be greater than --epsilon_min ("
+ str(args.epsilon_min)
+ ")"
)
sys.exit()

if args.egos != "rc" and np.min(args.multi_epsilon_inter) <= args.epsilon_min:
print(
f"all epsilons in {args.multi_epsi_inter} must be greater than --epsilon_min (" + str(args.epsilon_min) + ")"
)
sys.exit()

if args.custom_dict:
custom_dict = parse_json(args.custom_dict)
Expand Down Expand Up @@ -273,11 +406,11 @@ def get_meGO_LJ(meGO_ensemble, args):
"""
pairs14, exclusion_bonds14 = ensemble.generate_14_data(meGO_ensemble)
if args.egos == "rc":
meGO_LJ = ensemble.generate_basic_LJ(meGO_ensemble)
meGO_LJ = ensemble.generate_basic_LJ(meGO_ensemble, args)
meGO_LJ_14 = pairs14
meGO_LJ_14["epsilon"] = -meGO_LJ_14["c12"]
else:
train_dataset, check_dataset = ensemble.init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14)
train_dataset, check_dataset = ensemble.init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14, args)
meGO_LJ, meGO_LJ_14 = ensemble.generate_LJ(meGO_ensemble, train_dataset, check_dataset, args)

meGO_LJ_14 = ensemble.make_pairs_exclusion_topology(meGO_ensemble, meGO_LJ_14)
Expand Down
98 changes: 90 additions & 8 deletions src/multiego/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,67 @@ def initialize_molecular_contacts(contact_matrix, path, ensemble_molecules_idx_s
p_sort_normalized = np.cumsum(p_sort_id) / norm_id
md_threshold_id = p_sort_id[np.min(np.where(p_sort_normalized > args.p_to_learn)[0])]

# needed to obtain the correct epsilon
contact_matrix["molecule_idx_ai_temp"] = contact_matrix["molecule_name_ai"].str.split("_").str[0]
contact_matrix["molecule_idx_aj_temp"] = contact_matrix["molecule_name_aj"].str.split("_").str[0]

# set the epsilon_0 this simplify a lot of the following code
# for intra-domain
contact_matrix.loc[(contact_matrix["same_chain"]) & (contact_matrix["intra_domain"]), "epsilon_0"] = args.epsilon
contact_matrix.loc[(contact_matrix["same_chain"]) & (contact_matrix["intra_domain"]), "zf"] = args.f
# for inter-domain
contact_matrix.loc[(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"] = (
args.inter_domain_epsilon
)
contact_matrix.loc[(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "zf"] = args.inter_domain_f
# for inter-molecular
contact_matrix.loc[(~contact_matrix["same_chain"]), "epsilon_0"] = args.inter_epsilon
contact_matrix.loc[(~contact_matrix["same_chain"]), "zf"] = args.inter_f
if args.multi_mode:
# for intra-domain
if args.multi_epsilon is not None:
temp_epsi_intra = args.multi_epsilon[contact_matrix["molecule_idx_ai_temp"].to_numpy(dtype=int)[0] - 1]
contact_matrix.loc[
(contact_matrix["same_chain"]) & (contact_matrix["intra_domain"]), "epsilon_0"
] = temp_epsi_intra
if name[0] == "intramat":
print(f" -Intra-domain epsilon {temp_epsi_intra}")
else:
print("intra multi modality violated: this should never happend")
# for inter-domain
if args.multi_epsilon_inter_domain is not None:
temp_epsi_inter_dom = args.multi_epsilon_inter_domain[
contact_matrix["molecule_idx_ai_temp"].to_numpy(dtype=int)[0] - 1
]
contact_matrix.loc[
(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"
] = temp_epsi_inter_dom
if name[0] == "intramat":
print(f" -Inter-domain epsilon {temp_epsi_inter_dom}")
else:
print("inter domain multi modality violated: this should never happend")
# for inter-molecular
if args.multi_epsilon_inter is not None:
temp_epsi_inter = args.multi_epsilon_inter[
contact_matrix["molecule_idx_ai_temp"].to_numpy(dtype=int)[0] - 1,
contact_matrix["molecule_idx_aj_temp"].to_numpy(dtype=int)[0] - 1,
]
contact_matrix.loc[(~contact_matrix["same_chain"]), "epsilon_0"] = temp_epsi_inter
if name[0] == "intermat":
print(f" -Inter-molecular epsilon {temp_epsi_inter}")
else:
print("inter multi modality violated: this should never happend")
else:
# for intra-domain
contact_matrix.loc[(contact_matrix["same_chain"]) & (contact_matrix["intra_domain"]), "epsilon_0"] = args.epsilon
if name[0] == "intramat":
print(f" -Intra-domain epsilon {args.epsilon}")
# for inter-domain
contact_matrix.loc[
(contact_matrix["same_chain"]) & (~contact_matrix["intra_domain"]), "epsilon_0"
] = args.inter_domain_epsilon
if name[0] == "intramat":
print(f" -Inter-domain epsilon {args.inter_domain_epsilon}")
# for inter-molecular
contact_matrix.loc[(~contact_matrix["same_chain"]), "epsilon_0"] = args.inter_epsilon
if name[0] == "intermat":
print(f" -Inter-molecular epsilon {args.inter_epsilon}")

# add the columns for rc, md threshold
contact_matrix = contact_matrix.assign(md_threshold=md_threshold)
contact_matrix.loc[~(contact_matrix["intra_domain"]), "md_threshold"] = md_threshold_id
Expand Down Expand Up @@ -381,6 +430,23 @@ def init_meGO_ensemble(args):
molecule_type_dict,
) = initialize_topology(reference_topology, custom_dict)

if args.multi_mode:
mol_check = []
for mol in reference_topology.molecules:
mol_check.append(mol)
if len(mol_check) != len(args.names):
print("Error the number of molecules in the input file is different from that in the topology")
exit()
for i, mol_appo in enumerate(mol_check):
if mol_appo != args.names[i]:
print(
f"""ERROR: the name of the molecule from topology is different from that of the input file.
The names must be chosen in the same way to avoid further errors in the association to the specific epsilon.
File: {args.names[i]} ---- mego_topology: {mol_appo}
"""
)
exit()

reference_contact_matrices = {}
io.check_matrix_format(args)
if args.egos != "rc":
Expand Down Expand Up @@ -714,7 +780,7 @@ def generate_14_data(meGO_ensemble):
return pairs14, exclusion_bonds14


def init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14):
def init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14, args):
"""
Initializes LJ (Lennard-Jones) datasets for train and check matrices within a molecular ensemble.

Expand Down Expand Up @@ -788,6 +854,12 @@ def init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14):
train_dataset["type_ai"] = train_dataset["ai"].map(meGO_ensemble["sbtype_type_dict"])
train_dataset["type_aj"] = train_dataset["aj"].map(meGO_ensemble["sbtype_type_dict"])
type_to_c12 = {key: val for key, val in zip(type_definitions.gromos_atp.name, type_definitions.gromos_atp.c12)}

if args.custom_c12 is not None:
custom_c12_dict = io.read_custom_c12_parameters(args.custom_c12)
type_to_c12_appo = {key: val for key, val in zip(custom_c12_dict.name, custom_c12_dict.c12)}
type_to_c12.update(type_to_c12_appo)

oxygen_mask = masking.create_linearized_mask(
train_dataset["type_ai"].to_numpy(),
train_dataset["type_aj"].to_numpy(),
Expand Down Expand Up @@ -859,6 +931,11 @@ def init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14):
check_dataset["type_ai"] = check_dataset["ai"].map(meGO_ensemble["sbtype_type_dict"])
check_dataset["type_aj"] = check_dataset["aj"].map(meGO_ensemble["sbtype_type_dict"])
type_to_c12 = {key: val for key, val in zip(type_definitions.gromos_atp.name, type_definitions.gromos_atp.c12)}
if args.custom_c12 is not None:
custom_c12_dict = io.read_custom_c12_parameters(args.custom_c12)
type_to_c12_appo = {key: val for key, val in zip(custom_c12_dict.name, custom_c12_dict.c12)}
type_to_c12.update(type_to_c12_appo)

oxygen_mask = masking.create_linearized_mask(
check_dataset["type_ai"].to_numpy(),
check_dataset["type_aj"].to_numpy(),
Expand Down Expand Up @@ -888,7 +965,7 @@ def init_LJ_datasets(meGO_ensemble, pairs14, exclusion_bonds14):
return train_dataset, check_dataset


def generate_basic_LJ(meGO_ensemble):
def generate_basic_LJ(meGO_ensemble, args):
"""
Generates basic LJ (Lennard-Jones) interactions DataFrame within a molecular ensemble.

Expand Down Expand Up @@ -934,6 +1011,11 @@ def generate_basic_LJ(meGO_ensemble):

topol_df = meGO_ensemble["topology_dataframe"]
name_to_c12 = {key: val for key, val in zip(type_definitions.gromos_atp.name, type_definitions.gromos_atp.c12)}
if args.custom_c12 is not None:
custom_c12_dict = io.read_custom_c12_parameters(args.custom_c12)
name_to_c12_appo = {key: val for key, val in zip(custom_c12_dict.name, custom_c12_dict.c12)}
name_to_c12.update(name_to_c12_appo)

if meGO_ensemble["reference_matrices"] == {}:
basic_LJ = pd.DataFrame(columns=columns)
basic_LJ["index_ai"] = [
Expand Down Expand Up @@ -1712,7 +1794,7 @@ def generate_LJ(meGO_ensemble, train_dataset, check_dataset, parameters):

# Now is time to add masked default interactions for pairs
# that have not been learned in any other way
basic_LJ = generate_basic_LJ(meGO_ensemble)
basic_LJ = generate_basic_LJ(meGO_ensemble, parameters)
basic_LJ = basic_LJ[needed_fields]
meGO_LJ = pd.concat([meGO_LJ, basic_LJ])

Expand Down
Loading
Loading