Skip to content

Commit

Permalink
Merge pull request #8 from xchem/water_retention
Browse files Browse the repository at this point in the history
Water retention
  • Loading branch information
ConorFWild authored May 8, 2024
2 parents cdbc189 + 7e3bd96 commit 882a18d
Show file tree
Hide file tree
Showing 20 changed files with 1,278 additions and 965 deletions.
9 changes: 2 additions & 7 deletions .github/pages/make_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,14 @@ def get_versions(ref: str, add: Optional[str], remove: Optional[str]) -> List[st

def write_json(path: Path, repository: str, versions: str):
org, repo_name = repository.split("/")
struct = [
dict(version=version, url=f"https://{org}.github.io/{repo_name}/{version}/")
for version in versions
]
struct = [dict(version=version, url=f"https://{org}.github.io/{repo_name}/{version}/") for version in versions]
text = json.dumps(struct, indent=2)
print(f"JSON switcher:\n{text}")
path.write_text(text)


def main(args=None):
parser = ArgumentParser(
description="Make a versions.txt file from gh-pages directories"
)
parser = ArgumentParser(description="Make a versions.txt file from gh-pages directories")
parser.add_argument(
"--add",
help="Add this directory to the list of existing directories",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ description = "One line description of your module"
dependencies = [
"gemmi",
"loguru",
"pydantic",
"pydantic==2.6.0",
"networkx",
"numpy",
"rich",
Expand Down
139 changes: 63 additions & 76 deletions src/ligand_neighbourhood_alignment/align_xmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ def _get_interpolation_range(neighbourhood: dt.Neighbourhood, transform, referen


def interpolate_range(
reference_xmap,
xmap,
interpolation_ranges: list[Block],
transform,
reference_xmap,
xmap,
interpolation_ranges: list[Block],
transform,
):
# Make a xmap on reference template
new_xmap = gemmi.FloatGrid(reference_xmap.nu, reference_xmap.nv, reference_xmap.nw)
Expand Down Expand Up @@ -294,26 +294,26 @@ def interpolate_range(
logger.debug(f"Block Z Range in output xmap: {rzi} : {rzf}")

grid_np[
rxi:rxf,
ryi:ryf,
rzi:rzf,
rxi:rxf,
ryi:ryf,
rzi:rzf,
] = arr

return new_xmap


def align_xmap(
neighbourhoods: LigandNeighbourhoods,
g,
transforms: Transforms,
site_transforms: SiteTransforms,
reference_xmap,
subsite_reference_id: LigandID,
site_id: int,
subsite_id: int,
lid: LigandID,
xmap,
output_path: Path,
neighbourhoods: LigandNeighbourhoods,
g,
transforms: Transforms,
site_transforms: SiteTransforms,
reference_xmap,
subsite_reference_id: LigandID,
site_id: int,
subsite_id: int,
lid: LigandID,
xmap,
output_path: Path,
):
# Get the ligand neighbourhood
neighbourhood: LigandNeighbourhood = neighbourhoods.get_neighbourhood(lid)
Expand Down Expand Up @@ -392,9 +392,11 @@ def _get_box(neighbourhood: dt.Neighbourhood, xmap, transform):
)
return box


def _write_xmap_from_ccp4(ccp4, path):
ccp4.write_ccp4_map(str(path))


def _write_xmap(xmap, path: Path, neighbourhood: dt.Neighbourhood, transform):

ccp4 = gemmi.Ccp4Map()
Expand All @@ -416,7 +418,6 @@ def _write_xmap(xmap, path: Path, neighbourhood: dt.Neighbourhood, transform):
ccp4.setup(float("nan"))
ccp4.update_ccp4_header()


ccp4.write_ccp4_map(str(path))


Expand All @@ -441,34 +442,22 @@ def get_frame_bounds(ligand_lower_bound, ligand_upper_bound, border, step):

def get_frame_array(frame_lower_bound, frame_upper_bound, step):
interval = np.round((frame_upper_bound - frame_lower_bound) / step)
return np.zeros(
(int(interval[0]), int(interval[1]), int(interval[2])),
dtype=np.float32
)
return np.zeros((int(interval[0]), int(interval[1]), int(interval[2])), dtype=np.float32)
...


def get_frame_transform(frame_lower_bound, frame_array, step):
tr = gemmi.Transform()
tr.vec.fromlist([x for x in frame_lower_bound])
tr.mat.fromlist(
(np.eye(3) * step).tolist()
)
tr.mat.fromlist((np.eye(3) * step).tolist())
return tr

...


def get_cell(frame_array, step):
shape = frame_array.shape
cell = gemmi.UnitCell(
shape[0] * step,
shape[1] * step,
shape[2] * step,
90.0,
90.0,
90.0
)
cell = gemmi.UnitCell(shape[0] * step, shape[1] * step, shape[2] * step, 90.0, 90.0, 90.0)
return cell


Expand All @@ -480,7 +469,7 @@ def get_new_map(cell, sample, frame_min, step):
ccp4 = gemmi.Ccp4Map()
ccp4.grid = grid
ccp4.grid.set_unit_cell(cell)
ccp4.grid.spacegroup = gemmi.SpaceGroup('P1')
ccp4.grid.spacegroup = gemmi.SpaceGroup("P1")
ccp4.update_ccp4_header()
# ccp4.set_header_float(50, frame_min[0]/step)
# ccp4.set_header_float(51, frame_min[1]/step)
Expand All @@ -490,10 +479,8 @@ def get_new_map(cell, sample, frame_min, step):
ccp4.set_header_float(52, frame_min[2])
return ccp4

def resample_xmap(
new_xmap,
aligned_res
):

def resample_xmap(new_xmap, aligned_res):
step = 0.5
border = 5.0
m = new_xmap
Expand Down Expand Up @@ -522,23 +509,26 @@ def resample_xmap(
cell = get_cell(frame_array, step)
# print(cell)

print(f'Origin for xmap is now: {frame_lower_bound}')

new_map = get_new_map(cell, frame_array, frame_lower_bound, step)
return new_map


def __align_xmap(
neighbourhood: dt.Neighbourhood,
g,
ligand_neighbourhood_transforms: dict[tuple[tuple[str, str, str], tuple[str, str, str]], dt.Transform],
reference_xmap,
subsite_reference_id: tuple[str,str,str],
lid: tuple[str, str, str],
xmap,
conformer_site_transforms,
conformer_site_id,
# canonical_site_transforms,
canonical_site_id,
output_path: Path,
aligned_res
neighbourhood: dt.Neighbourhood,
g,
ligand_neighbourhood_transforms: dict[tuple[tuple[str, str, str], tuple[str, str, str]], dt.Transform],
reference_xmap,
subsite_reference_id: tuple[str, str, str],
lid: tuple[str, str, str],
xmap,
conformer_site_transforms,
conformer_site_id,
# canonical_site_transforms,
canonical_site_id,
output_path: Path,
aligned_res,
):
# Get the ligand neighbourhood
# neighbourhood: LigandNeighbourhood = neighbourhoods.get_neighbourhood(lid)
Expand Down Expand Up @@ -590,10 +580,7 @@ def __align_xmap(
)

# Resample the xmap to the aligned structure frame
resampled_xmap = resample_xmap(
new_xmap,
aligned_res
)
resampled_xmap = resample_xmap(new_xmap, aligned_res)

# Output the xmap
# _write_xmap(
Expand All @@ -606,13 +593,13 @@ def __align_xmap(
_write_xmap_from_ccp4(
resampled_xmap,
output_path,

)


def read_xmap_from_mtz(
mtz_path: Path,
map_type="2Fo-Fc",):
mtz_path: Path,
map_type="2Fo-Fc",
):
mtz = gemmi.read_mtz_file(str(mtz_path))

if map_type == "2Fo-Fc":
Expand Down Expand Up @@ -644,15 +631,15 @@ def read_xmap_from_mtz(


def _align_xmaps(
system_data: SystemData,
structures,
canonical_sites: CanonicalSites,
conformer_sites: ConformerSites,
neighbourhoods: LigandNeighbourhoods,
g,
transforms: Transforms,
site_transforms: SiteTransforms,
output: Output,
system_data: SystemData,
structures,
canonical_sites: CanonicalSites,
conformer_sites: ConformerSites,
neighbourhoods: LigandNeighbourhoods,
g,
transforms: Transforms,
site_transforms: SiteTransforms,
output: Output,
):
# Get the global reference
# reference_lid: LigandID = canonical_sites.reference_site.reference_ligand_id
Expand Down Expand Up @@ -779,14 +766,14 @@ def _align_xmaps(


def _align_xmap(
system_data: SystemData,
canonical_sites: CanonicalSites,
conformer_sites: ConformerSites,
neighbourhoods: LigandNeighbourhoods,
g,
transforms: Transforms,
site_transforms: SiteTransforms,
output: Output,
system_data: SystemData,
canonical_sites: CanonicalSites,
conformer_sites: ConformerSites,
neighbourhoods: LigandNeighbourhoods,
g,
transforms: Transforms,
site_transforms: SiteTransforms,
output: Output,
):
# Get the global reference
# reference_lid: LigandID = canonical_sites.reference_site.reference_ligand_id
Expand Down
72 changes: 72 additions & 0 deletions src/ligand_neighbourhood_alignment/alignment_heirarchy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from ligand_neighbourhood_alignment import dt

AlignmentHeirarchy: dict[str, tuple[str, str]]


def _derive_alignment_heirarchy(assemblies: dict[str, dt.Assembly]) -> AlignmentHeirarchy:
# The Alignment hierarchy is the graph of alignments one must perform in order to get from
# a ligand canonical site to the Reference Assembly Frame

# In order to calculate the assembly the following steps are performed:
# 1. Determine the Assembly priority
# 2. Determine the Chain priority
# 3. Find each assembly's reference
# 4. Check per-chain RMSDs and warn if any are high

# 1. Determine the Assembly priority
assembly_priority = {_j: _assembly_name for _j, _assembly_name in enumerate(assemblies)}

# 2. Determine the Chain priority and map assembly names to chains
chain_priority = {}
assembly_chains = {}
chain_priority_count = 0
for _j, _assembly_name in assembly_priority.items():
assembly = assemblies[_assembly_name]
assembly_chains[_assembly_name] = []
for _generator in assembly.generators:
_biological_chain_name = _generator.biomol
assembly_chains[_assembly_name].append(_biological_chain_name)
if _biological_chain_name not in chain_priority.values():
chain_priority[chain_priority_count] = _biological_chain_name
chain_priority_count += 1

# 3. Find each assembly's reference
reference_assemblies = {}
for _assembly_name, _assembly in assemblies.items():
# Get the highest priority chain
reference_chain = min(
[_generator.chain for _generator in _assembly.generators], key=lambda _x: chain_priority[_x]
)

# Get the highest priority assembly in which it occurs
reference_assembly = min(
[
_assembly_name
for _assembly_name in assembly_chains
if reference_chain in assembly_chains[_assembly_name]
],
key=lambda _x: assembly_priority[_x],
)
reference_assemblies[_assembly_name] = (reference_assembly, reference_chain)

# 4. Check per-chain RMSDs and warn if any are high
# TODO

return reference_assemblies


def _chain_to_biochain(chain_name, xtalform: dt.XtalForm, assemblies: dict[str, dt.Assembly]) -> str:
for _xtal_assembly_name, _xtal_assembly in xtalform.assemblies.items():
for _j, _chain_name in enumerate(_xtal_assembly.chains):
if chain_name == _chain_name:
return assemblies[_xtal_assembly.assembly].generators[_j].biomol


StructureLandmarks: dict[tuple[str, str, str], tuple[float, float, float]]


def _calculate_assembly_transform(
assembly_name: str, alignment_heirarchy: AlignmentHeirarchy, assembly_landmarks: dict[str, StructureLandmarks]
):
# Get the chain to align to
...
Loading

0 comments on commit 882a18d

Please sign in to comment.