From 10c1fd5800ae3a31fc01e16914919dfd8113e701 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sun, 9 Feb 2025 15:49:57 -0800 Subject: [PATCH] Format --- lib/models/src/models/dlrm/dlrm.cc | 6 +++--- lib/op-attrs/include/op-attrs/datatype.h | 5 +++-- lib/op-attrs/src/op-attrs/datatype.cc | 20 ++++++++++++++------ lib/op-attrs/src/op-attrs/ops/cast.cc | 6 ++++-- lib/op-attrs/test/src/op-attrs/datatype.cc | 21 +++++++++++---------- 5 files changed, 35 insertions(+), 23 deletions(-) diff --git a/lib/models/src/models/dlrm/dlrm.cc b/lib/models/src/models/dlrm/dlrm.cc index 4d9c8545f4..718e709352 100644 --- a/lib/models/src/models/dlrm/dlrm.cc +++ b/lib/models/src/models/dlrm/dlrm.cc @@ -129,8 +129,7 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) { std::vector sparse_inputs = repeat(num_elements(config.embedding_size), [&]() { return create_input_tensor( - {config.batch_size, config.embedding_bag_size}, - DataType::INT64); + {config.batch_size, config.embedding_bag_size}, DataType::INT64); }); tensor_guid_t dense_input = create_input_tensor( @@ -146,7 +145,8 @@ ComputationGraph get_dlrm_computation_graph(DLRMConfig const &config) { std::vector emb_outputs = transform( zip(config.embedding_size, sparse_inputs), - [&](std::pair const &combined_pair) -> tensor_guid_t { + [&](std::pair const &combined_pair) + -> tensor_guid_t { return create_dlrm_sparse_embedding_network( /*cgb=*/cgb, /*config=*/config, diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 0404e4022e..e17f51b73a 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -53,13 +53,14 @@ using real_type_t = typename data_type_enum_to_class
::type; nonnegative_int size_of_datatype(DataType); /** - * @brief Maximally semantics-preserving casts, not including identity + * @brief Maximally semantics-preserving casts, not including identity * casts (e.g., `float -> float` returns `false`) */ bool can_strictly_promote_datatype_from_to(DataType from, DataType to); /** - * @brief Equivalent to [`torch.can_cast`](https://pytorch.org/docs/stable/generated/torch.can_cast.html), + * @brief Equivalent to + * [`torch.can_cast`](https://pytorch.org/docs/stable/generated/torch.can_cast.html), * except that identity casts (e.g., `float -> float`) return `false` */ bool can_torch_strictly_promote_datatype_from_to(DataType from, DataType to); diff --git a/lib/op-attrs/src/op-attrs/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc index 2fc57a9cd2..f8791521ab 100644 --- a/lib/op-attrs/src/op-attrs/datatype.cc +++ b/lib/op-attrs/src/op-attrs/datatype.cc @@ -28,8 +28,11 @@ bool can_strictly_promote_datatype_from_to(DataType src, DataType dst) { std::unordered_set allowed; switch (src) { case DataType::BOOL: - allowed = { - DataType::INT32, DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE}; + allowed = {DataType::INT32, + DataType::INT64, + DataType::HALF, + DataType::FLOAT, + DataType::DOUBLE}; break; case DataType::INT32: allowed = {DataType::INT64}; @@ -55,14 +58,19 @@ bool can_torch_strictly_promote_datatype_from_to(DataType src, DataType dst) { std::unordered_set allowed; switch (src) { case DataType::BOOL: - allowed = { - DataType::INT32, DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE}; + allowed = {DataType::INT32, + DataType::INT64, + DataType::HALF, + DataType::FLOAT, + DataType::DOUBLE}; break; case DataType::INT32: - allowed = {DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE}; + allowed = { + DataType::INT64, DataType::HALF, DataType::FLOAT, DataType::DOUBLE}; break; case DataType::INT64: - allowed = {DataType::INT32, DataType::HALF, DataType::FLOAT, DataType::DOUBLE}; + allowed = { + DataType::INT32, DataType::HALF, DataType::FLOAT, DataType::DOUBLE}; break; case DataType::HALF: allowed = {DataType::FLOAT, DataType::DOUBLE}; diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc index e7dc586b13..4bdef65457 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast.cc @@ -6,7 +6,8 @@ namespace FlexFlow { tl::expected get_output_shape(CastAttrs const &attrs, TensorShape const &input) { - if (!can_torch_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) { + if (!can_torch_strictly_promote_datatype_from_to(input.data_type, + attrs.dtype)) { return tl::unexpected(fmt::format( "Cast cannot strictly promote input datatype {} to output datatype {}", input.data_type, @@ -21,7 +22,8 @@ tl::expected tl::expected get_output_shape(CastAttrs const &attrs, ParallelTensorShape const &input) { - if (!can_torch_strictly_promote_datatype_from_to(input.data_type, attrs.dtype)) { + if (!can_torch_strictly_promote_datatype_from_to(input.data_type, + attrs.dtype)) { return tl::unexpected(fmt::format( "Cast cannot strictly promote input datatype {} to output datatype {}", input.data_type, diff --git a/lib/op-attrs/test/src/op-attrs/datatype.cc b/lib/op-attrs/test/src/op-attrs/datatype.cc index dadc918e10..0289bfdb87 100644 --- a/lib/op-attrs/test/src/op-attrs/datatype.cc +++ b/lib/op-attrs/test/src/op-attrs/datatype.cc @@ -35,22 +35,23 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("can_torch_strictly_promote_datatype_from_to(DataType, DataType)") { - CHECK( - can_torch_strictly_promote_datatype_from_to(DataType::BOOL, DataType::INT32)); + CHECK(can_torch_strictly_promote_datatype_from_to(DataType::BOOL, + DataType::INT32)); CHECK(can_torch_strictly_promote_datatype_from_to(DataType::INT32, - DataType::INT64)); + DataType::INT64)); CHECK(can_torch_strictly_promote_datatype_from_to(DataType::FLOAT, - DataType::DOUBLE)); + DataType::DOUBLE)); RC_SUBCASE("is strict", [](DataType d) { RC_ASSERT(!can_torch_strictly_promote_datatype_from_to(d, d)); }); - RC_SUBCASE("is transitive if end-points are not the same", [](DataType d1, DataType d2, DataType d3) { - RC_PRE(can_torch_strictly_promote_datatype_from_to(d1, d2)); - RC_PRE(can_torch_strictly_promote_datatype_from_to(d2, d3)); - RC_PRE(d1 != d3); - RC_ASSERT(can_torch_strictly_promote_datatype_from_to(d1, d3)); - }); + RC_SUBCASE("is transitive if end-points are not the same", + [](DataType d1, DataType d2, DataType d3) { + RC_PRE(can_torch_strictly_promote_datatype_from_to(d1, d2)); + RC_PRE(can_torch_strictly_promote_datatype_from_to(d2, d3)); + RC_PRE(d1 != d3); + RC_ASSERT(can_torch_strictly_promote_datatype_from_to(d1, d3)); + }); } }