Skip to content

Commit

Permalink
Fix incorrectly parsed properties in Scan.from_dict (cctbx#688)
Browse files Browse the repository at this point in the history
Added a missing test for this.

Fixes 2606.
  • Loading branch information
toastisme authored Feb 12, 2024
1 parent 03e7ee0 commit 80f81f5
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 30 deletions.
1 change: 1 addition & 0 deletions newsfragments/688.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed `Scan.from_dict` bug where some properties were not correctly parsed.
11 changes: 9 additions & 2 deletions src/dxtbx/model/boost_python/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace dxtbx { namespace model { namespace boost_python {
value[0].attr("__class__").attr("__name__"));

// Handled explicitly as it is in deg when serialised but rad in code
if (key == "oscillation") {
if (key == "oscillation" || key == "oscillation_width") {
DXTBX_ASSERT(obj_type == "float");
scitbx::af::shared<double> osc =
boost::python::extract<scitbx::af::shared<double> >(value);
Expand Down Expand Up @@ -215,7 +215,14 @@ namespace dxtbx { namespace model { namespace boost_python {
dxtbx::af::flex_table_suite::column_to_object_visitor visitor;

for (const_iterator it = properties.begin(); it != properties.end(); ++it) {
if (it->first == "oscillation") { // Handled explicitly due to unit conversion
if (it->first
== "oscillation_width") { // Handled explicitly due to unit conversion
vec2<double> osc_deg = obj.get_oscillation_in_deg();
boost::python::list lst = boost::python::list();
lst.append(osc_deg[1]);
properties_dict[it->first] = lst;
} else if (it->first
== "oscillation") { // Handled explicitly due to unit conversion
properties_dict[it->first] =
boost::python::tuple(obj.get_oscillation_arr_in_deg());
} else {
Expand Down
125 changes: 97 additions & 28 deletions src/dxtbx/model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,59 +113,128 @@ def from_dict(d, t=None):
The scan model
"""

def convert_oscillation_to_vec2(properties_dict):
def add_properties_table(scan_dict, num_images):

"""
If oscillation is in properties_dict,
shared<double> is converted to vec2<double> and
oscillation_width is removed (if present) to ensure
it is replaced correctly if updating t dict from d dict
Handles legacy case before Scan had a properties table.
Moves oscillation, epochs, and exposure times to a properties
table and adds this to scan_dict.
"""

if "oscillation" not in properties_dict:
assert "oscillation_width" not in properties_dict
return properties_dict
if "oscillation_width" in properties_dict:
assert "oscillation" in properties_dict
properties_dict["oscillation"] = (
properties_dict["oscillation"][0],
properties_dict["oscillation_width"][0],
)
del properties_dict["oscillation_width"]
return properties_dict
properties_dict["oscillation"] = (
properties_dict["oscillation"][0],
properties_dict["oscillation"][1] - properties_dict["oscillation"][0],
properties = {}
if scan_dict:
if "oscillation" in scan_dict:
if num_images == 1:
properties["oscillation_width"] = [scan_dict["oscillation"][1]]
properties["oscillation"] = [scan_dict["oscillation"][0]]

else:
osc = scan_dict["oscillation"]
properties["oscillation"] = [
osc[0] + (osc[1] - osc[0]) * i for i in range(num_images)
]
del scan_dict["oscillation"]
if "exposure_time" in scan_dict:
properties["exposure_time"] = scan_dict["exposure_time"]
del scan_dict["exposure_time"]
if "epochs" in scan_dict:
properties["epochs"] = scan_dict["epochs"]
del scan_dict["epochs"]
scan_dict["properties"] = make_properties_table_consistent(
properties, num_images
)
return scan_dict

def make_properties_table_consistent(properties, num_images):

"""
Handles legacy case before Scan had a properties table.
Ensures oscillation, epochs, and exposure times have the same length.
"""

return properties_dict
if not properties:
return properties

if "oscillation" in properties:
assert len(properties["oscillation"]) > 0

if num_images == 1 and "oscillation_width" not in properties:
assert len(properties["oscillation"]) > 1
properties["oscillation_width"] = [properties["oscillation"][1]]
properties["oscillation"] = [properties["oscillation"][0]]
elif num_images > 1:
osc_0 = properties["oscillation"][0]
if "oscillation_width" in properties:
osc_1 = properties["oscillation_width"][0]
del properties["oscillation_width"]
else:
assert len(properties["oscillation"]) > 1
osc_1 = (
properties["oscillation"][1] - properties["oscillation"][0]
)
properties["oscillation"] = [
osc_0 + osc_1 * i for i in range(num_images)
]

if "exposure_time" in properties:
assert len(properties["exposure_time"]) > 0

# Assume same exposure time for each image
properties["exposure_time"] = [
properties["exposure_time"][0] for i in range(num_images)
]

if "epochs" in properties:
assert len(properties["epochs"]) > 0
# If 1 epoch, assume increasing by epochs[0]
# Else assume increasing as epochs[1] - epochs[0]
if len(properties["epochs"]) == 1:
properties["epochs"] = [
properties["epochs"][0] + properties["epochs"][0] * i
for i in range(num_images)
]
else:
diff = properties["epochs"][1] - properties["epochs"][0]
properties["epochs"] = [
properties["epochs"][0] + i * diff for i in range(num_images)
]
return properties

if d is None and t is None:
return None
joint = t.copy() if t else {}

# Accounting for legacy cases where t or d does not
# contain properties dict
num_images = None
if "image_range" in d:
num_images = 1 + d["image_range"][1] - d["image_range"][0]
elif "image_range" in joint:
num_images = 1 + joint["image_range"][1] - joint["image_range"][0]

if "properties" in joint and "properties" in d:
properties = t["properties"].copy()
properties.update(d["properties"])
joint.update(d)
joint["properties"] = properties
elif "properties" in d:
joint = add_properties_table(joint, num_images)
d_copy = d.copy()
d_copy["properties"] = convert_oscillation_to_vec2(
d_copy["properties"].copy()
joint["properties"].update(d_copy["properties"])
joint["properties"] = make_properties_table_consistent(
joint["properties"], num_images
)
joint.update(**d_copy["properties"])
del d_copy["properties"]
joint.update(d_copy)
elif "properties" in joint:
joint["properties"] = convert_oscillation_to_vec2(
joint["properties"].copy()
d = add_properties_table(d, num_images)
d_copy = d.copy()
joint["properties"].update(d_copy["properties"])
joint["properties"] = make_properties_table_consistent(
joint["properties"], num_images
)
joint.update(**joint["properties"])
del joint["properties"]
joint.update(d)
del d_copy["properties"]
joint.update(d_copy)
else:
joint.update(d)

Expand Down
22 changes: 22 additions & 0 deletions tests/model/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,25 @@ def test_print_scan():
)
expected_scan_string = "Scan:\n number of images: 10\n image range: {1,10}\n test_bool: 1 - 1\n test_float: 0 - 18\n test_int: 0 - 9\n test_string: test_0 - test_9\n test_vec2_double: {2,2} - {2,2}\n test_vec3_double: {1,1,1} - {1,1,1}\n"
assert scan.__str__() == expected_scan_string


def test_scan_properties_from_dict():
image_range = (1, 10)
properties = {"test": list(range(10))}
scan = ScanFactory.make_scan_from_properties(image_range, properties)
assert scan == ScanFactory.from_dict(scan.to_dict())

image_range = (1, 1)
properties = {"oscillation": [1.0], "oscillation_width": [0.5]}
scan = ScanFactory.make_scan_from_properties(image_range, properties)
assert scan == ScanFactory.from_dict(scan.to_dict())

scan = ScanFactory.from_dict(
{
"oscillation": [1.0, 0.5],
"image_range": [1, 1],
"exposure_time": [0.5],
"epochs": [1],
}
)
assert scan == ScanFactory.from_dict(scan.to_dict())

0 comments on commit 80f81f5

Please sign in to comment.