From 350babf3584c3d99e76e4dc0f72a658aa0222afc Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Sun, 23 Feb 2025 19:11:42 -0800 Subject: [PATCH] Passing tests after merge issues --- lib/local-execution/src/allocated_tensors.cc | 5 +- .../src/local_cost_estimator.cc | 59 +++-- .../src/local_training_backing.cc | 2 +- lib/local-execution/src/optimizer.cc | 224 +++++++++++------- lib/local-execution/src/task_registry.cc | 2 +- .../src/unallocated_tensors.cc | 4 +- .../test/src/test_allocated_tensors.cc | 6 - .../test/src/test_local_cost_estimator.cc | 10 +- .../test/src/test_local_tensor_backing.cc | 4 - .../test/src/test_loss_functions.cc | 74 +++--- .../test/src/test_unallocated_tensors.cc | 6 - lib/local-execution/test/src/test_update.cc | 51 ++-- 12 files changed, 247 insertions(+), 200 deletions(-) diff --git a/lib/local-execution/src/allocated_tensors.cc b/lib/local-execution/src/allocated_tensors.cc index 2c40cc3b86..196da16ace 100644 --- a/lib/local-execution/src/allocated_tensors.cc +++ b/lib/local-execution/src/allocated_tensors.cc @@ -54,8 +54,7 @@ bool are_allocated_gradient_tensors_valid( for (std::pair const &tensor_to_grad : allocated_tensors.gradient_mapping) { if (tensor_attrs.count(tensor_to_grad.first)) { - if (tensor_attrs.at(tensor_to_grad.first).create_gradients == - CreateGrad::NO) { + if (tensor_attrs.at(tensor_to_grad.first).create_grad == CreateGrad::NO) { return false; } @@ -96,7 +95,7 @@ bool are_allocated_optimizer_tensors_valid( for (std::pair> const &tensor_to_optimizers : allocated_tensors.optimizer_mapping) { if (tensor_attrs.count(tensor_to_optimizers.first)) { - if (tensor_attrs.at(tensor_to_optimizers.first).create_gradients == + if (tensor_attrs.at(tensor_to_optimizers.first).create_grad == CreateGrad::NO) { return false; } diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index 5c17f011e4..9828a67293 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -4,13 +4,13 @@ #include "local-execution/tracked_allocator.h" #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" -#include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/computation_graph.h" +#include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/machine_view.dtg.h" #include "pcg/parallel_tensor_attrs.h" #include "utils/containers/concat_vectors.h" +#include "utils/containers/get_only.h" #include "utils/containers/sum.h" -#include "pcg/parallel_tensor_attrs.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" @@ -26,41 +26,36 @@ static ComputationGraph create_computation_graph_for_local_cost_estimation( std::vector const &outputs) { ComputationGraph computation_graph = make_empty_computation_graph(); - // create layer for inputs - auto get_vector_piece_attrs_from_parallel_tensor_shape = - [](std::vector const ¶llel_shapes) { - return transform(parallel_shapes, [](ParallelTensorShape const &p) { - return TensorAttrs{ - get_piece_shape(p), std::nullopt, std::nullopt, CreateGrad::YES}; - }); - }; - - LayerAddedResult inputs_layer = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{InputAttrs{}}, "inputs"}, - {}, - get_vector_piece_attrs_from_parallel_tensor_shape(inputs)); - - // create layer for weights - auto get_vector_piece_attrs_from_parallel_tensor_attrs = - [](std::vector const ¶llel_attrs) { - return transform(parallel_attrs, [](ParallelTensorAttrs const &p) { - return get_piece_attrs(p); - }); - }; - - LayerAddedResult weights_layer = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{InputAttrs{}}, "weights"}, - {}, - get_vector_piece_attrs_from_parallel_tensor_attrs(weights)); + std::vector input_tensors; + for (ParallelTensorShape const &input : inputs) { + LayerAddedResult inputs_layer = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{InputAttrs{get_piece_shape(input)}}, + std::nullopt}, + {}, + {}); + input_tensors.push_back(get_only(inputs_layer.outputs)); + } + + std::vector weight_tensors; + for (ParallelTensorAttrs const &weight : weights) { + LayerAddedResult weights_layer = + add_layer(computation_graph, + LayerAttrs{ComputationGraphOpAttrs{WeightAttrs{ + get_piece_shape(weight.shape), + InitializerAttrs{ZeroInitializerAttrs{}}}}, + std::nullopt}, + {}, + {}); + weight_tensors.push_back(get_only(weights_layer.outputs)); + } // create operator layer LayerAddedResult operator_layer = add_layer( computation_graph, LayerAttrs{compgraph_op_attrs_from_pcg_op_attrs(op), "operator"}, - concat_vectors(inputs_layer.outputs, weights_layer.outputs), - get_vector_piece_attrs_from_parallel_tensor_attrs(outputs)); + input_tensors, + weight_tensors); return computation_graph; } diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index df15c707b2..77e62e52af 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -213,7 +213,7 @@ void execute_update(LocalTrainingBacking const &local_training_backing, Allocator &allocator) { LayerAttrs layer_attrs = get_layer_attrs(local_training_backing.computation_graph, node); - if (layer_attrs.attrs.has()) { + if (layer_attrs.op_attrs.has()) { // get tensors tensor_guid_t weight_tensor = get_only( get_outgoing_tensors(local_training_backing.computation_graph, node)); diff --git a/lib/local-execution/src/optimizer.cc b/lib/local-execution/src/optimizer.cc index a69ae9da61..1b9ce83d14 100644 --- a/lib/local-execution/src/optimizer.cc +++ b/lib/local-execution/src/optimizer.cc @@ -1,6 +1,7 @@ #include "local-execution/optimizer.h" #include "kernels/optimizer_kernels.h" #include "task-spec/profiling.h" +#include "utils/containers/get_only.h" #include "utils/overload.h" namespace FlexFlow { @@ -24,9 +25,12 @@ TaskSignature get_sgd_update_signature() { add_arg_slot(sig, ATTRS); add_arg_slot(sig, PROFILING); - if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - add_unchecked_arg_slot(sig, HANDLE); - } + add_unchecked_arg_slot( + sig, HANDLE); // how to deal with removal of ParamSync? + + // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + // add_unchecked_arg_slot(sig, HANDLE); + // } return sig; } @@ -44,12 +48,16 @@ TaskInvocation sgd_update(SGDOptimizerAttrs const &attrs, b.bind_arg(ATTRS, attrs); b.bind_arg(PROFILING, profiling_settings()); - if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - b.bind_arg(HANDLE, ff_handle()); - return TaskInvocation{task_id_t::SGD_UPD_NCCL_TASK_ID, b}; - } else { - return TaskInvocation{task_id_t::SGD_UPD_PS_TASK_ID, b}; - } + b.bind_arg(HANDLE, ff_handle()); + return TaskInvocation{task_id_t::SGD_UPD_NCCL_TASK_ID, + b}; // how to deal with removal of ParamSync? + + // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + // b.bind_arg(HANDLE, ff_handle()); + // return TaskInvocation{task_id_t::SGD_UPD_NCCL_TASK_ID, b}; + // } else { + // return TaskInvocation{task_id_t::SGD_UPD_PS_TASK_ID, b}; + // } } static void sgd_update_task_impl(TaskArgumentAccessor const &acc) { @@ -73,35 +81,49 @@ static void sgd_update_task_impl(TaskArgumentAccessor const &acc) { sgd_v_ptr = sgd_v.get_float_ptr(); } - if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - auto handle = acc.get_argument(HANDLE); - profile(sgd_nccl_update_task_gpu, - profiling, - "[SGD NCCL] update_time = %.2lfms\n", - attrs.lr, - attrs.momentum, - attrs.nesterov, - attrs.weight_decay, - handle, - weight_grad.get_float_ptr(), - size, - weight.get_float_ptr(), - sgd_v_ptr); - - } else { - profile(sgd_ps_update_task_gpu, - profiling, - "[SGD PS] update_time = %.2lfms\n", - attrs.lr, - attrs.momentum, - attrs.nesterov, - attrs.weight_decay, - weight_grad.get_float_ptr(), - size, - num_replicas, - weight.get_float_ptr(), - sgd_v_ptr); - } + auto handle = acc.get_argument(HANDLE); + profile(sgd_nccl_update_task_gpu, + profiling, + "[SGD NCCL] update_time = %.2lfms\n", + attrs.lr, + attrs.momentum, + attrs.nesterov, + attrs.weight_decay, + handle, + weight_grad.get_float_ptr(), + size, + weight.get_float_ptr(), + sgd_v_ptr); // how to deal with removal of ParamSync? + + // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + // auto handle = acc.get_argument(HANDLE); + // profile(sgd_nccl_update_task_gpu, + // profiling, + // "[SGD NCCL] update_time = %.2lfms\n", + // attrs.lr, + // attrs.momentum, + // attrs.nesterov, + // attrs.weight_decay, + // handle, + // weight_grad.get_float_ptr(), + // size, + // weight.get_float_ptr(), + // sgd_v_ptr); + + // } else { + // profile(sgd_ps_update_task_gpu, + // profiling, + // "[SGD PS] update_time = %.2lfms\n", + // attrs.lr, + // attrs.momentum, + // attrs.nesterov, + // attrs.weight_decay, + // weight_grad.get_float_ptr(), + // size, + // num_replicas, + // weight.get_float_ptr(), + // sgd_v_ptr); + // } } TaskImplFunction get_sgd_update_task_impl() { @@ -117,9 +139,11 @@ TaskSignature get_adam_update_signature() { add_arg_slot(sig, ATTRS); add_arg_slot(sig, PROFILING); - if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - add_unchecked_arg_slot(sig, HANDLE); - } + add_unchecked_arg_slot( + sig, HANDLE); // how to deal with removal of ParamSync? + // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + // add_unchecked_arg_slot(sig, HANDLE); + // } return sig; } @@ -135,13 +159,16 @@ TaskInvocation adam_update(AdamOptimizerAttrs const &attrs, b.bind_optimizer(ADAM_V, adam_v); b.bind_arg(ATTRS, attrs); b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(HANDLE, ff_handle()); + return TaskInvocation{task_id_t::ADAM_UPD_NCCL_TASK_ID, + b}; // how to deal with removal of ParamSync? - if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - b.bind_arg(HANDLE, ff_handle()); - return TaskInvocation{task_id_t::ADAM_UPD_NCCL_TASK_ID, b}; - } else { - return TaskInvocation{task_id_t::ADAM_UPD_PS_TASK_ID, b}; - } + // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + // b.bind_arg(HANDLE, ff_handle()); + // return TaskInvocation{task_id_t::ADAM_UPD_NCCL_TASK_ID, b}; + // } else { + // return TaskInvocation{task_id_t::ADAM_UPD_PS_TASK_ID, b}; + // } } static void adam_update_task_impl(TaskArgumentAccessor const &acc) { @@ -162,38 +189,54 @@ static void adam_update_task_impl(TaskArgumentAccessor const &acc) { int num_replicas = weight_grad.shape.get_volume().unwrap_nonnegative() / weight.shape.get_volume().unwrap_nonnegative(); - if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { - auto handle = acc.get_argument(HANDLE); - profile(adam_nccl_update_task_gpu, - profiling, - "[Adam NCCL] update_time = %.2lfms\n", - attrs.alpha_t, - attrs.beta1, - attrs.beta2, - attrs.weight_decay, - attrs.epsilon, - size, - handle, - weight_grad.get_float_ptr(), - m_tensor.get_float_ptr(), - v_tensor.get_float_ptr(), - weight.get_float_ptr()); - } else { - profile(adam_ps_update_task_gpu, - profiling, - "[Adam NCCL] update_time = %.2lfms\n", - attrs.alpha_t, - attrs.beta1, - attrs.beta2, - attrs.weight_decay, - attrs.epsilon, - size, - num_replicas, - weight_grad.get_float_ptr(), - m_tensor.get_float_ptr(), - v_tensor.get_float_ptr(), - weight.get_float_ptr()); - } + auto handle = acc.get_argument(HANDLE); + profile(adam_nccl_update_task_gpu, + profiling, + "[Adam NCCL] update_time = %.2lfms\n", + attrs.alpha_t, + attrs.beta1, + attrs.beta2, + attrs.weight_decay, + attrs.epsilon, + size, + handle, + weight_grad.get_float_ptr(), + m_tensor.get_float_ptr(), + v_tensor.get_float_ptr(), + weight.get_float_ptr()); // how to deal with removal of ParamSync? + + // if (CHOSEN_SYNC_TYPE == ParamSync::NCCL) { + // auto handle = acc.get_argument(HANDLE); + // profile(adam_nccl_update_task_gpu, + // profiling, + // "[Adam NCCL] update_time = %.2lfms\n", + // attrs.alpha_t, + // attrs.beta1, + // attrs.beta2, + // attrs.weight_decay, + // attrs.epsilon, + // size, + // handle, + // weight_grad.get_float_ptr(), + // m_tensor.get_float_ptr(), + // v_tensor.get_float_ptr(), + // weight.get_float_ptr()); + // } else { + // profile(adam_ps_update_task_gpu, + // profiling, + // "[Adam NCCL] update_time = %.2lfms\n", + // attrs.alpha_t, + // attrs.beta1, + // attrs.beta2, + // attrs.weight_decay, + // attrs.epsilon, + // size, + // num_replicas, + // weight_grad.get_float_ptr(), + // m_tensor.get_float_ptr(), + // v_tensor.get_float_ptr(), + // weight.get_float_ptr()); + // } } TaskImplFunction get_adam_update_task_impl() { @@ -211,17 +254,18 @@ TaskInvocation get_update_invocation( tensor_guid_t const &weight, gradient_tensor_t const &weight_grad, std::vector const &grad_buffer_tensors) { - return attrs.visit(overload{ - [&](SGDOptimizerAttrs const &s) { - return sgd_update(s, weight, weight_grad, grad_buffer_tensors.at(0)); - }, - [&](AdamOptimizerAttrs const &s) { - return adam_update(s, - weight, - weight_grad, - grad_buffer_tensors.at(0), - grad_buffer_tensors.at(1)); - }}); + return attrs.visit( + overload{[&](SGDOptimizerAttrs const &s) { + return sgd_update( + s, weight, weight_grad, get_only(grad_buffer_tensors)); + }, + [&](AdamOptimizerAttrs const &s) { + return adam_update(s, + weight, + weight_grad, + grad_buffer_tensors.at(0), + grad_buffer_tensors.at(1)); + }}); } TaskImplFunction get_update_task_impl(OptimizerAttrs const &attrs) { diff --git a/lib/local-execution/src/task_registry.cc b/lib/local-execution/src/task_registry.cc index 487bd4420e..3d9dec1e26 100644 --- a/lib/local-execution/src/task_registry.cc +++ b/lib/local-execution/src/task_registry.cc @@ -19,7 +19,7 @@ TaskRegistry construct_task_registry( fwd_task_ids.insert({node, std::nullopt}); bwd_task_ids.insert({node, std::nullopt}); - ComputationGraphOpAttrs attrs = layer_attrs.second.attrs; + ComputationGraphOpAttrs attrs = layer_attrs.second.op_attrs; std::vector task_ids = get_task_ids(attrs); for (task_id_t const &task_id : task_ids) { diff --git a/lib/local-execution/src/unallocated_tensors.cc b/lib/local-execution/src/unallocated_tensors.cc index ea64a46051..363d1eedef 100644 --- a/lib/local-execution/src/unallocated_tensors.cc +++ b/lib/local-execution/src/unallocated_tensors.cc @@ -23,7 +23,7 @@ UnallocatedTensors generate_unallocated_tensors( tensor_type_shapes.insert({tensor_guid_type, tensor_attrs.shape}); } - if (tensor_attrs.create_gradients == CreateGrad::YES && + if (tensor_attrs.create_grad == CreateGrad::YES && !allocated_tensors.gradient_mapping.count(tensor_guid)) { gradient_tensor_t gradient_tensor = gradient_tensor_source.new_gradient_tensor(); @@ -61,7 +61,7 @@ UnallocatedTensors generate_unallocated_tensors_with_optimizer( tensor_attrs_mapping) { tensor_guid_t tensor_guid = tensor_guid_attrs.first; TensorAttrs tensor_attrs = tensor_guid_attrs.second; - if (tensor_attrs.create_gradients == CreateGrad::YES) { + if (tensor_attrs.create_grad == CreateGrad::YES) { std::vector optimizer_tensors; int num_optimizer_tensors_to_allocate = diff --git a/lib/local-execution/test/src/test_allocated_tensors.cc b/lib/local-execution/test/src/test_allocated_tensors.cc index 99abd538d5..45fc8e0a1c 100644 --- a/lib/local-execution/test/src/test_allocated_tensors.cc +++ b/lib/local-execution/test/src/test_allocated_tensors.cc @@ -31,20 +31,14 @@ TEST_SUITE(FF_TEST_SUITE) { TensorAttrs tensor_attrs_1_no_grad = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 10_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::NO}; TensorAttrs tensor_attrs_2_no_grad = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 20_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::NO}; TensorAttrs tensor_attrs_3_with_grad = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 30_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::YES}; GenericTensorAccessorW tensor_backing_1 = diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 7220d2a367..30682c9a48 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -50,18 +50,12 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { ParallelTensorShape weights_shape = throw_if_unexpected( get_weights_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); ParallelTensorAttrs weight_attrs = - ParallelTensorAttrs{weights_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; + ParallelTensorAttrs{weights_shape, CreateGrad::YES}; ParallelTensorShape output_shape = throw_if_unexpected( get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); ParallelTensorAttrs output_attrs = - ParallelTensorAttrs{output_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; + ParallelTensorAttrs{output_shape, CreateGrad::YES}; CostDetails result = cost_estimator.estimate_cost( PCGOperatorAttrs{attrs}, diff --git a/lib/local-execution/test/src/test_local_tensor_backing.cc b/lib/local-execution/test/src/test_local_tensor_backing.cc index 083b677e18..594051c2f1 100644 --- a/lib/local-execution/test/src/test_local_tensor_backing.cc +++ b/lib/local-execution/test/src/test_local_tensor_backing.cc @@ -96,14 +96,10 @@ TEST_SUITE(FF_TEST_SUITE) { TensorAttrs allocated_tensor_attrs = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 10_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::NO}; TensorAttrs unallocated_tensor_attrs = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 20_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::YES}; GenericTensorAccessorW allocated_tensor_backing = diff --git a/lib/local-execution/test/src/test_loss_functions.cc b/lib/local-execution/test/src/test_loss_functions.cc index c0386a4171..bb3e83cc4d 100644 --- a/lib/local-execution/test/src/test_loss_functions.cc +++ b/lib/local-execution/test/src/test_loss_functions.cc @@ -9,6 +9,7 @@ #include "pcg/computation_graph_builder.h" #include "pcg/optimizer_attrs.dtg.h" #include "test_utils.h" +#include "utils/containers/get_only.h" namespace FlexFlow { @@ -24,19 +25,20 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { loss_tensor_source.new_loss_tensor(); nonnegative_int batch_size = 10_n; - nonnegative_int data_dim = 100_n; + nonnegative_int data_dim = 16_n; + nonnegative_int output_dim = 32_n; - TensorShape input_tensor_shape = TensorShape{ - TensorDims{FFOrdered{batch_size, data_dim}}, + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; - TensorShape reduced_input_tensor_shape = + TensorShape reduced_tensor_shape = TensorShape{TensorDims{FFOrdered{batch_size, 1_n}}, DataType::FLOAT}; GenericTensorAccessorW label_for_nonconfigurable_loss_attrs_backing = - allocator.allocate_tensor(reduced_input_tensor_shape); + allocator.allocate_tensor(output_tensor_shape); GenericTensorAccessorW label_for_sparse_cce_loss_attrs_backing = - allocator.allocate_tensor(reduced_input_tensor_shape); + allocator.allocate_tensor(reduced_tensor_shape); AllocatedTensors allocated_tensors = AllocatedTensors{ {{TensorTypeVariant{label_for_nonconfigurable_loss_attrs}, label_for_nonconfigurable_loss_attrs_backing}, @@ -48,24 +50,40 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { // construct computation graph ComputationGraph computation_graph = make_empty_computation_graph(); - TensorAttrs input_tensor_attrs = TensorAttrs{ - input_tensor_shape, std::nullopt, std::nullopt, CreateGrad::YES}; - - LayerAddedResult inputs_layer = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{InputAttrs{}}, "inputs"}, - {}, - {input_tensor_attrs}); - - float scalar = 4.0; - LayerAddedResult scalar_multiply_operator = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{ElementUnaryAttrs{ - OperatorType::SCALAR_MULTIPLY, scalar}}, - "scalar_mult"}, - inputs_layer.outputs, - {input_tensor_attrs}); - tensor_guid_t label_tensor = scalar_multiply_operator.outputs.at(0); + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, + DataType::FLOAT}; + + TensorShape weight_shape = TensorShape{ + TensorDims{FFOrdered{data_dim, output_dim}}, + DataType::FLOAT}; + + LayerAddedResult inputs_layer = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{InputAttrs{input_tensor_shape}}, + "inputs"}, + {}, + {}); + + LayerAddedResult weights_layer = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{WeightAttrs{ + weight_shape, InitializerAttrs{ZeroInitializerAttrs{}}}}, + "weights"}, + {}, + {}); + + LayerAddedResult linear_operator = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{LinearAttrs{output_dim, + /*use_bias=*/true, + DataType::FLOAT, + std::nullopt, + std::nullopt}}, + "linear"}, + inputs_layer.outputs, + {}); + tensor_guid_t logit_tensor = get_only(linear_operator.outputs); // initialize runtime configs ManagedPerDeviceFFHandle managed_handle{}; @@ -85,7 +103,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { compute_loss(local_training_backing, loss_attrs, - label_tensor, + logit_tensor, label_for_sparse_cce_loss_attrs, allocator); } @@ -96,7 +114,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; compute_loss(local_training_backing, loss_attrs, - label_tensor, + logit_tensor, label_for_nonconfigurable_loss_attrs, allocator); } @@ -106,7 +124,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE}}; compute_loss(local_training_backing, loss_attrs, - label_tensor, + logit_tensor, label_for_nonconfigurable_loss_attrs, allocator); } @@ -116,7 +134,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { LossAttrs{NonconfigurableLossAttrs{LossFunction::IDENTITY}}; compute_loss(local_training_backing, loss_attrs, - label_tensor, + logit_tensor, label_for_nonconfigurable_loss_attrs, allocator); } diff --git a/lib/local-execution/test/src/test_unallocated_tensors.cc b/lib/local-execution/test/src/test_unallocated_tensors.cc index 662e7b1878..82f5a132fe 100644 --- a/lib/local-execution/test/src/test_unallocated_tensors.cc +++ b/lib/local-execution/test/src/test_unallocated_tensors.cc @@ -40,20 +40,14 @@ TEST_SUITE(FF_TEST_SUITE) { TensorAttrs tensor_attrs_1_no_grad = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 10_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::NO}; TensorAttrs tensor_attrs_2_no_grad = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 20_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::NO}; TensorAttrs tensor_attrs_3_with_grad = TensorAttrs{ TensorShape{TensorDims{FFOrdered{16_n, 30_n}}, DataType::FLOAT}, - std::nullopt, - std::nullopt, CreateGrad::YES}; GenericTensorAccessorW tensor_backing_1 = diff --git a/lib/local-execution/test/src/test_update.cc b/lib/local-execution/test/src/test_update.cc index 3121d8e02b..d6108635af 100644 --- a/lib/local-execution/test/src/test_update.cc +++ b/lib/local-execution/test/src/test_update.cc @@ -20,29 +20,42 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { ComputationGraph computation_graph = make_empty_computation_graph(); nonnegative_int batch_size = 10_n; - nonnegative_int data_dim = 100_n; + nonnegative_int data_dim = 16_n; + nonnegative_int output_dim = 32_n; TensorShape input_tensor_shape = TensorShape{ TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; - TensorAttrs input_tensor_attrs = TensorAttrs{ - input_tensor_shape, std::nullopt, std::nullopt, CreateGrad::YES}; + TensorShape weight_shape = TensorShape{ + TensorDims{FFOrdered{data_dim, output_dim}}, + DataType::FLOAT}; + + LayerAddedResult inputs_layer = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{InputAttrs{input_tensor_shape}}, + "inputs"}, + {}, + {}); - LayerAddedResult inputs_layer = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{InputAttrs{}}, "inputs"}, - {}, - {input_tensor_attrs}); + LayerAddedResult weights_layer = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{WeightAttrs{ + weight_shape, InitializerAttrs{ZeroInitializerAttrs{}}}}, + "weights"}, + {}, + {}); - float scalar = 4.0; - LayerAddedResult scalar_multiply_operator = - add_layer(computation_graph, - LayerAttrs{ComputationGraphOpAttrs{ElementUnaryAttrs{ - OperatorType::SCALAR_MULTIPLY, scalar}}, - "scalar_mult"}, - inputs_layer.outputs, - {input_tensor_attrs}); + LayerAddedResult linear_operator = add_layer( + computation_graph, + LayerAttrs{ComputationGraphOpAttrs{LinearAttrs{output_dim, + /*use_bias=*/true, + DataType::FLOAT, + std::nullopt, + std::nullopt}}, + "linear"}, + inputs_layer.outputs, + {}); // initialize runtime configs ManagedPerDeviceFFHandle managed_handle{}; @@ -66,7 +79,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { runtime_arg_config, optimizer_attrs}; execute_update(local_training_backing, - scalar_multiply_operator.layer, + linear_operator.layer, optimizer_attrs, allocator); } @@ -83,7 +96,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { runtime_arg_config, optimizer_attrs}; execute_update(local_training_backing, - scalar_multiply_operator.layer, + linear_operator.layer, optimizer_attrs, allocator); } @@ -105,7 +118,7 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { runtime_arg_config, optimizer_attrs}; execute_update(local_training_backing, - scalar_multiply_operator.layer, + linear_operator.layer, optimizer_attrs, allocator); }