Skip to content

Commit

Permalink
Merge pull request #2627 from matthewturk/nbody_tests
Browse files Browse the repository at this point in the history
Adding nbody answer tests using particle plots
  • Loading branch information
munkm authored Jun 19, 2020
2 parents 44eb8a8 + 04cbde2 commit ec9a562
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
15 changes: 7 additions & 8 deletions yt/frontends/flash/tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
requires_ds, \
small_patch_amr, \
data_dir_load, \
sph_answer
nbody_answer
from yt.frontends.flash.api import FLASHDataset, \
FLASHParticleDataset
from collections import OrderedDict
Expand Down Expand Up @@ -54,12 +54,11 @@ def test_mu():

fid_1to3_b1_fields = OrderedDict(
[
(("deposit", "all_density"), None),
(("deposit", "all_count"), None),
(("deposit", "all_cic"), None),
(("deposit", "all_cic_velocity_x"), ("deposit", "all_cic")),
(("deposit", "all_cic_velocity_y"), ("deposit", "all_cic")),
(("deposit", "all_cic_velocity_z"), ("deposit", "all_cic")),
(("all", "particle_mass"), None),
(("all", "particle_ones"), None),
(("all", "particle_velocity_x"), ("all", "particle_mass")),
(("all", "particle_velocity_y"), ("all", "particle_mass")),
(("all", "particle_velocity_z"), ("all", "particle_mass")),
]
)

Expand Down Expand Up @@ -92,6 +91,6 @@ def test_FLASH25_dataset():
@requires_ds(fid_1to3_b1, big_data=True)
def test_fid_1to3_b1():
ds = data_dir_load(fid_1to3_b1)
for test in sph_answer(ds, 'fiducial_1to3_b1_hdf5_part_0080', 6684119, fid_1to3_b1_fields):
for test in nbody_answer(ds, 'fiducial_1to3_b1_hdf5_part_0080', 6684119, fid_1to3_b1_fields):
test_fid_1to3_b1.__name__ = test.description
yield test
34 changes: 27 additions & 7 deletions yt/utilities/answer_testing/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,13 @@ def __init__(self, ds_fn, axis, field, weight_field = None,
self.weight_field = weight_field
self.obj_type = obj_type

def _get_frb(self, obj):
proj = self.ds.proj(self.field, self.axis,
weight_field=self.weight_field,
data_source = obj)
frb = proj.to_frb((1.0, 'unitary'), 256)
return proj, frb

def run(self):
if self.obj_type is not None:
obj = create_obj(self.ds, self.obj_type)
Expand All @@ -581,6 +588,13 @@ def compare(self, new_result, old_result):
for k in new_result:
assert_rel_equal(new_result[k], old_result[k], 10)

class PixelizedParticleProjectionValuesTest(PixelizedProjectionValuesTest):

def _get_frb(self, obj):
proj_plot = particle_plots.ParticleProjectionPlot(self.ds, self.axis, [self.field],
weight_field = self.weight_field)
return proj_plot.data_source, proj_plot.frb

class GridValuesTest(AnswerTestingTest):
_type_name = "GridValues"
_attrs = ("field",)
Expand Down Expand Up @@ -969,7 +983,7 @@ def big_patch_amr(ds_fn, fields, input_center="max", input_weight="density"):
dobj_name)


def sph_answer(ds, ds_str_repr, ds_nparticles, fields):
def _particle_answers(ds, ds_str_repr, ds_nparticles, fields, proj_test_class):
if not can_run_ds(ds):
return
assert_equal(str(ds), ds_str_repr)
Expand All @@ -981,18 +995,24 @@ def sph_answer(ds, ds_str_repr, ds_nparticles, fields):
assert_equal(tot, ds_nparticles)
for dobj_name in dso:
for field, weight_field in fields.items():
if field[0] in ds.particle_types:
particle_type = True
else:
particle_type = False
particle_type = field[0] in ds.particle_types
for axis in [0, 1, 2]:
if particle_type is False:
yield PixelizedProjectionValuesTest(
if not particle_type:
yield proj_test_class(
ds, axis, field, weight_field,
dobj_name)
yield FieldValuesTest(ds, field, dobj_name,
particle_type=particle_type)


def nbody_answer(ds, ds_str_repr, ds_nparticles, fields):
return _particle_answers(ds, ds_str_repr, ds_nparticles, fields,
PixelizedParticleProjectionValuesTest)

def sph_answer(ds, ds_str_repr, ds_nparticles, fields):
return _particle_answers(ds, ds_str_repr, ds_nparticles, fields,
PixelizedProjectionValuesTest)

def create_obj(ds, obj_type):
# obj_type should be tuple of
# ( obj_name, ( args ) )
Expand Down

0 comments on commit ec9a562

Please sign in to comment.