Skip to content

Commit

Permalink
Implemented cross symmetries
Browse files Browse the repository at this point in the history
  • Loading branch information
frantropy committed Mar 5, 2024
1 parent 5297979 commit c0943ef
Showing 1 changed file with 31 additions and 76 deletions.
107 changes: 31 additions & 76 deletions src/multiego/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,47 +1101,26 @@ def generate_LJ(meGO_ensemble, train_dataset, check_dataset, parameters):

for sym in symmetries:
for atypes in itertools.combinations(sym[1:], 2):
ai_rows = meGO_LJ[meGO_LJ["ai"].str.startswith(f"{atypes[0]}_") & (mglj_resn_ai == sym[0])].copy()
aj_rows = meGO_LJ[meGO_LJ["aj"].str.startswith(f"{atypes[0]}_") & (mglj_resn_aj == sym[0])].copy()
ai_rows.loc[:, "ai"] = (
atypes[1] + "_" + ai_rows["ai"].str.split("_").str[1] + "_" + ai_rows["ai"].str.split("_").str[2]
stmp_df_ai = meGO_LJ[meGO_LJ["ai"].str.startswith(f"{atypes[0]}_") & (mglj_resn_ai == sym[0])].copy()
stmp_df_aj = meGO_LJ[meGO_LJ["aj"].str.startswith(f"{atypes[0]}_") & (mglj_resn_aj == sym[0])].copy()
stmp_df_ai.loc[:, "ai"] = (
atypes[1] + "_" + stmp_df_ai["ai"].str.split("_").str[1] + "_" + stmp_df_ai["ai"].str.split("_").str[2]
)
aj_rows.loc[:, "aj"] = (
atypes[1] + "_" + aj_rows["aj"].str.split("_").str[1] + "_" + aj_rows["aj"].str.split("_").str[2]
stmp_df_aj.loc[:, "aj"] = (
atypes[1] + "_" + stmp_df_aj["aj"].str.split("_").str[1] + "_" + stmp_df_aj["aj"].str.split("_").str[2]
)

ai_rows["ai"], ai_rows["aj"] = np.where(
ai_rows["ai"].str.split("_").str[2] < ai_rows["aj"].str.split("_").str[2],
(ai_rows["ai"], ai_rows["aj"]),
(ai_rows["aj"], ai_rows["ai"]),
)
aj_rows["ai"], aj_rows["aj"] = np.where(
aj_rows["ai"].str.split("_").str[2] < aj_rows["aj"].str.split("_").str[2],
(aj_rows["ai"], aj_rows["aj"]),
(aj_rows["aj"], aj_rows["ai"]),
)
tmp_df = pd.concat([tmp_df, ai_rows, aj_rows])

ai_rows = meGO_LJ[meGO_LJ["ai"].str.startswith(f"{atypes[1]}_") & (mglj_resn_ai == sym[0])].copy()
aj_rows = meGO_LJ[meGO_LJ["aj"].str.startswith(f"{atypes[1]}_") & (mglj_resn_aj == sym[0])].copy()
ai_rows.loc[:, "ai"] = (
atypes[0] + "_" + ai_rows["ai"].str.split("_").str[1] + "_" + ai_rows["ai"].str.split("_").str[2]
if tmp_df.empty:
continue
tmp_df_i = tmp_df[tmp_df["ai"].str.startswith(f"{atypes[0]}_") & (mglj_resn_ai == sym[0])].copy()
tmp_df_j = tmp_df[tmp_df["aj"].str.startswith(f"{atypes[0]}_") & (mglj_resn_aj == sym[0])].copy()
tmp_df_i.loc[:, "ai"] = (
atypes[1] + "_" + tmp_df_i["ai"].str.split("_").str[1] + "_" + tmp_df_i["ai"].str.split("_").str[2]
)
aj_rows.loc[:, "aj"] = (
atypes[0] + "_" + aj_rows["aj"].str.split("_").str[1] + "_" + aj_rows["aj"].str.split("_").str[2]
tmp_df_j.loc[:, "aj"] = (
atypes[1] + "_" + tmp_df_j["aj"].str.split("_").str[1] + "_" + tmp_df_j["aj"].str.split("_").str[2]
)

ai_rows["ai"], ai_rows["aj"] = np.where(
ai_rows["ai"].str.split("_").str[2] < ai_rows["aj"].str.split("_").str[2],
(ai_rows["ai"], ai_rows["aj"]),
(ai_rows["aj"], ai_rows["ai"]),
)
aj_rows["ai"], aj_rows["aj"] = np.where(
aj_rows["ai"].str.split("_").str[2] < aj_rows["aj"].str.split("_").str[2],
(aj_rows["ai"], aj_rows["aj"]),
(aj_rows["aj"], aj_rows["ai"]),
)
tmp_df = pd.concat([tmp_df, ai_rows, aj_rows])
tmp_df = pd.concat([tmp_df, stmp_df_ai, stmp_df_aj, tmp_df_i, tmp_df_j])

meGO_LJ = pd.concat([meGO_LJ, tmp_df])

Expand Down Expand Up @@ -1267,57 +1246,33 @@ def generate_LJ(meGO_ensemble, train_dataset, check_dataset, parameters):
mglj_resn_aj = meGO_check_contacts["aj"].map(dict_sbtype_to_resname)
tmp_df = pd.DataFrame()
# apply symmetries to check contacts
tmp_df = pd.DataFrame()
for sym in symmetries:
for atypes in itertools.combinations(sym[1:], 2):
ai_rows = meGO_check_contacts[
stmp_df_ai = meGO_check_contacts[
meGO_check_contacts["ai"].str.startswith(f"{atypes[0]}_") & (mglj_resn_ai == sym[0])
].copy()
aj_rows = meGO_check_contacts[
stmp_df_aj = meGO_check_contacts[
meGO_check_contacts["aj"].str.startswith(f"{atypes[0]}_") & (mglj_resn_aj == sym[0])
].copy()
ai_rows.loc[:, "ai"] = (
atypes[1] + "_" + ai_rows["ai"].str.split("_").str[1] + "_" + ai_rows["ai"].str.split("_").str[2]
stmp_df_ai.loc[:, "ai"] = (
atypes[1] + "_" + stmp_df_ai["ai"].str.split("_").str[1] + "_" + stmp_df_ai["ai"].str.split("_").str[2]
)
aj_rows.loc[:, "aj"] = (
atypes[1] + "_" + aj_rows["aj"].str.split("_").str[1] + "_" + aj_rows["aj"].str.split("_").str[2]
stmp_df_aj.loc[:, "aj"] = (
atypes[1] + "_" + stmp_df_aj["aj"].str.split("_").str[1] + "_" + stmp_df_aj["aj"].str.split("_").str[2]
)

ai_rows["ai"], ai_rows["aj"] = np.where(
ai_rows["ai"].str.split("_").str[2] < ai_rows["aj"].str.split("_").str[2],
(ai_rows["ai"], ai_rows["aj"]),
(ai_rows["aj"], ai_rows["ai"]),
if tmp_df.empty:
continue
tmp_df_i = tmp_df[tmp_df["ai"].str.startswith(f"{atypes[0]}_") & (mglj_resn_ai == sym[0])].copy()
tmp_df_j = tmp_df[tmp_df["aj"].str.startswith(f"{atypes[0]}_") & (mglj_resn_aj == sym[0])].copy()
tmp_df_i.loc[:, "ai"] = (
atypes[1] + "_" + tmp_df_i["ai"].str.split("_").str[1] + "_" + tmp_df_i["ai"].str.split("_").str[2]
)
aj_rows["ai"], aj_rows["aj"] = np.where(
aj_rows["ai"].str.split("_").str[2] < aj_rows["aj"].str.split("_").str[2],
(aj_rows["ai"], aj_rows["aj"]),
(aj_rows["aj"], aj_rows["ai"]),
tmp_df_j.loc[:, "aj"] = (
atypes[1] + "_" + tmp_df_j["aj"].str.split("_").str[1] + "_" + tmp_df_j["aj"].str.split("_").str[2]
)
tmp_df = pd.concat([tmp_df, ai_rows, aj_rows])

ai_rows = meGO_check_contacts[
meGO_check_contacts["ai"].str.startswith(f"{atypes[1]}_") & (mglj_resn_ai == sym[0])
].copy()
aj_rows = meGO_check_contacts[
meGO_check_contacts["aj"].str.startswith(f"{atypes[1]}_") & (mglj_resn_aj == sym[0])
].copy()
ai_rows.loc[:, "ai"] = (
atypes[0] + "_" + ai_rows["ai"].str.split("_").str[1] + "_" + ai_rows["ai"].str.split("_").str[2]
)
aj_rows.loc[:, "aj"] = (
atypes[0] + "_" + aj_rows["aj"].str.split("_").str[1] + "_" + aj_rows["aj"].str.split("_").str[2]
)

ai_rows["ai"], ai_rows["aj"] = np.where(
ai_rows["ai"].str.split("_").str[2] < ai_rows["aj"].str.split("_").str[2],
(ai_rows["ai"], ai_rows["aj"]),
(ai_rows["aj"], ai_rows["ai"]),
)
aj_rows["ai"], aj_rows["aj"] = np.where(
aj_rows["ai"].str.split("_").str[2] < aj_rows["aj"].str.split("_").str[2],
(aj_rows["ai"], aj_rows["aj"]),
(aj_rows["aj"], aj_rows["ai"]),
)
tmp_df = pd.concat([tmp_df, ai_rows, aj_rows])
tmp_df = pd.concat([tmp_df, stmp_df_ai, stmp_df_aj, tmp_df_i, tmp_df_j])

meGO_check_contacts = pd.concat([meGO_check_contacts, tmp_df])
# check contacts are all repulsive so among duplicates we keep the one with shortest distance
Expand Down

0 comments on commit c0943ef

Please sign in to comment.