Skip to content

Commit

Permalink
Add experiment type (cctbx#702)
Browse files Browse the repository at this point in the history
* Added ExperimentType to Experiment, accessed via get_type().
  • Loading branch information
toastisme authored Apr 24, 2024
1 parent ee74a3d commit 6fcf993
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 75 deletions.
1 change: 1 addition & 0 deletions newsfragments/702.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `Experiment.get_type()` to replace `Experiment.is_still()/Experiment.is_sequence()`
2 changes: 2 additions & 0 deletions src/dxtbx/dxtbx_model_ext.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from scitbx.array_family import shared as flex_shared
from scitbx.array_family.flex import FlexPlain

from dxtbx_model_ext import Probe # type: ignore
from dxtbx_model_ext import ExperimentType

# TypeVar for the set of Experiment models that can be joint-accepted
# - profile, imageset and scalingmodel are handled as 'object'
Expand Down Expand Up @@ -354,6 +355,7 @@ class Experiment:
def is_sequence(self) -> bool: ...
def is_still(self) -> bool: ...
def __contains__(self, obj: TExperimentModel) -> bool: ...
def get_type(self) -> ExperimentType: ...

class ExperimentList:
@overload
Expand Down
14 changes: 12 additions & 2 deletions src/dxtbx/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DetectorNode,
Experiment,
ExperimentList,
ExperimentType,
Goniometer,
GoniometerBase,
KappaDirection,
Expand Down Expand Up @@ -69,6 +70,7 @@
DetectorNode,
Experiment,
ExperimentList,
ExperimentType,
Goniometer,
GoniometerBase,
KappaDirection,
Expand Down Expand Up @@ -599,11 +601,19 @@ def imagesets(self):

def all_stills(self):
"""Check if all the experiments are stills"""
return all(exp.is_still() for exp in self)
return all(exp.get_type() == ExperimentType.STILL for exp in self)

def all_sequences(self):
"""Check if all the experiments are from sequences"""
return all(exp.is_sequence() for exp in self)
return self.all_rotations()

def all_rotations(self):
"""Check if all the experiments are stills"""
return all(exp.get_type() == ExperimentType.ROTATION for exp in self)

def all_tof(self):
"""Check if all the experiments are time-of-flight"""
return all(exp.get_type() == ExperimentType.TOF for exp in self)

def to_dict(self):
"""Serialize the experiment list to dictionary."""
Expand Down
6 changes: 6 additions & 0 deletions src/dxtbx/model/boost_python/experiment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ namespace dxtbx { namespace model { namespace boost_python {
};

void export_experiment() {
enum_<ExperimentType>("ExperimentType")
.value("STILL", STILL)
.value("ROTATION", ROTATION)
.value("TOF", TOF);

class_<Experiment>("Experiment")
.def(init<std::shared_ptr<BeamBase>,
std::shared_ptr<Detector>,
Expand Down Expand Up @@ -118,6 +123,7 @@ namespace dxtbx { namespace model { namespace boost_python {
.def("is_sequence",
&Experiment::is_sequence,
"Check if this experiment represents swept rotation image(s)")
.def("get_type", &Experiment::get_type)
.def_pickle(ExperimentPickleSuite());
}

Expand Down
83 changes: 18 additions & 65 deletions src/dxtbx/model/experiment.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

namespace dxtbx { namespace model {

enum ExperimentType { ROTATION = 1, STILL = 2, TOF = 3 };

/**
* A class to represent what's in an experiment.
*
Expand Down Expand Up @@ -128,9 +130,6 @@ namespace dxtbx { namespace model {
return profile_ == obj || imageset_ == obj || scaling_model_ == obj;
}

/**
* Compare this experiment with another
*/
bool operator==(const Experiment &other) const {
return imageset_ == other.imageset_ && beam_ == other.beam_
&& detector_ == other.detector_ && goniometer_ == other.goniometer_
Expand All @@ -139,18 +138,11 @@ namespace dxtbx { namespace model {
&& identifier_ == other.identifier_;
}

/**
* Check that the experiment is consistent
*/
bool is_consistent() const {
return true; // FIXME
}

/**
* Check if this experiment represents a still image
*/
bool is_still() const {
return !goniometer_ || !scan_ || scan_->is_still();
return get_type() == STILL;
}

/**
Expand All @@ -160,128 +152,89 @@ namespace dxtbx { namespace model {
return !is_still();
}

/**
* Set the beam model
*/
ExperimentType get_type() const {
if (scan_ && scan_->contains("time_of_flight")) {
return TOF;
}
if (!goniometer_ || !scan_ || scan_->is_still()) {
return STILL;
} else {
return ROTATION;
}
}

bool is_consistent() const {
return true; // FIXME
}

void set_beam(std::shared_ptr<BeamBase> beam) {
beam_ = beam;
}

/**
* Get the beam model
*/
std::shared_ptr<BeamBase> get_beam() const {
return beam_;
}

/**
* Get the detector model
*/
void set_detector(std::shared_ptr<Detector> detector) {
detector_ = detector;
}

/**
* Get the detector model
*/
std::shared_ptr<Detector> get_detector() const {
return detector_;
}

/**
* Get the goniometer model
*/
void set_goniometer(std::shared_ptr<Goniometer> goniometer) {
goniometer_ = goniometer;
}

/**
* Get the goniometer model
*/
std::shared_ptr<Goniometer> get_goniometer() const {
return goniometer_;
}

/**
* Get the scan model
*/
void set_scan(std::shared_ptr<Scan> scan) {
scan_ = scan;
}

/**
* Get the scan model
*/
std::shared_ptr<Scan> get_scan() const {
return scan_;
}

/**
* Get the crystal model
*/
void set_crystal(std::shared_ptr<CrystalBase> crystal) {
crystal_ = crystal;
}

/**
* Get the crystal model
*/
std::shared_ptr<CrystalBase> get_crystal() const {
return crystal_;
}

/**
* Get the profile model
*/
void set_profile(boost::python::object profile) {
profile_ = profile;
}

/**
* Get the profile model
*/
boost::python::object get_profile() const {
return profile_;
}

/**
* Get the imageset model
*/
void set_imageset(boost::python::object imageset) {
imageset_ = imageset;
}

/**
* Get the imageset model
*/
boost::python::object get_imageset() const {
return imageset_;
}

/**
* Set the scaling model
*/
void set_scaling_model(boost::python::object scaling_model) {
scaling_model_ = scaling_model;
}

/**
* Get the scaling model
*/
boost::python::object get_scaling_model() const {
return scaling_model_;
}

/**
* Set the identifier
*/
void set_identifier(std::string identifier) {
identifier_ = identifier;
}

/**
* Get the identifier
*/
std::string get_identifier() const {
return identifier_;
}
Expand Down
17 changes: 9 additions & 8 deletions tests/model/test_experiment_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Detector,
Experiment,
ExperimentList,
ExperimentType,
Goniometer,
Scan,
ScanFactory,
Expand Down Expand Up @@ -773,26 +774,26 @@ def test_partial_missing_model_serialization():
check(elist, elist_)


def test_experiment_is_still():
def test_experiment_type():
experiment = Experiment()
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL
experiment.goniometer = Goniometer()
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL
experiment.scan = Scan()
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL
experiment.scan = Scan((1, 1000), (0, 0.05))
assert not experiment.is_still()
assert experiment.get_type() == ExperimentType.ROTATION
# Specifically test the bug from dxtbx#4 triggered by ending on 0°
experiment.scan = Scan((1, 1800), (-90, 0.05))
assert not experiment.is_still()
assert experiment.get_type() == ExperimentType.ROTATION
experiment.scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"time_of_flight": list(range(10))}
)
assert not experiment.is_still()
assert experiment.get_type() == ExperimentType.TOF
experiment.scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"other_property": list(range(10))}
)
assert experiment.is_still()
assert experiment.get_type() == ExperimentType.STILL


def check(el1, el2):
Expand Down

0 comments on commit 6fcf993

Please sign in to comment.