diff --git a/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h b/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h new file mode 100644 index 00000000000..08efc77e439 --- /dev/null +++ b/onert-micro/onert-micro/include/core/train/OMTrainingHandler.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_CORE_TRAIN_TRAINING_HANDLER_H +#define ONERT_MICRO_CORE_TRAIN_TRAINING_HANDLER_H + +#include "OMStatus.h" +#include "OMConfig.h" +#include "core/OMRuntimeContext.h" +#include "core/OMRuntimeStorage.h" +#include "core/train/OMTrainingStorage.h" + +namespace onert_micro +{ +namespace core +{ +namespace train +{ + +/* + * Class to handle training process + */ +class OMTrainingHandler +{ + OMTrainingStorage _training_storage; + +public: + OMTrainingHandler() = default; + OMTrainingHandler(const OMTrainingHandler &) = delete; + OMTrainingHandler(OMTrainingHandler &&) = delete; + OMTrainingHandler &operator=(const OMTrainingHandler &) = delete; + OMTrainingHandler &&operator=(const OMTrainingHandler &&) = delete; + ~OMTrainingHandler() = default; + + // Save input and target data in OMTrainingStorage + void setInputData(uint8_t *data, uint32_t input_index) + { + _training_storage.setInputData(data, input_index); + } + void setTargetData(uint8_t *data, uint32_t target_index) + { + _training_storage.setTargetData(data, target_index); + } + + // Get input and target data from OMTrainingStorage + uint8_t *getInputData(uint32_t input_index) + { + return _training_storage.getInputData(input_index); + } + uint8_t *getTargetData(uint32_t target_index) + { + return _training_storage.getTargetData(target_index); + } + + // Handle with current error function (defined in config). + // Calculate backpropagation error between target and calculated data. + // Batch_num - number of current sample in current batch (needed to calculate offset to get + // current target sample) + OMStatus handleError(const OMConfig &config, OMRuntimeStorage &forward_storage, + OMRuntimeStorage &backward_storage, OMRuntimeContext &context, + uint32_t batch_num); + // Handle with updating optimizer state + OMStatus updateOptimizerState(const OMConfig &config, OMRuntimeStorage &backward_storage, + OMRuntimeContext &context); + + // Handle with updating weights with current optimizer + OMStatus updateWeights(const OMConfig &config, OMRuntimeContext &context); + + // Evaluate metric and save result in metric_val + // Warning: 1) assume that all metric_val for all OMMetrics types actually are float values. + // 2) metric_val should be initialized with some value before calling this method due to + // after calculation for current batch_num (the sequence number of the current sample) + // this value is added to metric_val + OMStatus evaluateMetric(OMMetrics metric, void *metric_val, OMRuntimeStorage &storage, + OMRuntimeContext &context, uint32_t batch_num); + + // Save optimizer in OMTrainingStorage + OMStatus setOptimizer(const OMConfig &config) { return _training_storage.setOptimizer(config); } + + // Get training storage + OMTrainingStorage &getTrainingStorage() { return _training_storage; } + + // Reset and deallocate all internal states + void reset(); +}; + +} // namespace train +} // namespace core +} // namespace onert_micro + +#endif // ONERT_MICRO_CORE_TRAIN_TRAINING_HANDLER_H diff --git a/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h b/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h new file mode 100644 index 00000000000..55f7f114224 --- /dev/null +++ b/onert-micro/onert-micro/include/core/train/OMTrainingStorage.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ONERT_MICRO_CORE_TRAIN_TRAINING_STORAGE_H +#define ONERT_MICRO_CORE_TRAIN_TRAINING_STORAGE_H + +#include "OMStatus.h" +#include "OMConfig.h" +#include "train/train_optimizers/SGD.h" +#include "train/train_optimizers/Adam.h" + +#include +#include + +namespace onert_micro +{ +namespace core +{ +namespace train +{ + +/* + * Class to store training specific information + */ +class OMTrainingStorage +{ + // Store mapping between input tensor (its input_index) and its current input data. + // The input data must have a number of samples equal to batch_size + std::unordered_map _input_index_to_input_data; + // Store mapping between output tensor (its output_index) and its current target data. + // The target data must have a number of samples equal to batch_size + std::unordered_map _target_index_to_target_data; + + // Store SGD optimizer with its own internal states + // Note: initial its null + std::unique_ptr _sgd_optimizer = nullptr; + // Store Adam optimizer with its own internal states + // Note: initial its null + std::unique_ptr _adam_optimizer = nullptr; + +public: + OMTrainingStorage() = default; + OMTrainingStorage(const OMTrainingStorage &) = delete; + OMTrainingStorage(OMTrainingStorage &&) = delete; + OMTrainingStorage &operator=(const OMTrainingStorage &) = delete; + OMTrainingStorage &&operator=(const OMTrainingStorage &&) = delete; + ~OMTrainingStorage() = default; + + // Set input data for current input tensor + void setInputData(uint8_t *data, uint32_t input_index) + { + _input_index_to_input_data[input_index] = data; + } + // Set target data for current output tensor + void setTargetData(uint8_t *data, uint32_t target_index) + { + _target_index_to_target_data[target_index] = data; + } + + // Choose and set optimizer defined in config + OMStatus setOptimizer(const OMConfig &config); + + // Get pointer to SGD optimizer + // Note: can return nullptr + onert_micro::train::optimizers::SGD *getSGD() { return _sgd_optimizer.get(); } + // Get pointer to Adam optimizer + // Note: can return nullptr + onert_micro::train::optimizers::Adam *getAdam() { return _adam_optimizer.get(); } + + // Get pointer to saved input data for current input tensor + uint8_t *getInputData(uint32_t input_index) { return _input_index_to_input_data[input_index]; } + // Get pointer to saved target data for current output tensor + uint8_t *getTargetData(uint32_t target_index) + { + return _target_index_to_target_data[target_index]; + } + + // Reset and deallocate all states + void reset(); +}; + +} // namespace train +} // namespace core +} // namespace onert_micro + +#endif // ONERT_MICRO_CORE_TRAIN_TRAINING_STORAGE_H diff --git a/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp new file mode 100644 index 00000000000..d636f16b7b6 --- /dev/null +++ b/onert-micro/onert-micro/src/core/train/OMTrainingHandler.cpp @@ -0,0 +1,277 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/OMDataType.h" +#include "core/memory/OMMemoryManager.h" +#include "core/train/OMTrainingHandler.h" +#include "train/losses_functions/MSE.h" +#include "train/losses_functions/CrossEntropy.h" +#include "train/metrics/MSE.h" +#include "train/metrics/CrossEntropy.h" +#include "train/metrics/Accuracy.h" +#include "train/metrics/MAE.h" + +using namespace onert_micro::core::train; +using namespace onert_micro::core; +using namespace onert_micro::train; +using namespace onert_micro; + +/* + * Calculate backpropagation error between calculated data in forward graph and target data + */ +OMStatus OMTrainingHandler::handleError(const OMConfig &config, OMRuntimeStorage &forward_storage, + OMRuntimeStorage &backward_storage, + OMRuntimeContext &context, uint32_t batch_num) +{ + auto forward_outputs = context.getCircleOutputs(); + // Go over all outputs + for (uint32_t i = 0; i < forward_outputs->size(); ++i) + { + const auto forward_output_index = forward_outputs->operator[](i); + const auto forward_output_tensor = context.getTensorByIndex(forward_output_index); + + OMRuntimeShape shape(forward_output_tensor); + const auto flat_size = shape.flatSize(); + + // Check type + assert(forward_output_tensor->type() == circle::TensorType_FLOAT32 && "Unsupported type"); + if (forward_output_tensor->type() != circle::TensorType_FLOAT32) + return UnsupportedType; + + // Get calculated data + uint8_t *calculated_data = nullptr; + OMStatus status = forward_storage.getDataByTensorIndex(&calculated_data, forward_output_index); + if (status != Ok) + return status; + assert(calculated_data != nullptr); + + // Get target data + auto data_type_size = sizeof(core::OMDataType(forward_output_tensor->type())); + size_t offset = batch_num * data_type_size * flat_size; + uint8_t *target_data = _training_storage.getTargetData(i) + offset; + OMLoss loss_type = config.training_context.loss; + + // Allocate data for error gradient for current calculated data and target data + uint8_t *output_grad_data; + core::memory::OMMemoryManager::allocateMemory(flat_size * data_type_size, &output_grad_data); + // Save error gradient into backward storage associated with current output tensor index + backward_storage.saveDataToTensorIndex(output_grad_data, forward_output_index); + + // Handle different loss types + switch (loss_type) + { + case MSE: + { + losses_functions::MSE::calculateErrorBackpropagation( + flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data), reinterpret_cast(output_grad_data)); + break; + } + case CROSS_ENTROPY: + { + losses_functions::CrossEntropy::calculateErrorBackpropagation( + flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data), reinterpret_cast(output_grad_data)); + break; + } + default: + { + assert(false && "Unsupported loss type"); + return UnsupportedType; + } + } + } + + return Ok; +} + +/* + * Update weights with current optimizer logic + */ +OMStatus OMTrainingHandler::updateWeights(const OMConfig &config, OMRuntimeContext &context) +{ + OMStatus status = Ok; + + // Chose optimizer type + switch (config.training_context.optimizer) + { + case SGD: + { + auto *sgd_optimizer = _training_storage.getSGD(); + assert(sgd_optimizer != nullptr); + if (sgd_optimizer == nullptr) + return UnknownError; + + status = sgd_optimizer->updateWeights(config.training_context, context); + assert(status == Ok); + // Reset values + sgd_optimizer->reset(); + break; + } + case ADAM: + { + auto *adam_optimizer = _training_storage.getAdam(); + assert(adam_optimizer != nullptr); + if (adam_optimizer == nullptr) + return UnknownError; + + status = adam_optimizer->updateWeights(config.training_context, context); + assert(status == Ok); + // Reset values + adam_optimizer->reset(); + break; + } + default: + { + assert(false && "Unsupported type"); + return UnsupportedType; + } + } + + return status; +} + +/* + * Update optimizer state + * + * WARNING: It is assumed that the backward_storage contains only calculated gradients (this + * execution plane creator work). + */ +OMStatus OMTrainingHandler::updateOptimizerState(const OMConfig &config, + OMRuntimeStorage &backward_storage, + OMRuntimeContext &context) +{ + OMStatus status = Ok; + + switch (config.training_context.optimizer) + { + case SGD: + { + auto *sgd_optimizer = _training_storage.getSGD(); + assert(sgd_optimizer != nullptr); + if (sgd_optimizer == nullptr) + return UnknownError; + + sgd_optimizer->handle(backward_storage, context); + break; + } + case ADAM: + { + auto *adam_optimizer = _training_storage.getAdam(); + assert(adam_optimizer != nullptr); + if (adam_optimizer == nullptr) + return UnknownError; + + adam_optimizer->handle(backward_storage, context); + break; + } + default: + { + assert(false && "Unsupported type"); + return UnsupportedType; + } + } + + return status; +} + +void OMTrainingHandler::reset() { _training_storage.reset(); } + +/* + * Evaluate metric according OMMetrics and save it into metric_val + * + * Warning: 1) assume that all metric_val for all OMMetrics types actually are float values. + * 2) metric_val should be initialized with some value before calling this method due to + * after calculation for current batch_num (the sequence number of the current sample) + * this value is added to metric_val + */ +OMStatus OMTrainingHandler::evaluateMetric(OMMetrics metric, void *metric_val, + OMRuntimeStorage &storage, OMRuntimeContext &context, + uint32_t batch_num) +{ + // Go over all outputs and calculate metric + auto forward_outputs = context.getCircleOutputs(); + for (uint32_t i = 0; i < forward_outputs->size(); ++i) + { + // Get output tensor + const auto forward_output_index = forward_outputs->operator[](i); + const auto forward_output_tensor = context.getTensorByIndex(forward_output_index); + + OMRuntimeShape shape(forward_output_tensor); + const auto flat_size = shape.flatSize(); + + // Check type + assert(forward_output_tensor->type() == circle::TensorType_FLOAT32 && "Unsupported type"); + if (forward_output_tensor->type() != circle::TensorType_FLOAT32) + return UnsupportedType; + + // Get calculated data + uint8_t *calculated_data = nullptr; + OMStatus status = storage.getDataByTensorIndex(&calculated_data, forward_output_index); + if (status != Ok) + return status; + assert(calculated_data != nullptr); + + // Get target data + size_t offset = batch_num * sizeof(core::OMDataType(forward_output_tensor->type())) * flat_size; + uint8_t *target_data = _training_storage.getTargetData(i) + offset; + + // Note: always cast it to float + float *f_metric_val = reinterpret_cast(metric_val); + switch (metric) + { + case MSE_METRICS: + { + // Note: sum up new calculated value for current sample + *f_metric_val += + metrics::MSE::calculateValue(flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data)); + break; + } + case MAE_METRICS: + { + // Note: sum up new calculated value for current sample + *f_metric_val += + metrics::MAE::calculateValue(flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data)); + break; + } + case CROSS_ENTROPY_METRICS: + { + // Note: sum up new calculated value for current sample + *f_metric_val += metrics::CrossEntropy::calculateValue( + flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data)); + break; + } + case ACCURACY: + { + // Note: sum up new calculated value for current sample + *f_metric_val += + metrics::Accuracy::calculateValue(flat_size, reinterpret_cast(calculated_data), + reinterpret_cast(target_data)); + break; + } + default: + { + assert(false && "Unsupported loss type"); + return UnsupportedType; + } + } + } + + return Ok; +} diff --git a/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp b/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp new file mode 100644 index 00000000000..b3167facde6 --- /dev/null +++ b/onert-micro/onert-micro/src/core/train/OMTrainingStorage.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "core/train/OMTrainingStorage.h" + +#include + +using namespace onert_micro::core::train; +using namespace onert_micro::train; +using namespace onert_micro; + +OMStatus OMTrainingStorage::setOptimizer(const OMConfig &config) +{ + switch (config.training_context.optimizer) + { + case SGD: + { + if (_sgd_optimizer != nullptr) + { + return UnknownError; + } + _sgd_optimizer = std::make_unique(); + break; + } + case ADAM: + { + if (_adam_optimizer != nullptr) + { + return UnknownError; + } + _adam_optimizer = std::make_unique(); + break; + } + default: + assert(false && "Unsupported Optimizer type"); + } + return Ok; +} + +void OMTrainingStorage::reset() +{ + if (_sgd_optimizer) + _sgd_optimizer->reset(); + + _target_index_to_target_data.clear(); + _input_index_to_input_data.clear(); +}