1be168c0dSopenharmony_cidiff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake 2be168c0dSopenharmony_ciindex 2254c2a7..f15724f1 100644 3be168c0dSopenharmony_ci--- a/cmake/package_lite.cmake 4be168c0dSopenharmony_ci+++ b/cmake/package_lite.cmake 5be168c0dSopenharmony_ci@@ -474,7 +474,7 @@ if(PLATFORM_ARM64) 6be168c0dSopenharmony_ci COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE) 7be168c0dSopenharmony_ci install(DIRECTORY ${TOP_DIR}/include/c_api/ DESTINATION ${RUNTIME_INC_DIR}/c_api 8be168c0dSopenharmony_ci COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h") 9be168c0dSopenharmony_ci- if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR MSLITE_ENABLE_CONVERTER OR TARGET_HIMIX) 10be168c0dSopenharmony_ci+ if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR MSLITE_ENABLE_CONVERTER OR TARGET_HIMIX OR TARGET_OHOS) 11be168c0dSopenharmony_ci __install_micro_wrapper() 12be168c0dSopenharmony_ci endif() 13be168c0dSopenharmony_ci if(MSLITE_ENABLE_RUNTIME_GLOG) 14be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/backend/common/optimizer/pass.h b/mindspore/ccsrc/backend/common/optimizer/pass.h 15be168c0dSopenharmony_cinew file mode 100644 16be168c0dSopenharmony_ciindex 00000000..8d396164 17be168c0dSopenharmony_ci--- /dev/null 18be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/backend/common/optimizer/pass.h 19be168c0dSopenharmony_ci@@ -0,0 +1,48 @@ 20be168c0dSopenharmony_ci+/** 21be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 22be168c0dSopenharmony_ci+ * 23be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 24be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 25be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 26be168c0dSopenharmony_ci+ * 27be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 28be168c0dSopenharmony_ci+ * 29be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 30be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 31be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 33be168c0dSopenharmony_ci+ * limitations under the License. 34be168c0dSopenharmony_ci+ */ 35be168c0dSopenharmony_ci+#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ 36be168c0dSopenharmony_ci+#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ 37be168c0dSopenharmony_ci+#include <memory> 38be168c0dSopenharmony_ci+#include <string> 39be168c0dSopenharmony_ci+#include "ir/anf.h" 40be168c0dSopenharmony_ci+#include "mindspore/core/ops/array_ops.h" 41be168c0dSopenharmony_ci+#include "mindspore/core/ops/lite_ops.h" 42be168c0dSopenharmony_ci+#include "utils/trace_base.h" 43be168c0dSopenharmony_ci+ 44be168c0dSopenharmony_ci+namespace mindspore { 45be168c0dSopenharmony_ci+namespace opt { 46be168c0dSopenharmony_ci+class CacheManager; 47be168c0dSopenharmony_ci+using CacheManagerPtr = std::shared_ptr<CacheManager>; 48be168c0dSopenharmony_ci+ 49be168c0dSopenharmony_ci+// @brief ANF Graph level optimization base pass 50be168c0dSopenharmony_ci+class Pass { 51be168c0dSopenharmony_ci+public: 52be168c0dSopenharmony_ci+ explicit Pass(const std::string &name = "pass") : name_(name) {} 53be168c0dSopenharmony_ci+ virtual ~Pass() = default; 54be168c0dSopenharmony_ci+ virtual bool Run(const FuncGraphPtr &fun_graph) = 0; 55be168c0dSopenharmony_ci+ const std::string &name() const { return name_;} 56be168c0dSopenharmony_ci+ void SetCacheManager(const CacheManagerPtr &cm) { cache_manager_ = cm;} 57be168c0dSopenharmony_ci+ const CacheManagerPtr &GetCacheManager() const {return cache_manager_;} 58be168c0dSopenharmony_ci+ 59be168c0dSopenharmony_ci+private: 60be168c0dSopenharmony_ci+ const std::string name_; 61be168c0dSopenharmony_ci+ CacheManagerPtr cache_manager_; 62be168c0dSopenharmony_ci+}; 63be168c0dSopenharmony_ci+using PassPtr = std::shared_ptr<Pass>; 64be168c0dSopenharmony_ci+} // namespace opt 65be168c0dSopenharmony_ci+} // namespace mindspore 66be168c0dSopenharmony_ci+ 67be168c0dSopenharmony_ci+#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PASS_H_ 68be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc 69be168c0dSopenharmony_ciindex 55bbddac..378ef00c 100644 70be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc 71be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.cc 72be168c0dSopenharmony_ci@@ -60,6 +60,8 @@ bool LstmCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec 73be168c0dSopenharmony_ci hidden_size_ = kernel_ptr->get_hidden_size(); 74be168c0dSopenharmony_ci num_layers_ = kernel_ptr->get_num_layers(); 75be168c0dSopenharmony_ci has_bias_ = kernel_ptr->get_has_bias(); 76be168c0dSopenharmony_ci+ proj_size_ = kernel_ptr->get_proj_size(); 77be168c0dSopenharmony_ci+ real_hidden_size_ = proj_size_ > 0 ? proj_size_ : hidden_size_; 78be168c0dSopenharmony_ci constexpr int kBidirectional = 2; 79be168c0dSopenharmony_ci num_directions_ = 1; 80be168c0dSopenharmony_ci if (bidirectional_) { 81be168c0dSopenharmony_ci@@ -73,14 +75,20 @@ bool LstmCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec 82be168c0dSopenharmony_ci MS_LOG(EXCEPTION) << "Layers must be lower than 100!"; 83be168c0dSopenharmony_ci } 84be168c0dSopenharmony_ci 85be168c0dSopenharmony_ci+ weight_size_ = 0; 86be168c0dSopenharmony_ci+ weight_h_size_ = 0; 87be168c0dSopenharmony_ci+ weight_r_size_ = 0; 88be168c0dSopenharmony_ci for (int i = 0; i < num_layers_; ++i) { 89be168c0dSopenharmony_ci weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); 90be168c0dSopenharmony_ci- weight_h_size_ += gate_size * hidden_size_; 91be168c0dSopenharmony_ci+ weight_h_size_ += gate_size * real_hidden_size_; 92be168c0dSopenharmony_ci+ weight_r_size_ += hidden_size_ * proj_size_; 93be168c0dSopenharmony_ci } 94be168c0dSopenharmony_ci weight_size_ = weight_size_ * num_directions_; 95be168c0dSopenharmony_ci weight_h_size_ = weight_h_size_ * num_directions_; 96be168c0dSopenharmony_ci+ weight_r_size_ = weight_r_size_ * num_directions_; 97be168c0dSopenharmony_ci weights_dims_ = {num_layers_, num_directions_, input_size_, kGateNum, hidden_size_}; 98be168c0dSopenharmony_ci- weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kGateNum, hidden_size_}; 99be168c0dSopenharmony_ci+ weights_h_dims_ = {num_layers_, num_directions_, real_hidden_size_, kGateNum, hidden_size_}; 100be168c0dSopenharmony_ci+ weights_r_dims_ = {num_layers_, num_directions_, hidden_size_, proj_size_}; 101be168c0dSopenharmony_ci bias_dims_ = {num_layers_, num_directions_, kGateNum, hidden_size_}; 102be168c0dSopenharmony_ci is_training_ = 103be168c0dSopenharmony_ci base_operator->HasAttr(kAttrIsTraining) ? GetValue<bool>(base_operator->GetAttr(kAttrIsTraining)) : true; 104be168c0dSopenharmony_ci@@ -110,10 +118,10 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 105be168c0dSopenharmony_ci direction = dnnl::rnn_direction::bidirectional_concat; 106be168c0dSopenharmony_ci } 107be168c0dSopenharmony_ci dim src_dims = {seq_len_, batch_size_, input_size_}; 108be168c0dSopenharmony_ci- dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 109be168c0dSopenharmony_ci+ dim src_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 110be168c0dSopenharmony_ci dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 111be168c0dSopenharmony_ci- dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; 112be168c0dSopenharmony_ci- dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 113be168c0dSopenharmony_ci+ dim dst_dims = {seq_len_, batch_size_, real_hidden_size_ * num_directions_}; 114be168c0dSopenharmony_ci+ dim dst_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 115be168c0dSopenharmony_ci dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 116be168c0dSopenharmony_ci dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); 117be168c0dSopenharmony_ci dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); 118be168c0dSopenharmony_ci@@ -126,13 +134,16 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 119be168c0dSopenharmony_ci auto prop_kind = is_training_ ? dnnl::prop_kind::forward_training : dnnl::prop_kind::forward_inference; 120be168c0dSopenharmony_ci auto weights_desc = formatted_md(weights_dims_, tag::any); 121be168c0dSopenharmony_ci auto weights_h_desc = formatted_md(weights_h_dims_, tag::any); 122be168c0dSopenharmony_ci- auto desc = 123be168c0dSopenharmony_ci- CreatePrimitive<dnnl::lstm_forward::desc>(prop_kind, direction, src_desc, src_h_desc, src_c_desc, weights_desc, 124be168c0dSopenharmony_ci- weights_h_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); 125be168c0dSopenharmony_ci+ auto weights_r_desc = proj_size_ > 0 ? formatted_md(weights_r_dims_, tag::any) : dnnl::memory::desc(); 126be168c0dSopenharmony_ci+ auto peephole_desc = dnnl::memory::desc(); 127be168c0dSopenharmony_ci+ auto desc = CreatePrimitive<dnnl::lstm_forward::desc>(prop_kind, direction, src_desc, src_h_desc, src_c_desc, 128be168c0dSopenharmony_ci+ weights_desc, weights_h_desc, peephole_desc, weights_r_desc, 129be168c0dSopenharmony_ci+ bias_desc, dst_desc, dst_h_desc, dst_c_desc); 130be168c0dSopenharmony_ci prim_desc_ = CreateDesc<dnnl::lstm_forward::primitive_desc>(*desc, engine_); 131be168c0dSopenharmony_ci primitive_ = CreatePrimitive<dnnl::lstm_forward>(prim_desc_); 132be168c0dSopenharmony_ci auto weights_layer = GetWeightsLayerDesc(prim_desc_); 133be168c0dSopenharmony_ci auto weights_iter = GetWeightsIterDesc(prim_desc_); 134be168c0dSopenharmony_ci+ auto weights_proj = GetWeightsProjectionDesc(prim_desc_); 135be168c0dSopenharmony_ci bias_desc_ = GetBiasDesc(prim_desc_); 136be168c0dSopenharmony_ci if (is_training_) { 137be168c0dSopenharmony_ci auto wksp_desc = GetWorkspaceDesc(prim_desc_); 138be168c0dSopenharmony_ci@@ -144,6 +155,7 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 139be168c0dSopenharmony_ci AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); 140be168c0dSopenharmony_ci AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer); 141be168c0dSopenharmony_ci AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter); 142be168c0dSopenharmony_ci+ AddArgument(DNNL_ARG_WEIGHTS_PROJECTION, weights_proj); 143be168c0dSopenharmony_ci AddArgument(DNNL_ARG_BIAS, bias_desc); 144be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DST_LAYER, dst_desc); 145be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); 146be168c0dSopenharmony_ci@@ -151,10 +163,13 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 147be168c0dSopenharmony_ci 148be168c0dSopenharmony_ci auto weights_dims_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi); 149be168c0dSopenharmony_ci auto weights_h_dims_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi); 150be168c0dSopenharmony_ci+ auto weights_r_dims_desc = CreateDesc<dnnl::memory::desc>(weights_r_dims_, dt::f32, tag::ldoi); 151be168c0dSopenharmony_ci user_weights_memory_ = CreateDesc<dnnl::memory>(weights_dims_desc, engine_); 152be168c0dSopenharmony_ci user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_dims_desc, engine_); 153be168c0dSopenharmony_ci+ user_weights_r_memory_ = CreateDesc<dnnl::memory>(weights_r_dims_desc, engine_); 154be168c0dSopenharmony_ci weights_memory_ = CreateDesc<dnnl::memory>(weights_layer, engine_); 155be168c0dSopenharmony_ci weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter, engine_); 156be168c0dSopenharmony_ci+ weights_r_memory_ = CreateDesc<dnnl::memory>(weights_proj, engine_); 157be168c0dSopenharmony_ci bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, engine_); 158be168c0dSopenharmony_ci 159be168c0dSopenharmony_ci InitOutputSize(outputs); 160be168c0dSopenharmony_ci@@ -163,13 +178,20 @@ int LstmCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve 161be168c0dSopenharmony_ci 162be168c0dSopenharmony_ci bool LstmCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, 163be168c0dSopenharmony_ci const std::vector<kernel::AddressPtr> &outputs) { 164be168c0dSopenharmony_ci+ size_t offset = 0; 165be168c0dSopenharmony_ci SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr); 166be168c0dSopenharmony_ci- SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_); 167be168c0dSopenharmony_ci+ offset += weight_size_; 168be168c0dSopenharmony_ci+ SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 169be168c0dSopenharmony_ci+ offset += weight_h_size_; 170be168c0dSopenharmony_ci Reorder(&user_weights_memory_, &weights_memory_); 171be168c0dSopenharmony_ci Reorder(&user_weights_h_memory_, &weights_h_memory_); 172be168c0dSopenharmony_ci+ if (proj_size_ > 0) { 173be168c0dSopenharmony_ci+ SetDataHandle(user_weights_r_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 174be168c0dSopenharmony_ci+ Reorder(&user_weights_r_memory_, &weights_r_memory_); 175be168c0dSopenharmony_ci+ offset += weight_r_size_; 176be168c0dSopenharmony_ci+ } 177be168c0dSopenharmony_ci if (has_bias_) { 178be168c0dSopenharmony_ci- SetDataHandle(bias_memory_, 179be168c0dSopenharmony_ci- reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_); 180be168c0dSopenharmony_ci+ SetDataHandle(bias_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 181be168c0dSopenharmony_ci } else { 182be168c0dSopenharmony_ci auto size = GetSize(bias_desc_); 183be168c0dSopenharmony_ci if (memset_s(GetDataHandle(bias_memory_), size, 0, size) != EOK) { 184be168c0dSopenharmony_ci@@ -182,6 +204,7 @@ bool LstmCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, con 185be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kInputCIndex]->addr); 186be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_)); 187be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_)); 188be168c0dSopenharmony_ci+ SetArgumentHandle(DNNL_ARG_WEIGHTS_PROJECTION, GetDataHandle(weights_r_memory_)); 189be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_)); 190be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); 191be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); 192be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h 193be168c0dSopenharmony_ciindex 42609eed..a0241c16 100644 194be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h 195be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_cpu_kernel.h 196be168c0dSopenharmony_ci@@ -58,14 +58,17 @@ class LstmCpuKernelMod : public MKLCpuKernelMod { 197be168c0dSopenharmony_ci private: 198be168c0dSopenharmony_ci void InitOutputSize(const std::vector<KernelTensorPtr> &outputs); 199be168c0dSopenharmony_ci 200be168c0dSopenharmony_ci- int weight_size_{0}; 201be168c0dSopenharmony_ci- int weight_h_size_{0}; 202be168c0dSopenharmony_ci- int input_size_{0}; 203be168c0dSopenharmony_ci- int hidden_size_{0}; 204be168c0dSopenharmony_ci- int num_layers_{0}; 205be168c0dSopenharmony_ci- int batch_size_{0}; 206be168c0dSopenharmony_ci- int seq_len_{0}; 207be168c0dSopenharmony_ci- int num_directions_{0}; 208be168c0dSopenharmony_ci+ int64_t weight_size_{0}; 209be168c0dSopenharmony_ci+ int64_t weight_h_size_{0}; 210be168c0dSopenharmony_ci+ int64_t weight_r_size_{0}; 211be168c0dSopenharmony_ci+ int64_t input_size_{0}; 212be168c0dSopenharmony_ci+ int64_t hidden_size_{0}; 213be168c0dSopenharmony_ci+ int64_t num_layers_{0}; 214be168c0dSopenharmony_ci+ int64_t batch_size_{0}; 215be168c0dSopenharmony_ci+ int64_t seq_len_{0}; 216be168c0dSopenharmony_ci+ int64_t num_directions_{0}; 217be168c0dSopenharmony_ci+ int64_t proj_size_{0}; 218be168c0dSopenharmony_ci+ int64_t real_hidden_size_{0}; 219be168c0dSopenharmony_ci bool bidirectional_{false}; 220be168c0dSopenharmony_ci bool has_bias_{false}; 221be168c0dSopenharmony_ci bool is_training_{false}; 222be168c0dSopenharmony_ci@@ -73,13 +76,16 @@ class LstmCpuKernelMod : public MKLCpuKernelMod { 223be168c0dSopenharmony_ci 224be168c0dSopenharmony_ci dnnl::memory::dims weights_dims_; 225be168c0dSopenharmony_ci dnnl::memory::dims weights_h_dims_; 226be168c0dSopenharmony_ci+ dnnl::memory::dims weights_r_dims_; 227be168c0dSopenharmony_ci dnnl::memory::dims bias_dims_; 228be168c0dSopenharmony_ci dnnl::lstm_forward::primitive_desc prim_desc_; 229be168c0dSopenharmony_ci dnnl::memory::desc bias_desc_; 230be168c0dSopenharmony_ci dnnl::memory user_weights_memory_; 231be168c0dSopenharmony_ci dnnl::memory user_weights_h_memory_; 232be168c0dSopenharmony_ci+ dnnl::memory user_weights_r_memory_; 233be168c0dSopenharmony_ci dnnl::memory weights_memory_; 234be168c0dSopenharmony_ci dnnl::memory weights_h_memory_; 235be168c0dSopenharmony_ci+ dnnl::memory weights_r_memory_; 236be168c0dSopenharmony_ci dnnl::memory bias_memory_; 237be168c0dSopenharmony_ci }; 238be168c0dSopenharmony_ci } // namespace kernel 239be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc 240be168c0dSopenharmony_ciindex aa1f8b44..0b5d09c1 100644 241be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc 242be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.cc 243be168c0dSopenharmony_ci@@ -62,6 +62,8 @@ bool LSTMGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std: 244be168c0dSopenharmony_ci hidden_size_ = op_prim->get_hidden_size(); 245be168c0dSopenharmony_ci num_layers_ = op_prim->get_num_layers(); 246be168c0dSopenharmony_ci has_bias_ = op_prim->get_has_bias(); 247be168c0dSopenharmony_ci+ proj_size_ = op_prim->get_proj_size(); 248be168c0dSopenharmony_ci+ real_hidden_size_ = proj_size_ > 0 ? proj_size_ : hidden_size_; 249be168c0dSopenharmony_ci auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); 250be168c0dSopenharmony_ci auto match = MatchKernelAttr(kernel_attr, GetOpSupport()); 251be168c0dSopenharmony_ci if (!match.first) { 252be168c0dSopenharmony_ci@@ -103,12 +105,15 @@ int LSTMGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std 253be168c0dSopenharmony_ci } 254be168c0dSopenharmony_ci weight_size_ = 0; 255be168c0dSopenharmony_ci weight_h_size_ = 0; 256be168c0dSopenharmony_ci+ weight_r_size_ = 0; 257be168c0dSopenharmony_ci for (int64_t i = 0; i < num_layers_; ++i) { 258be168c0dSopenharmony_ci weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); 259be168c0dSopenharmony_ci- weight_h_size_ += gate_size * hidden_size_; 260be168c0dSopenharmony_ci+ weight_h_size_ += gate_size * real_hidden_size_; 261be168c0dSopenharmony_ci+ weight_r_size_ += proj_size_ * hidden_size_; 262be168c0dSopenharmony_ci } 263be168c0dSopenharmony_ci weight_size_ = weight_size_ * num_directions_; 264be168c0dSopenharmony_ci weight_h_size_ = weight_h_size_ * num_directions_; 265be168c0dSopenharmony_ci+ weight_r_size_ = weight_r_size_ * num_directions_; 266be168c0dSopenharmony_ci if (num_directions_ * num_layers_ != src_h_shape[0]) { 267be168c0dSopenharmony_ci MS_LOG(ERROR) << "Error iteration shape!"; 268be168c0dSopenharmony_ci return KRET_RESIZE_FAILED; 269be168c0dSopenharmony_ci@@ -124,13 +129,14 @@ void LSTMGradCpuKernelMod::InitDnnl() { 270be168c0dSopenharmony_ci direction = dnnl::rnn_direction::bidirectional_concat; 271be168c0dSopenharmony_ci } 272be168c0dSopenharmony_ci dim src_dims = {seq_len_, batch_size_, input_size_}; 273be168c0dSopenharmony_ci- dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 274be168c0dSopenharmony_ci+ dim src_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 275be168c0dSopenharmony_ci dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 276be168c0dSopenharmony_ci weights_dims_ = {num_layers_, num_directions_, input_size_, kNumberFour, hidden_size_}; 277be168c0dSopenharmony_ci- weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, kNumberFour, hidden_size_}; 278be168c0dSopenharmony_ci+ weights_h_dims_ = {num_layers_, num_directions_, real_hidden_size_, kNumberFour, hidden_size_}; 279be168c0dSopenharmony_ci+ weights_r_dims_ = {num_layers_, num_directions_, hidden_size_, proj_size_}; 280be168c0dSopenharmony_ci bias_dims_ = {num_layers_, num_directions_, kNumberFour, hidden_size_}; 281be168c0dSopenharmony_ci- dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; 282be168c0dSopenharmony_ci- dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 283be168c0dSopenharmony_ci+ dim dst_dims = {seq_len_, batch_size_, real_hidden_size_ * num_directions_}; 284be168c0dSopenharmony_ci+ dim dst_h_dims = {num_layers_, num_directions_, batch_size_, real_hidden_size_}; 285be168c0dSopenharmony_ci dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; 286be168c0dSopenharmony_ci dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); 287be168c0dSopenharmony_ci dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); 288be168c0dSopenharmony_ci@@ -141,15 +147,17 @@ void LSTMGradCpuKernelMod::InitDnnl() { 289be168c0dSopenharmony_ci dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); 290be168c0dSopenharmony_ci auto weights_desc = formatted_md(weights_dims_, tag::any); 291be168c0dSopenharmony_ci auto weights_h_desc = formatted_md(weights_h_dims_, tag::any); 292be168c0dSopenharmony_ci+ auto weights_r_desc = proj_size_ > 0 ? formatted_md(weights_r_dims_, tag::any) : dnnl::memory::desc(); 293be168c0dSopenharmony_ci+ auto peepole_desc = dnnl::memory::desc(); 294be168c0dSopenharmony_ci 295be168c0dSopenharmony_ci- auto forward_desc = CreatePrimitive<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc, 296be168c0dSopenharmony_ci- src_h_desc, src_c_desc, weights_desc, weights_h_desc, 297be168c0dSopenharmony_ci- bias_desc, dst_desc, dst_h_desc, dst_c_desc); 298be168c0dSopenharmony_ci+ auto forward_desc = CreatePrimitive<dnnl::lstm_forward::desc>( 299be168c0dSopenharmony_ci+ dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, 300be168c0dSopenharmony_ci+ peepole_desc, weights_r_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); 301be168c0dSopenharmony_ci auto prim_forward_desc = CreateDesc<dnnl::lstm_forward::primitive_desc>(*forward_desc, eng); 302be168c0dSopenharmony_ci auto backward_desc = CreatePrimitive<dnnl::lstm_backward::desc>( 303be168c0dSopenharmony_ci- dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc, 304be168c0dSopenharmony_ci- dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, bias_desc, 305be168c0dSopenharmony_ci- dst_desc, dst_h_desc, dst_c_desc); 306be168c0dSopenharmony_ci+ dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, weights_desc, weights_h_desc, peepole_desc, 307be168c0dSopenharmony_ci+ weights_r_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, src_c_desc, weights_desc, 308be168c0dSopenharmony_ci+ weights_h_desc, peepole_desc, weights_r_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); 309be168c0dSopenharmony_ci prim_backward_desc_ = CreateDesc<dnnl::lstm_backward::primitive_desc>(*backward_desc, eng, prim_forward_desc); 310be168c0dSopenharmony_ci primitive_ = CreatePrimitive<dnnl::lstm_backward>(prim_backward_desc_); 311be168c0dSopenharmony_ci auto wksp_desc = GetWorkspaceDesc(prim_forward_desc); 312be168c0dSopenharmony_ci@@ -159,24 +167,31 @@ void LSTMGradCpuKernelMod::InitDnnl() { 313be168c0dSopenharmony_ci // construct fw memory 314be168c0dSopenharmony_ci weights_layer_desc_ = GetWeightsLayerDesc(prim_backward_desc_); 315be168c0dSopenharmony_ci weights_iter_desc_ = GetWeightsIterDesc(prim_backward_desc_); 316be168c0dSopenharmony_ci+ weights_proj_desc_ = GetWeightsProjectionDesc(prim_backward_desc_); 317be168c0dSopenharmony_ci bias_desc_ = GetBiasDesc(prim_backward_desc_); 318be168c0dSopenharmony_ci auto weights_mem_desc = CreateDesc<dnnl::memory::desc>(weights_dims_, dt::f32, tag::ldgoi); 319be168c0dSopenharmony_ci auto weights_h_mem_desc = CreateDesc<dnnl::memory::desc>(weights_h_dims_, dt::f32, tag::ldgoi); 320be168c0dSopenharmony_ci+ auto weights_r_mem_desc = CreateDesc<dnnl::memory::desc>(weights_r_dims_, dt::f32, tag::ldoi); 321be168c0dSopenharmony_ci user_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng); 322be168c0dSopenharmony_ci user_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng); 323be168c0dSopenharmony_ci+ user_weights_r_memory_ = CreateDesc<dnnl::memory>(weights_r_mem_desc, eng); 324be168c0dSopenharmony_ci weights_memory_ = CreateDesc<dnnl::memory>(weights_layer_desc_, eng); 325be168c0dSopenharmony_ci weights_h_memory_ = CreateDesc<dnnl::memory>(weights_iter_desc_, eng); 326be168c0dSopenharmony_ci+ weights_r_memory_ = CreateDesc<dnnl::memory>(weights_proj_desc_, eng); 327be168c0dSopenharmony_ci bias_memory_ = CreateDesc<dnnl::memory>(bias_desc_, eng); 328be168c0dSopenharmony_ci 329be168c0dSopenharmony_ci // construct bw memory 330be168c0dSopenharmony_ci diff_weights_layer_desc_ = GetDiffWeightsLayerDesc(prim_backward_desc_); 331be168c0dSopenharmony_ci diff_weights_iter_desc_ = GetDiffWeightsIterDesc(prim_backward_desc_); 332be168c0dSopenharmony_ci+ diff_weights_proj_desc_ = GetDiffWeightsProjectionDesc(prim_backward_desc_); 333be168c0dSopenharmony_ci diff_bias_desc_ = GetDiffBiasDesc(prim_backward_desc_); 334be168c0dSopenharmony_ci diff_weights_memory_ = CreateDesc<dnnl::memory>(diff_weights_layer_desc_, eng); 335be168c0dSopenharmony_ci diff_weights_h_memory_ = CreateDesc<dnnl::memory>(diff_weights_iter_desc_, eng); 336be168c0dSopenharmony_ci+ diff_weights_r_memory_ = CreateDesc<dnnl::memory>(diff_weights_proj_desc_, eng); 337be168c0dSopenharmony_ci diff_bias_memory_ = CreateDesc<dnnl::memory>(diff_bias_desc_, eng); 338be168c0dSopenharmony_ci user_diff_weights_memory_ = CreateDesc<dnnl::memory>(weights_mem_desc, eng); 339be168c0dSopenharmony_ci user_diff_weights_h_memory_ = CreateDesc<dnnl::memory>(weights_h_mem_desc, eng); 340be168c0dSopenharmony_ci+ user_diff_weights_r_memory_ = CreateDesc<dnnl::memory>(weights_r_mem_desc, eng); 341be168c0dSopenharmony_ci } 342be168c0dSopenharmony_ci 343be168c0dSopenharmony_ci void LSTMGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc, 344be168c0dSopenharmony_ci@@ -188,6 +203,7 @@ void LSTMGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, con 345be168c0dSopenharmony_ci AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); 346be168c0dSopenharmony_ci AddArgument(DNNL_ARG_WEIGHTS_LAYER, weights_layer_desc_); 347be168c0dSopenharmony_ci AddArgument(DNNL_ARG_WEIGHTS_ITER, weights_iter_desc_); 348be168c0dSopenharmony_ci+ AddArgument(DNNL_ARG_WEIGHTS_PROJECTION, weights_proj_desc_); 349be168c0dSopenharmony_ci AddArgument(DNNL_ARG_BIAS, bias_desc); 350be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DST_LAYER, dst_desc); 351be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); 352be168c0dSopenharmony_ci@@ -197,6 +213,7 @@ void LSTMGradCpuKernelMod::AddArgumentOp(const dnnl::memory::desc &src_desc, con 353be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); 354be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_desc_); 355be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_desc_); 356be168c0dSopenharmony_ci+ AddArgument(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, diff_weights_proj_desc_); 357be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); 358be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); 359be168c0dSopenharmony_ci AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); 360be168c0dSopenharmony_ci@@ -211,6 +228,7 @@ void LSTMGradCpuKernelMod::SetArgumentHandleOp(const std::vector<kernel::Address 361be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[kSrcIterCIdx]->addr); 362be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, GetDataHandle(weights_memory_)); 363be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, GetDataHandle(weights_h_memory_)); 364be168c0dSopenharmony_ci+ SetArgumentHandle(DNNL_ARG_WEIGHTS_PROJECTION, GetDataHandle(weights_r_memory_)); 365be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_BIAS, GetDataHandle(bias_memory_)); 366be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[kDstLayerIdx]->addr); 367be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[kDstIterIdx]->addr); 368be168c0dSopenharmony_ci@@ -221,6 +239,7 @@ void LSTMGradCpuKernelMod::SetArgumentHandleOp(const std::vector<kernel::Address 369be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[kSrcIterCIdx]->addr); 370be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, GetDataHandle(diff_weights_memory_)); 371be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, GetDataHandle(diff_weights_h_memory_)); 372be168c0dSopenharmony_ci+ SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, GetDataHandle(diff_weights_r_memory_)); 373be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DIFF_BIAS, GetDataHandle(diff_bias_memory_)); 374be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[kDiffDstLayerIdx]->addr); 375be168c0dSopenharmony_ci SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[kDiffDstIterIdx]->addr); 376be168c0dSopenharmony_ci@@ -241,13 +260,20 @@ bool LSTMGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, 377be168c0dSopenharmony_ci const std::vector<kernel::AddressPtr> &outputs) { 378be168c0dSopenharmony_ci CHECK_KERNEL_INPUTS_NUM(inputs.size(), kLstmGradInputsNum, kernel_name_); 379be168c0dSopenharmony_ci CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kLstmGradOutputsNum, kernel_name_); 380be168c0dSopenharmony_ci+ size_t offset = 0; 381be168c0dSopenharmony_ci SetDataHandle(user_weights_memory_, inputs[kInputWeightIndex]->addr); 382be168c0dSopenharmony_ci- SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_); 383be168c0dSopenharmony_ci+ offset += weight_size_; 384be168c0dSopenharmony_ci+ SetDataHandle(user_weights_h_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 385be168c0dSopenharmony_ci+ offset += weight_h_size_; 386be168c0dSopenharmony_ci Reorder(&user_weights_memory_, &weights_memory_); 387be168c0dSopenharmony_ci Reorder(&user_weights_h_memory_, &weights_h_memory_); 388be168c0dSopenharmony_ci+ if (proj_size_ > 0) { 389be168c0dSopenharmony_ci+ SetDataHandle(user_weights_r_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 390be168c0dSopenharmony_ci+ Reorder(&user_weights_r_memory_, &weights_r_memory_); 391be168c0dSopenharmony_ci+ offset += weight_r_size_; 392be168c0dSopenharmony_ci+ } 393be168c0dSopenharmony_ci if (has_bias_) { 394be168c0dSopenharmony_ci- SetDataHandle(bias_memory_, 395be168c0dSopenharmony_ci- reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + weight_size_ + weight_h_size_); 396be168c0dSopenharmony_ci+ SetDataHandle(bias_memory_, reinterpret_cast<float *>(inputs[kInputWeightIndex]->addr) + offset); 397be168c0dSopenharmony_ci } else { 398be168c0dSopenharmony_ci auto dst_ptr = GetDataHandle(bias_memory_); 399be168c0dSopenharmony_ci auto size = GetSize(bias_desc_); 400be168c0dSopenharmony_ci@@ -256,16 +282,23 @@ bool LSTMGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, 401be168c0dSopenharmony_ci } 402be168c0dSopenharmony_ci } 403be168c0dSopenharmony_ci 404be168c0dSopenharmony_ci+ offset = 0; 405be168c0dSopenharmony_ci SetDataHandle(user_diff_weights_memory_, outputs[kOutputWeightIndex]->addr); 406be168c0dSopenharmony_ci- SetDataHandle(user_diff_weights_h_memory_, 407be168c0dSopenharmony_ci- reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_); 408be168c0dSopenharmony_ci+ offset += weight_size_; 409be168c0dSopenharmony_ci+ SetDataHandle(user_diff_weights_h_memory_, reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + offset); 410be168c0dSopenharmony_ci+ offset += weight_h_size_; 411be168c0dSopenharmony_ci ResetMemory(user_diff_weights_memory_, "user weights grad"); 412be168c0dSopenharmony_ci ResetMemory(user_diff_weights_h_memory_, "user weights iter grad"); 413be168c0dSopenharmony_ci ResetMemory(diff_weights_memory_, "weights grad"); 414be168c0dSopenharmony_ci ResetMemory(diff_weights_h_memory_, "weights iter grad"); 415be168c0dSopenharmony_ci+ if (proj_size_ > 0) { 416be168c0dSopenharmony_ci+ SetDataHandle(user_diff_weights_r_memory_, reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + offset); 417be168c0dSopenharmony_ci+ ResetMemory(user_diff_weights_r_memory_, "user weights projection grad"); 418be168c0dSopenharmony_ci+ ResetMemory(diff_weights_r_memory_, "weights projection grad"); 419be168c0dSopenharmony_ci+ offset += weight_r_size_; 420be168c0dSopenharmony_ci+ } 421be168c0dSopenharmony_ci if (has_bias_) { 422be168c0dSopenharmony_ci- SetDataHandle(diff_bias_memory_, 423be168c0dSopenharmony_ci- reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + weight_size_ + weight_h_size_); 424be168c0dSopenharmony_ci+ SetDataHandle(diff_bias_memory_, reinterpret_cast<float *>(outputs[kOutputWeightIndex]->addr) + offset); 425be168c0dSopenharmony_ci } 426be168c0dSopenharmony_ci auto dst_ptr = GetDataHandle(diff_bias_memory_); 427be168c0dSopenharmony_ci auto size = GetSize(diff_bias_desc_); 428be168c0dSopenharmony_ci@@ -276,6 +309,9 @@ bool LSTMGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, 429be168c0dSopenharmony_ci ExecutePrimitive(); 430be168c0dSopenharmony_ci Reorder(&diff_weights_memory_, &user_diff_weights_memory_); 431be168c0dSopenharmony_ci Reorder(&diff_weights_h_memory_, &user_diff_weights_h_memory_); 432be168c0dSopenharmony_ci+ if (proj_size_ > 0) { 433be168c0dSopenharmony_ci+ Reorder(&diff_weights_r_memory_, &user_diff_weights_r_memory_); 434be168c0dSopenharmony_ci+ } 435be168c0dSopenharmony_ci return true; 436be168c0dSopenharmony_ci } 437be168c0dSopenharmony_ci 438be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h 439be168c0dSopenharmony_ciindex f47bafc0..9768464d 100644 440be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h 441be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/lstm_grad_cpu_kernel.h 442be168c0dSopenharmony_ci@@ -75,34 +75,44 @@ class LSTMGradCpuKernelMod : public MKLCpuKernelMod { 443be168c0dSopenharmony_ci bool has_bias_{false}; 444be168c0dSopenharmony_ci int64_t weight_size_{0}; 445be168c0dSopenharmony_ci int64_t weight_h_size_{0}; 446be168c0dSopenharmony_ci+ int64_t weight_r_size_{0}; 447be168c0dSopenharmony_ci int64_t input_size_{0}; 448be168c0dSopenharmony_ci int64_t hidden_size_{0}; 449be168c0dSopenharmony_ci int64_t num_layers_{0}; 450be168c0dSopenharmony_ci int64_t batch_size_{0}; 451be168c0dSopenharmony_ci int64_t seq_len_{0}; 452be168c0dSopenharmony_ci+ int64_t proj_size_{0}; 453be168c0dSopenharmony_ci+ int64_t real_hidden_size_{0}; 454be168c0dSopenharmony_ci size_t reserve_size_{0}; 455be168c0dSopenharmony_ci 456be168c0dSopenharmony_ci dnnl::memory::dims weights_dims_; 457be168c0dSopenharmony_ci dnnl::memory::dims weights_h_dims_; 458be168c0dSopenharmony_ci+ dnnl::memory::dims weights_r_dims_; 459be168c0dSopenharmony_ci dnnl::memory::dims bias_dims_; 460be168c0dSopenharmony_ci dnnl::lstm_backward::primitive_desc prim_backward_desc_; 461be168c0dSopenharmony_ci 462be168c0dSopenharmony_ci dnnl::memory::desc weights_layer_desc_; 463be168c0dSopenharmony_ci dnnl::memory::desc weights_iter_desc_; 464be168c0dSopenharmony_ci+ dnnl::memory::desc weights_proj_desc_; 465be168c0dSopenharmony_ci dnnl::memory::desc bias_desc_; 466be168c0dSopenharmony_ci dnnl::memory::desc diff_weights_layer_desc_; 467be168c0dSopenharmony_ci dnnl::memory::desc diff_weights_iter_desc_; 468be168c0dSopenharmony_ci+ dnnl::memory::desc diff_weights_proj_desc_; 469be168c0dSopenharmony_ci dnnl::memory::desc diff_bias_desc_; 470be168c0dSopenharmony_ci dnnl::memory user_weights_memory_; 471be168c0dSopenharmony_ci dnnl::memory user_weights_h_memory_; 472be168c0dSopenharmony_ci+ dnnl::memory user_weights_r_memory_; 473be168c0dSopenharmony_ci dnnl::memory weights_memory_; 474be168c0dSopenharmony_ci dnnl::memory weights_h_memory_; 475be168c0dSopenharmony_ci+ dnnl::memory weights_r_memory_; 476be168c0dSopenharmony_ci dnnl::memory bias_memory_; 477be168c0dSopenharmony_ci dnnl::memory diff_weights_memory_; 478be168c0dSopenharmony_ci dnnl::memory diff_weights_h_memory_; 479be168c0dSopenharmony_ci+ dnnl::memory diff_weights_r_memory_; 480be168c0dSopenharmony_ci dnnl::memory diff_bias_memory_; 481be168c0dSopenharmony_ci dnnl::memory user_diff_weights_memory_; 482be168c0dSopenharmony_ci dnnl::memory user_diff_weights_h_memory_; 483be168c0dSopenharmony_ci+ dnnl::memory user_diff_weights_r_memory_; 484be168c0dSopenharmony_ci }; 485be168c0dSopenharmony_ci } // namespace kernel 486be168c0dSopenharmony_ci } // namespace mindspore 487be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h 488be168c0dSopenharmony_ciindex 7c8292df..0c98f8f6 100644 489be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h 490be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h 491be168c0dSopenharmony_ci@@ -89,6 +89,14 @@ auto GetWeightsIterDesc(const T &prim_desc) { 492be168c0dSopenharmony_ci return desc; 493be168c0dSopenharmony_ci } 494be168c0dSopenharmony_ci 495be168c0dSopenharmony_ci+template <class T> 496be168c0dSopenharmony_ci+auto GetWeightsProjectionDesc(const T &prim_desc) { 497be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::weights_projection_desc()"; 498be168c0dSopenharmony_ci+ auto desc = prim_desc.weights_projection_desc(); 499be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::weights_projection_desc()"; 500be168c0dSopenharmony_ci+ return desc; 501be168c0dSopenharmony_ci+} 502be168c0dSopenharmony_ci+ 503be168c0dSopenharmony_ci template <class T> 504be168c0dSopenharmony_ci auto GetBiasDesc(const T &prim_desc) { 505be168c0dSopenharmony_ci MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::bias_desc()"; 506be168c0dSopenharmony_ci@@ -113,6 +121,14 @@ auto GetDiffWeightsIterDesc(const T &prim_desc) { 507be168c0dSopenharmony_ci return desc; 508be168c0dSopenharmony_ci } 509be168c0dSopenharmony_ci 510be168c0dSopenharmony_ci+template <class T> 511be168c0dSopenharmony_ci+auto GetDiffWeightsProjectionDesc(const T &prim_desc) { 512be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_weights_projection_desc()"; 513be168c0dSopenharmony_ci+ auto desc = prim_desc.diff_weights_projection_desc(); 514be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "end to invoke " << demangle(typeid(T).name()) << "::diff_weights_projection_desc()"; 515be168c0dSopenharmony_ci+ return desc; 516be168c0dSopenharmony_ci+} 517be168c0dSopenharmony_ci+ 518be168c0dSopenharmony_ci template <class T> 519be168c0dSopenharmony_ci auto GetDiffBiasDesc(const T &prim_desc) { 520be168c0dSopenharmony_ci MS_LOG(DEBUG) << "begin to invoke " << demangle(typeid(T).name()) << "::diff_bias_desc()"; 521be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 522be168c0dSopenharmony_ciindex 103e53b7..d27817be 100644 523be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 524be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 525be168c0dSopenharmony_ci@@ -501,6 +501,7 @@ infer_shape_sources = [ 526be168c0dSopenharmony_ci "infer/custom_masked_fill_infer.c", 527be168c0dSopenharmony_ci "infer/custom_is_inf_infer.c", 528be168c0dSopenharmony_ci "infer/custom_tensor_scatter_max_infer.c", 529be168c0dSopenharmony_ci+ "infer/custom_gather_d_grad_v2_infer.c", 530be168c0dSopenharmony_ci "infer/decoder_layer_infer.c", 531be168c0dSopenharmony_ci "infer/deconv2d_infer.c", 532be168c0dSopenharmony_ci "infer/depth_to_space_infer.c", 533be168c0dSopenharmony_ci@@ -740,6 +741,7 @@ arm64_fp16_assembly_sources = [ 534be168c0dSopenharmony_ci "assembly/fp16/Matmul12X16Fp16.S", 535be168c0dSopenharmony_ci "assembly/fp16/MatmulBaseFp16Neon.S", 536be168c0dSopenharmony_ci "assembly/fp16/MatmulFp16Opt.S", 537be168c0dSopenharmony_ci+ "assembly/fp16/MatmulFp16OptV2.S", 538be168c0dSopenharmony_ci "assembly/fp16/MatmulFp16.S", 539be168c0dSopenharmony_ci "assembly/fp16/MatmulWinogradFp16.S", 540be168c0dSopenharmony_ci "assembly/fp16/MatVecMulFp16.S", 541be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/fp16/MatmulFp16OptV2.S b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/fp16/MatmulFp16OptV2.S 542be168c0dSopenharmony_cinew file mode 100644 543be168c0dSopenharmony_ciindex 00000000..2d901a3d 544be168c0dSopenharmony_ci--- /dev/null 545be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/assembly/fp16/MatmulFp16OptV2.S 546be168c0dSopenharmony_ci@@ -0,0 +1,2966 @@ 547be168c0dSopenharmony_ci+/** 548be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 549be168c0dSopenharmony_ci+ * 550be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 551be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 552be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 553be168c0dSopenharmony_ci+ * 554be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 555be168c0dSopenharmony_ci+ * 556be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 557be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 558be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 559be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 560be168c0dSopenharmony_ci+ * limitations under the License. 561be168c0dSopenharmony_ci+ */ 562be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 563be168c0dSopenharmony_ci+#include "nnacl/assembly_global.h" 564be168c0dSopenharmony_ci+ 565be168c0dSopenharmony_ci+.text 566be168c0dSopenharmony_ci+.align 5 567be168c0dSopenharmony_ci+ 568be168c0dSopenharmony_ci+// void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 569be168c0dSopenharmony_ci+// size_t depth, size_t row, size_t col, size_t stride, size_t writeMode) 570be168c0dSopenharmony_ci+// x0: a 571be168c0dSopenharmony_ci+// x1: b 572be168c0dSopenharmony_ci+// x2: c 573be168c0dSopenharmony_ci+// x3: bias 574be168c0dSopenharmony_ci+// x4: act_type 575be168c0dSopenharmony_ci+// x5: depth 576be168c0dSopenharmony_ci+// x6: row 577be168c0dSopenharmony_ci+// x7: col 578be168c0dSopenharmony_ci+// x8: stride 579be168c0dSopenharmony_ci+// x9: writeMode 580be168c0dSopenharmony_ci+ 581be168c0dSopenharmony_ci+asm_function MatmulFp16OptV2 582be168c0dSopenharmony_ci+ sub sp, sp, #192 583be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [sp], #64 584be168c0dSopenharmony_ci+ st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [sp], #64 585be168c0dSopenharmony_ci+ stp x19, x20, [sp], #16 586be168c0dSopenharmony_ci+ stp x21, x22, [sp], #16 587be168c0dSopenharmony_ci+ stp x23, x24, [sp], #16 588be168c0dSopenharmony_ci+ stp x29, x30, [sp], #16 589be168c0dSopenharmony_ci+ 590be168c0dSopenharmony_ci+ ldr x8, [sp] 591be168c0dSopenharmony_ci+ ldr x9, [sp, #8] // writeMode 592be168c0dSopenharmony_ci+ lsl x8, x8, #1 // stride * sizeof(float16_t) 593be168c0dSopenharmony_ci+ 594be168c0dSopenharmony_ci+ lsl x15, x7, #1 // col * sizeof(float16_t) 595be168c0dSopenharmony_ci+ lsl x16, x5, #1 // depth * sizeof(float16_t) 596be168c0dSopenharmony_ci+ mov x11, x2 597be168c0dSopenharmony_ci+ movi v7.8h, #0x46, lsl #8 598be168c0dSopenharmony_ci+ subs x6, x6, #12 599be168c0dSopenharmony_ci+ blt LoopRow8 600be168c0dSopenharmony_ci+LoopRow12: 601be168c0dSopenharmony_ci+ mov x11, x1 // reload matrixB 602be168c0dSopenharmony_ci+ mov x12, x3 // reload bias 603be168c0dSopenharmony_ci+ mov x13, x7 // reload col 604be168c0dSopenharmony_ci+ mov x21, x2 // relocate output 605be168c0dSopenharmony_ci+ subs x13, x13, #16 606be168c0dSopenharmony_ci+ blt LoopCol12x8 607be168c0dSopenharmony_ci+ LoopCol12x16: 608be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 609be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 610be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 611be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 612be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 613be168c0dSopenharmony_ci+ cbnz x12, InitFromBias12x16 614be168c0dSopenharmony_ci+ dup v8.2d, xzr 615be168c0dSopenharmony_ci+ dup v9.2d, xzr 616be168c0dSopenharmony_ci+ dup v10.2d, xzr 617be168c0dSopenharmony_ci+ dup v11.2d, xzr 618be168c0dSopenharmony_ci+ dup v12.2d, xzr 619be168c0dSopenharmony_ci+ dup v13.2d, xzr 620be168c0dSopenharmony_ci+ dup v14.2d, xzr 621be168c0dSopenharmony_ci+ dup v15.2d, xzr 622be168c0dSopenharmony_ci+ dup v16.2d, xzr 623be168c0dSopenharmony_ci+ dup v17.2d, xzr 624be168c0dSopenharmony_ci+ dup v18.2d, xzr 625be168c0dSopenharmony_ci+ dup v19.2d, xzr 626be168c0dSopenharmony_ci+ dup v20.2d, xzr 627be168c0dSopenharmony_ci+ dup v21.2d, xzr 628be168c0dSopenharmony_ci+ dup v22.2d, xzr 629be168c0dSopenharmony_ci+ dup v23.2d, xzr 630be168c0dSopenharmony_ci+ dup v24.2d, xzr 631be168c0dSopenharmony_ci+ dup v25.2d, xzr 632be168c0dSopenharmony_ci+ dup v26.2d, xzr 633be168c0dSopenharmony_ci+ dup v27.2d, xzr 634be168c0dSopenharmony_ci+ dup v28.2d, xzr 635be168c0dSopenharmony_ci+ dup v29.2d, xzr 636be168c0dSopenharmony_ci+ dup v30.2d, xzr 637be168c0dSopenharmony_ci+ dup v31.2d, xzr 638be168c0dSopenharmony_ci+ b Compute12x16Enter 639be168c0dSopenharmony_ci+ InitFromBias12x16: 640be168c0dSopenharmony_ci+ ld1 {v8.8h, v9.8h}, [x12] 641be168c0dSopenharmony_ci+ ld1 {v10.8h, v11.8h}, [x12] 642be168c0dSopenharmony_ci+ ld1 {v12.8h, v13.8h}, [x12] 643be168c0dSopenharmony_ci+ ld1 {v14.8h, v15.8h}, [x12] 644be168c0dSopenharmony_ci+ ld1 {v16.8h, v17.8h}, [x12] 645be168c0dSopenharmony_ci+ ld1 {v18.8h, v19.8h}, [x12] 646be168c0dSopenharmony_ci+ ld1 {v20.8h, v21.8h}, [x12] 647be168c0dSopenharmony_ci+ ld1 {v22.8h, v23.8h}, [x12] 648be168c0dSopenharmony_ci+ ld1 {v24.8h, v25.8h}, [x12] 649be168c0dSopenharmony_ci+ ld1 {v26.8h, v27.8h}, [x12] 650be168c0dSopenharmony_ci+ ld1 {v28.8h, v29.8h}, [x12] 651be168c0dSopenharmony_ci+ ld1 {v30.8h, v31.8h}, [x12] 652be168c0dSopenharmony_ci+ add x12, x12, #32 653be168c0dSopenharmony_ci+ Compute12x16Enter: 654be168c0dSopenharmony_ci+ bl Compute12x16Unit 655be168c0dSopenharmony_ci+ Activation12x16: 656be168c0dSopenharmony_ci+ cmp x4, #3 657be168c0dSopenharmony_ci+ beq Relu612x16 658be168c0dSopenharmony_ci+ cmp x4, #1 659be168c0dSopenharmony_ci+ beq Relu12x16 660be168c0dSopenharmony_ci+ b Write12x16 661be168c0dSopenharmony_ci+ 662be168c0dSopenharmony_ci+ Relu612x16: 663be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 664be168c0dSopenharmony_ci+ fmin v9.8h, v9.8h, v7.8h 665be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 666be168c0dSopenharmony_ci+ fmin v11.8h, v11.8h, v7.8h 667be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 668be168c0dSopenharmony_ci+ fmin v13.8h, v13.8h, v7.8h 669be168c0dSopenharmony_ci+ fmin v14.8h, v14.8h, v7.8h 670be168c0dSopenharmony_ci+ fmin v15.8h, v15.8h, v7.8h 671be168c0dSopenharmony_ci+ fmin v16.8h, v16.8h, v7.8h 672be168c0dSopenharmony_ci+ fmin v17.8h, v17.8h, v7.8h 673be168c0dSopenharmony_ci+ fmin v18.8h, v18.8h, v7.8h 674be168c0dSopenharmony_ci+ fmin v19.8h, v19.8h, v7.8h 675be168c0dSopenharmony_ci+ fmin v20.8h, v20.8h, v7.8h 676be168c0dSopenharmony_ci+ fmin v21.8h, v21.8h, v7.8h 677be168c0dSopenharmony_ci+ fmin v22.8h, v22.8h, v7.8h 678be168c0dSopenharmony_ci+ fmin v23.8h, v23.8h, v7.8h 679be168c0dSopenharmony_ci+ fmin v24.8h, v24.8h, v7.8h 680be168c0dSopenharmony_ci+ fmin v25.8h, v25.8h, v7.8h 681be168c0dSopenharmony_ci+ fmin v26.8h, v26.8h, v7.8h 682be168c0dSopenharmony_ci+ fmin v27.8h, v27.8h, v7.8h 683be168c0dSopenharmony_ci+ fmin v28.8h, v28.8h, v7.8h 684be168c0dSopenharmony_ci+ fmin v29.8h, v29.8h, v7.8h 685be168c0dSopenharmony_ci+ fmin v30.8h, v30.8h, v7.8h 686be168c0dSopenharmony_ci+ fmin v31.8h, v31.8h, v7.8h 687be168c0dSopenharmony_ci+ 688be168c0dSopenharmony_ci+ Relu12x16: 689be168c0dSopenharmony_ci+ dup v6.8h, wzr 690be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 691be168c0dSopenharmony_ci+ fmax v9.8h, v9.8h, v6.8h 692be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 693be168c0dSopenharmony_ci+ fmax v11.8h, v11.8h, v6.8h 694be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 695be168c0dSopenharmony_ci+ fmax v13.8h, v13.8h, v6.8h 696be168c0dSopenharmony_ci+ fmax v14.8h, v14.8h, v6.8h 697be168c0dSopenharmony_ci+ fmax v15.8h, v15.8h, v6.8h 698be168c0dSopenharmony_ci+ fmax v16.8h, v16.8h, v6.8h 699be168c0dSopenharmony_ci+ fmax v17.8h, v17.8h, v6.8h 700be168c0dSopenharmony_ci+ fmax v18.8h, v18.8h, v6.8h 701be168c0dSopenharmony_ci+ fmax v19.8h, v19.8h, v6.8h 702be168c0dSopenharmony_ci+ fmax v20.8h, v20.8h, v6.8h 703be168c0dSopenharmony_ci+ fmax v21.8h, v21.8h, v6.8h 704be168c0dSopenharmony_ci+ fmax v22.8h, v22.8h, v6.8h 705be168c0dSopenharmony_ci+ fmax v23.8h, v23.8h, v6.8h 706be168c0dSopenharmony_ci+ fmax v24.8h, v24.8h, v6.8h 707be168c0dSopenharmony_ci+ fmax v25.8h, v25.8h, v6.8h 708be168c0dSopenharmony_ci+ fmax v26.8h, v26.8h, v6.8h 709be168c0dSopenharmony_ci+ fmax v27.8h, v27.8h, v6.8h 710be168c0dSopenharmony_ci+ fmax v28.8h, v28.8h, v6.8h 711be168c0dSopenharmony_ci+ fmax v29.8h, v29.8h, v6.8h 712be168c0dSopenharmony_ci+ fmax v30.8h, v30.8h, v6.8h 713be168c0dSopenharmony_ci+ fmax v31.8h, v31.8h, v6.8h 714be168c0dSopenharmony_ci+ Write12x16: 715be168c0dSopenharmony_ci+ mov x22, x21 716be168c0dSopenharmony_ci+ add x23, x21, x8, lsl #2 717be168c0dSopenharmony_ci+ add x24, x21, x8, lsl #3 718be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h}, [x22], x8 719be168c0dSopenharmony_ci+ st1 {v10.8h, v11.8h}, [x22], x8 720be168c0dSopenharmony_ci+ st1 {v12.8h, v13.8h}, [x22], x8 721be168c0dSopenharmony_ci+ st1 {v14.8h, v15.8h}, [x22] 722be168c0dSopenharmony_ci+ st1 {v16.8h, v17.8h}, [x23], x8 723be168c0dSopenharmony_ci+ st1 {v18.8h, v19.8h}, [x23], x8 724be168c0dSopenharmony_ci+ st1 {v20.8h, v21.8h}, [x23], x8 725be168c0dSopenharmony_ci+ st1 {v22.8h, v23.8h}, [x23] 726be168c0dSopenharmony_ci+ st1 {v24.8h, v25.8h}, [x24], x8 727be168c0dSopenharmony_ci+ st1 {v26.8h, v27.8h}, [x24], x8 728be168c0dSopenharmony_ci+ st1 {v28.8h, v29.8h}, [x24], x8 729be168c0dSopenharmony_ci+ st1 {v30.8h, v31.8h}, [x24] 730be168c0dSopenharmony_ci+ add x21, x21, #32 731be168c0dSopenharmony_ci+ subs x13, x13, #16 732be168c0dSopenharmony_ci+ bge LoopCol12x16 733be168c0dSopenharmony_ci+ 734be168c0dSopenharmony_ci+ LoopCol12x8: 735be168c0dSopenharmony_ci+ adds x13, x13, #16 736be168c0dSopenharmony_ci+ cbz x13, LoopRow12End 737be168c0dSopenharmony_ci+ subs x13, x13, #8 738be168c0dSopenharmony_ci+ blt LoopCol12x4 739be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 740be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 741be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 742be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 743be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 744be168c0dSopenharmony_ci+ cbnz x12, InitFromBias12x8 745be168c0dSopenharmony_ci+ dup v8.2d, xzr 746be168c0dSopenharmony_ci+ dup v10.2d, xzr 747be168c0dSopenharmony_ci+ dup v12.2d, xzr 748be168c0dSopenharmony_ci+ dup v14.2d, xzr 749be168c0dSopenharmony_ci+ dup v16.2d, xzr 750be168c0dSopenharmony_ci+ dup v18.2d, xzr 751be168c0dSopenharmony_ci+ dup v20.2d, xzr 752be168c0dSopenharmony_ci+ dup v22.2d, xzr 753be168c0dSopenharmony_ci+ dup v24.2d, xzr 754be168c0dSopenharmony_ci+ dup v26.2d, xzr 755be168c0dSopenharmony_ci+ dup v28.2d, xzr 756be168c0dSopenharmony_ci+ dup v30.2d, xzr 757be168c0dSopenharmony_ci+ b Compute12x8Enter 758be168c0dSopenharmony_ci+ InitFromBias12x8: 759be168c0dSopenharmony_ci+ ld1 {v8.8h}, [x12] 760be168c0dSopenharmony_ci+ ld1 {v10.8h}, [x12] 761be168c0dSopenharmony_ci+ ld1 {v12.8h}, [x12] 762be168c0dSopenharmony_ci+ ld1 {v14.8h}, [x12] 763be168c0dSopenharmony_ci+ ld1 {v16.8h}, [x12] 764be168c0dSopenharmony_ci+ ld1 {v18.8h}, [x12] 765be168c0dSopenharmony_ci+ ld1 {v20.8h}, [x12] 766be168c0dSopenharmony_ci+ ld1 {v22.8h}, [x12] 767be168c0dSopenharmony_ci+ ld1 {v24.8h}, [x12] 768be168c0dSopenharmony_ci+ ld1 {v26.8h}, [x12] 769be168c0dSopenharmony_ci+ ld1 {v28.8h}, [x12] 770be168c0dSopenharmony_ci+ ld1 {v30.8h}, [x12] 771be168c0dSopenharmony_ci+ add x12, x12, #16 772be168c0dSopenharmony_ci+ Compute12x8Enter: 773be168c0dSopenharmony_ci+ bl Compute12x8Unit 774be168c0dSopenharmony_ci+ Activation12x8: 775be168c0dSopenharmony_ci+ cmp x4, #3 776be168c0dSopenharmony_ci+ beq Relu612x8 777be168c0dSopenharmony_ci+ cmp x4, #1 778be168c0dSopenharmony_ci+ beq Relu12x8 779be168c0dSopenharmony_ci+ b Write12x8 780be168c0dSopenharmony_ci+ 781be168c0dSopenharmony_ci+ Relu612x8: 782be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 783be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 784be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 785be168c0dSopenharmony_ci+ fmin v14.8h, v14.8h, v7.8h 786be168c0dSopenharmony_ci+ fmin v16.8h, v16.8h, v7.8h 787be168c0dSopenharmony_ci+ fmin v18.8h, v18.8h, v7.8h 788be168c0dSopenharmony_ci+ fmin v20.8h, v20.8h, v7.8h 789be168c0dSopenharmony_ci+ fmin v22.8h, v22.8h, v7.8h 790be168c0dSopenharmony_ci+ fmin v24.8h, v24.8h, v7.8h 791be168c0dSopenharmony_ci+ fmin v26.8h, v26.8h, v7.8h 792be168c0dSopenharmony_ci+ fmin v28.8h, v28.8h, v7.8h 793be168c0dSopenharmony_ci+ fmin v30.8h, v30.8h, v7.8h 794be168c0dSopenharmony_ci+ 795be168c0dSopenharmony_ci+ Relu12x8: 796be168c0dSopenharmony_ci+ dup v6.8h, wzr 797be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 798be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 799be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 800be168c0dSopenharmony_ci+ fmax v14.8h, v14.8h, v6.8h 801be168c0dSopenharmony_ci+ fmax v16.8h, v16.8h, v6.8h 802be168c0dSopenharmony_ci+ fmax v18.8h, v18.8h, v6.8h 803be168c0dSopenharmony_ci+ fmax v20.8h, v20.8h, v6.8h 804be168c0dSopenharmony_ci+ fmax v22.8h, v22.8h, v6.8h 805be168c0dSopenharmony_ci+ fmax v24.8h, v24.8h, v6.8h 806be168c0dSopenharmony_ci+ fmax v26.8h, v26.8h, v6.8h 807be168c0dSopenharmony_ci+ fmax v28.8h, v28.8h, v6.8h 808be168c0dSopenharmony_ci+ fmax v30.8h, v30.8h, v6.8h 809be168c0dSopenharmony_ci+ Write12x8: 810be168c0dSopenharmony_ci+ mov x22, x21 811be168c0dSopenharmony_ci+ add x23, x21, x8, lsl #2 812be168c0dSopenharmony_ci+ add x24, x21, x8, lsl #3 813be168c0dSopenharmony_ci+ st1 {v8.8h}, [x22], x8 814be168c0dSopenharmony_ci+ st1 {v10.8h}, [x22], x8 815be168c0dSopenharmony_ci+ st1 {v12.8h}, [x22], x8 816be168c0dSopenharmony_ci+ st1 {v14.8h}, [x22] 817be168c0dSopenharmony_ci+ st1 {v16.8h}, [x23], x8 818be168c0dSopenharmony_ci+ st1 {v18.8h}, [x23], x8 819be168c0dSopenharmony_ci+ st1 {v20.8h}, [x23], x8 820be168c0dSopenharmony_ci+ st1 {v22.8h}, [x23] 821be168c0dSopenharmony_ci+ st1 {v24.8h}, [x24], x8 822be168c0dSopenharmony_ci+ st1 {v26.8h}, [x24], x8 823be168c0dSopenharmony_ci+ st1 {v28.8h}, [x24], x8 824be168c0dSopenharmony_ci+ st1 {v30.8h}, [x24] 825be168c0dSopenharmony_ci+ add x21, x21, #16 826be168c0dSopenharmony_ci+ subs x13, x13, #8 827be168c0dSopenharmony_ci+ 828be168c0dSopenharmony_ci+ LoopCol12x4: 829be168c0dSopenharmony_ci+ adds x13, x13, #8 830be168c0dSopenharmony_ci+ cbz x13, LoopRow12End 831be168c0dSopenharmony_ci+ LoopCol12x4Core: 832be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 833be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 834be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 835be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 836be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 837be168c0dSopenharmony_ci+ cbnz x12, InitFromBias12x4 838be168c0dSopenharmony_ci+ dup v8.2s, wzr 839be168c0dSopenharmony_ci+ dup v10.2s, wzr 840be168c0dSopenharmony_ci+ dup v12.2s, wzr 841be168c0dSopenharmony_ci+ dup v14.2s, wzr 842be168c0dSopenharmony_ci+ dup v16.2s, wzr 843be168c0dSopenharmony_ci+ dup v18.2s, wzr 844be168c0dSopenharmony_ci+ dup v20.2s, wzr 845be168c0dSopenharmony_ci+ dup v22.2s, wzr 846be168c0dSopenharmony_ci+ dup v24.2s, wzr 847be168c0dSopenharmony_ci+ dup v26.2s, wzr 848be168c0dSopenharmony_ci+ dup v28.2s, wzr 849be168c0dSopenharmony_ci+ dup v30.2s, wzr 850be168c0dSopenharmony_ci+ b Compute12x4Enter 851be168c0dSopenharmony_ci+ InitFromBias12x4: 852be168c0dSopenharmony_ci+ ld1 {v8.4h}, [x12] 853be168c0dSopenharmony_ci+ ld1 {v10.4h}, [x12] 854be168c0dSopenharmony_ci+ ld1 {v12.4h}, [x12] 855be168c0dSopenharmony_ci+ ld1 {v14.4h}, [x12] 856be168c0dSopenharmony_ci+ ld1 {v16.4h}, [x12] 857be168c0dSopenharmony_ci+ ld1 {v18.4h}, [x12] 858be168c0dSopenharmony_ci+ ld1 {v20.4h}, [x12] 859be168c0dSopenharmony_ci+ ld1 {v22.4h}, [x12] 860be168c0dSopenharmony_ci+ ld1 {v24.4h}, [x12] 861be168c0dSopenharmony_ci+ ld1 {v26.4h}, [x12] 862be168c0dSopenharmony_ci+ ld1 {v28.4h}, [x12] 863be168c0dSopenharmony_ci+ ld1 {v30.4h}, [x12] 864be168c0dSopenharmony_ci+ add x12, x12, #8 865be168c0dSopenharmony_ci+ Compute12x4Enter: 866be168c0dSopenharmony_ci+ bl Compute12x4Unit 867be168c0dSopenharmony_ci+ Activation12x4: 868be168c0dSopenharmony_ci+ cmp x4, #3 869be168c0dSopenharmony_ci+ beq Relu612x4 870be168c0dSopenharmony_ci+ cmp x4, #1 871be168c0dSopenharmony_ci+ beq Relu12x4 872be168c0dSopenharmony_ci+ b Write12x4 873be168c0dSopenharmony_ci+ 874be168c0dSopenharmony_ci+ Relu612x4: 875be168c0dSopenharmony_ci+ fmin v8.4h, v8.4h, v7.4h 876be168c0dSopenharmony_ci+ fmin v10.4h, v10.4h, v7.4h 877be168c0dSopenharmony_ci+ fmin v12.4h, v12.4h, v7.4h 878be168c0dSopenharmony_ci+ fmin v14.4h, v14.4h, v7.4h 879be168c0dSopenharmony_ci+ fmin v16.4h, v16.4h, v7.4h 880be168c0dSopenharmony_ci+ fmin v18.4h, v18.4h, v7.4h 881be168c0dSopenharmony_ci+ fmin v20.4h, v20.4h, v7.4h 882be168c0dSopenharmony_ci+ fmin v22.4h, v22.4h, v7.4h 883be168c0dSopenharmony_ci+ fmin v24.4h, v24.4h, v7.4h 884be168c0dSopenharmony_ci+ fmin v26.4h, v26.4h, v7.4h 885be168c0dSopenharmony_ci+ fmin v28.4h, v28.4h, v7.4h 886be168c0dSopenharmony_ci+ fmin v30.4h, v30.4h, v7.4h 887be168c0dSopenharmony_ci+ 888be168c0dSopenharmony_ci+ Relu12x4: 889be168c0dSopenharmony_ci+ dup v6.4h, wzr 890be168c0dSopenharmony_ci+ fmax v8.4h, v8.4h, v6.4h 891be168c0dSopenharmony_ci+ fmax v10.4h, v10.4h, v6.4h 892be168c0dSopenharmony_ci+ fmax v12.4h, v12.4h, v6.4h 893be168c0dSopenharmony_ci+ fmax v14.4h, v14.4h, v6.4h 894be168c0dSopenharmony_ci+ fmax v16.4h, v16.4h, v6.4h 895be168c0dSopenharmony_ci+ fmax v18.4h, v18.4h, v6.4h 896be168c0dSopenharmony_ci+ fmax v20.4h, v20.4h, v6.4h 897be168c0dSopenharmony_ci+ fmax v22.4h, v22.4h, v6.4h 898be168c0dSopenharmony_ci+ fmax v24.4h, v24.4h, v6.4h 899be168c0dSopenharmony_ci+ fmax v26.4h, v26.4h, v6.4h 900be168c0dSopenharmony_ci+ fmax v28.4h, v28.4h, v6.4h 901be168c0dSopenharmony_ci+ fmax v30.4h, v30.4h, v6.4h 902be168c0dSopenharmony_ci+ Write12x4: 903be168c0dSopenharmony_ci+ mov x22, x21 904be168c0dSopenharmony_ci+ add x23, x21, x8, lsl #2 905be168c0dSopenharmony_ci+ add x24, x21, x8, lsl #3 906be168c0dSopenharmony_ci+ cmp x13, #1 907be168c0dSopenharmony_ci+ beq Write12x1 908be168c0dSopenharmony_ci+ cmp x13, #2 909be168c0dSopenharmony_ci+ beq Write12x2 910be168c0dSopenharmony_ci+ cmp x13, #3 911be168c0dSopenharmony_ci+ beq Write12x3 912be168c0dSopenharmony_ci+ st1 {v8.4h}, [x22], x8 913be168c0dSopenharmony_ci+ st1 {v10.4h}, [x22], x8 914be168c0dSopenharmony_ci+ st1 {v12.4h}, [x22], x8 915be168c0dSopenharmony_ci+ st1 {v14.4h}, [x22] 916be168c0dSopenharmony_ci+ st1 {v16.4h}, [x23], x8 917be168c0dSopenharmony_ci+ st1 {v18.4h}, [x23], x8 918be168c0dSopenharmony_ci+ st1 {v20.4h}, [x23], x8 919be168c0dSopenharmony_ci+ st1 {v22.4h}, [x23] 920be168c0dSopenharmony_ci+ st1 {v24.4h}, [x24], x8 921be168c0dSopenharmony_ci+ st1 {v26.4h}, [x24], x8 922be168c0dSopenharmony_ci+ st1 {v28.4h}, [x24], x8 923be168c0dSopenharmony_ci+ st1 {v30.4h}, [x24] 924be168c0dSopenharmony_ci+ add x21, x21, #8 925be168c0dSopenharmony_ci+ subs x13, x13, #4 926be168c0dSopenharmony_ci+ bgt LoopCol12x4Core 927be168c0dSopenharmony_ci+ b LoopRow12End 928be168c0dSopenharmony_ci+ Write12x1: 929be168c0dSopenharmony_ci+ st1 {v8.h}[0], [x22], x8 930be168c0dSopenharmony_ci+ st1 {v10.h}[0], [x22], x8 931be168c0dSopenharmony_ci+ st1 {v12.h}[0], [x22], x8 932be168c0dSopenharmony_ci+ st1 {v14.h}[0], [x22] 933be168c0dSopenharmony_ci+ st1 {v16.h}[0], [x23], x8 934be168c0dSopenharmony_ci+ st1 {v18.h}[0], [x23], x8 935be168c0dSopenharmony_ci+ st1 {v20.h}[0], [x23], x8 936be168c0dSopenharmony_ci+ st1 {v22.h}[0], [x23] 937be168c0dSopenharmony_ci+ st1 {v24.h}[0], [x24], x8 938be168c0dSopenharmony_ci+ st1 {v26.h}[0], [x24], x8 939be168c0dSopenharmony_ci+ st1 {v28.h}[0], [x24], x8 940be168c0dSopenharmony_ci+ st1 {v30.h}[0], [x24] 941be168c0dSopenharmony_ci+ b LoopRow12End 942be168c0dSopenharmony_ci+ Write12x2: 943be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 944be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 945be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 946be168c0dSopenharmony_ci+ st1 {v14.s}[0], [x22] 947be168c0dSopenharmony_ci+ st1 {v16.s}[0], [x23], x8 948be168c0dSopenharmony_ci+ st1 {v18.s}[0], [x23], x8 949be168c0dSopenharmony_ci+ st1 {v20.s}[0], [x23], x8 950be168c0dSopenharmony_ci+ st1 {v22.s}[0], [x23] 951be168c0dSopenharmony_ci+ st1 {v24.s}[0], [x24], x8 952be168c0dSopenharmony_ci+ st1 {v26.s}[0], [x24], x8 953be168c0dSopenharmony_ci+ st1 {v28.s}[0], [x24], x8 954be168c0dSopenharmony_ci+ st1 {v30.s}[0], [x24] 955be168c0dSopenharmony_ci+ b LoopRow12End 956be168c0dSopenharmony_ci+ Write12x3: 957be168c0dSopenharmony_ci+ add x23, x22, #4 958be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 959be168c0dSopenharmony_ci+ st1 {v8.h}[2], [x23], x8 960be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 961be168c0dSopenharmony_ci+ st1 {v10.h}[2], [x23], x8 962be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 963be168c0dSopenharmony_ci+ st1 {v12.h}[2], [x23], x8 964be168c0dSopenharmony_ci+ st1 {v14.s}[0], [x22], x8 965be168c0dSopenharmony_ci+ st1 {v14.h}[2], [x23], x8 966be168c0dSopenharmony_ci+ st1 {v16.s}[0], [x22], x8 967be168c0dSopenharmony_ci+ st1 {v16.h}[2], [x23], x8 968be168c0dSopenharmony_ci+ st1 {v18.s}[0], [x22], x8 969be168c0dSopenharmony_ci+ st1 {v18.h}[2], [x23], x8 970be168c0dSopenharmony_ci+ st1 {v20.s}[0], [x22], x8 971be168c0dSopenharmony_ci+ st1 {v20.h}[2], [x23], x8 972be168c0dSopenharmony_ci+ st1 {v22.s}[0], [x22], x8 973be168c0dSopenharmony_ci+ st1 {v22.h}[2], [x23], x8 974be168c0dSopenharmony_ci+ st1 {v24.s}[0], [x22], x8 975be168c0dSopenharmony_ci+ st1 {v24.h}[2], [x23], x8 976be168c0dSopenharmony_ci+ st1 {v26.s}[0], [x22], x8 977be168c0dSopenharmony_ci+ st1 {v26.h}[2], [x23], x8 978be168c0dSopenharmony_ci+ st1 {v28.s}[0], [x22], x8 979be168c0dSopenharmony_ci+ st1 {v28.h}[2], [x23], x8 980be168c0dSopenharmony_ci+ st1 {v30.s}[0], [x22] 981be168c0dSopenharmony_ci+ st1 {v30.h}[2], [x23] 982be168c0dSopenharmony_ci+ LoopRow12End: 983be168c0dSopenharmony_ci+ add x0, x0, x16, lsl #3 984be168c0dSopenharmony_ci+ add x0, x0, x16, lsl #2 985be168c0dSopenharmony_ci+ add x2, x2, x8, lsl #3 986be168c0dSopenharmony_ci+ add x2, x2, x8, lsl #2 987be168c0dSopenharmony_ci+ subs x6, x6, #12 988be168c0dSopenharmony_ci+ bge LoopRow12 989be168c0dSopenharmony_ci+ 990be168c0dSopenharmony_ci+LoopRow8: 991be168c0dSopenharmony_ci+ adds x6, x6,#12 992be168c0dSopenharmony_ci+ cbz x6, End 993be168c0dSopenharmony_ci+ subs x6, x6, #8 994be168c0dSopenharmony_ci+ blt LoopRow4 995be168c0dSopenharmony_ci+ mov x11, x1 // reload matrixB 996be168c0dSopenharmony_ci+ mov x12, x3 // reload bias 997be168c0dSopenharmony_ci+ mov x13, x7 // reload col 998be168c0dSopenharmony_ci+ mov x21, x2 // relocate output 999be168c0dSopenharmony_ci+ subs x13, x13, #16 1000be168c0dSopenharmony_ci+ blt LoopCol8x8 1001be168c0dSopenharmony_ci+ LoopCol8x16: 1002be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1003be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 1004be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1005be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1006be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 1007be168c0dSopenharmony_ci+ cbnz x12, InitFromBias8x16 1008be168c0dSopenharmony_ci+ dup v8.2d, xzr 1009be168c0dSopenharmony_ci+ dup v9.2d, xzr 1010be168c0dSopenharmony_ci+ dup v10.2d, xzr 1011be168c0dSopenharmony_ci+ dup v11.2d, xzr 1012be168c0dSopenharmony_ci+ dup v12.2d, xzr 1013be168c0dSopenharmony_ci+ dup v13.2d, xzr 1014be168c0dSopenharmony_ci+ dup v14.2d, xzr 1015be168c0dSopenharmony_ci+ dup v15.2d, xzr 1016be168c0dSopenharmony_ci+ dup v16.2d, xzr 1017be168c0dSopenharmony_ci+ dup v17.2d, xzr 1018be168c0dSopenharmony_ci+ dup v18.2d, xzr 1019be168c0dSopenharmony_ci+ dup v19.2d, xzr 1020be168c0dSopenharmony_ci+ dup v20.2d, xzr 1021be168c0dSopenharmony_ci+ dup v21.2d, xzr 1022be168c0dSopenharmony_ci+ dup v22.2d, xzr 1023be168c0dSopenharmony_ci+ dup v23.2d, xzr 1024be168c0dSopenharmony_ci+ b Compute8x16Enter 1025be168c0dSopenharmony_ci+ InitFromBias8x16: 1026be168c0dSopenharmony_ci+ ld1 {v8.8h, v9.8h}, [x12] 1027be168c0dSopenharmony_ci+ ld1 {v10.8h, v11.8h}, [x12] 1028be168c0dSopenharmony_ci+ ld1 {v12.8h, v13.8h}, [x12] 1029be168c0dSopenharmony_ci+ ld1 {v14.8h, v15.8h}, [x12] 1030be168c0dSopenharmony_ci+ ld1 {v16.8h, v17.8h}, [x12] 1031be168c0dSopenharmony_ci+ ld1 {v18.8h, v19.8h}, [x12] 1032be168c0dSopenharmony_ci+ ld1 {v20.8h, v21.8h}, [x12] 1033be168c0dSopenharmony_ci+ ld1 {v22.8h, v23.8h}, [x12] 1034be168c0dSopenharmony_ci+ add x12, x12, #32 1035be168c0dSopenharmony_ci+ Compute8x16Enter: 1036be168c0dSopenharmony_ci+ bl Compute8x16Unit 1037be168c0dSopenharmony_ci+ Activation8x16: 1038be168c0dSopenharmony_ci+ cmp x4, #3 1039be168c0dSopenharmony_ci+ beq Relu68x16 1040be168c0dSopenharmony_ci+ cmp x4, #1 1041be168c0dSopenharmony_ci+ beq Relu8x16 1042be168c0dSopenharmony_ci+ b Write8x16 1043be168c0dSopenharmony_ci+ 1044be168c0dSopenharmony_ci+ Relu68x16: 1045be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1046be168c0dSopenharmony_ci+ fmin v9.8h, v9.8h, v7.8h 1047be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1048be168c0dSopenharmony_ci+ fmin v11.8h, v11.8h, v7.8h 1049be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 1050be168c0dSopenharmony_ci+ fmin v13.8h, v13.8h, v7.8h 1051be168c0dSopenharmony_ci+ fmin v14.8h, v14.8h, v7.8h 1052be168c0dSopenharmony_ci+ fmin v15.8h, v15.8h, v7.8h 1053be168c0dSopenharmony_ci+ fmin v16.8h, v16.8h, v7.8h 1054be168c0dSopenharmony_ci+ fmin v17.8h, v17.8h, v7.8h 1055be168c0dSopenharmony_ci+ fmin v18.8h, v18.8h, v7.8h 1056be168c0dSopenharmony_ci+ fmin v19.8h, v19.8h, v7.8h 1057be168c0dSopenharmony_ci+ fmin v20.8h, v20.8h, v7.8h 1058be168c0dSopenharmony_ci+ fmin v21.8h, v21.8h, v7.8h 1059be168c0dSopenharmony_ci+ fmin v22.8h, v22.8h, v7.8h 1060be168c0dSopenharmony_ci+ fmin v23.8h, v23.8h, v7.8h 1061be168c0dSopenharmony_ci+ 1062be168c0dSopenharmony_ci+ Relu8x16: 1063be168c0dSopenharmony_ci+ dup v6.8h, wzr 1064be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1065be168c0dSopenharmony_ci+ fmax v9.8h, v9.8h, v6.8h 1066be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1067be168c0dSopenharmony_ci+ fmax v11.8h, v11.8h, v6.8h 1068be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 1069be168c0dSopenharmony_ci+ fmax v13.8h, v13.8h, v6.8h 1070be168c0dSopenharmony_ci+ fmax v14.8h, v14.8h, v6.8h 1071be168c0dSopenharmony_ci+ fmax v15.8h, v15.8h, v6.8h 1072be168c0dSopenharmony_ci+ fmax v16.8h, v16.8h, v6.8h 1073be168c0dSopenharmony_ci+ fmax v17.8h, v17.8h, v6.8h 1074be168c0dSopenharmony_ci+ fmax v18.8h, v18.8h, v6.8h 1075be168c0dSopenharmony_ci+ fmax v19.8h, v19.8h, v6.8h 1076be168c0dSopenharmony_ci+ fmax v20.8h, v20.8h, v6.8h 1077be168c0dSopenharmony_ci+ fmax v21.8h, v21.8h, v6.8h 1078be168c0dSopenharmony_ci+ fmax v22.8h, v22.8h, v6.8h 1079be168c0dSopenharmony_ci+ fmax v23.8h, v23.8h, v6.8h 1080be168c0dSopenharmony_ci+ Write8x16: 1081be168c0dSopenharmony_ci+ mov x22, x21 1082be168c0dSopenharmony_ci+ add x23, x21, x8, lsl #2 1083be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h}, [x22], x8 1084be168c0dSopenharmony_ci+ st1 {v10.8h, v11.8h}, [x22], x8 1085be168c0dSopenharmony_ci+ st1 {v12.8h, v13.8h}, [x22], x8 1086be168c0dSopenharmony_ci+ st1 {v14.8h, v15.8h}, [x22] 1087be168c0dSopenharmony_ci+ st1 {v16.8h, v17.8h}, [x23], x8 1088be168c0dSopenharmony_ci+ st1 {v18.8h, v19.8h}, [x23], x8 1089be168c0dSopenharmony_ci+ st1 {v20.8h, v21.8h}, [x23], x8 1090be168c0dSopenharmony_ci+ st1 {v22.8h, v23.8h}, [x23] 1091be168c0dSopenharmony_ci+ add x21, x21, #32 1092be168c0dSopenharmony_ci+ subs x13, x13, #16 1093be168c0dSopenharmony_ci+ bge LoopCol8x16 1094be168c0dSopenharmony_ci+ 1095be168c0dSopenharmony_ci+ LoopCol8x8: 1096be168c0dSopenharmony_ci+ adds x13, x13, #16 1097be168c0dSopenharmony_ci+ cbz x13, LoopRow8End 1098be168c0dSopenharmony_ci+ subs x13, x13, #8 1099be168c0dSopenharmony_ci+ blt LoopCol8x4 1100be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1101be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 1102be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1103be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1104be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 1105be168c0dSopenharmony_ci+ cbnz x12, InitFromBias8x8 1106be168c0dSopenharmony_ci+ dup v8.2d, xzr 1107be168c0dSopenharmony_ci+ dup v10.2d, xzr 1108be168c0dSopenharmony_ci+ dup v12.2d, xzr 1109be168c0dSopenharmony_ci+ dup v14.2d, xzr 1110be168c0dSopenharmony_ci+ dup v16.2d, xzr 1111be168c0dSopenharmony_ci+ dup v18.2d, xzr 1112be168c0dSopenharmony_ci+ dup v20.2d, xzr 1113be168c0dSopenharmony_ci+ dup v22.2d, xzr 1114be168c0dSopenharmony_ci+ b Compute8x8Enter 1115be168c0dSopenharmony_ci+ InitFromBias8x8: 1116be168c0dSopenharmony_ci+ ld1 {v8.8h}, [x12] 1117be168c0dSopenharmony_ci+ ld1 {v10.8h}, [x12] 1118be168c0dSopenharmony_ci+ ld1 {v12.8h}, [x12] 1119be168c0dSopenharmony_ci+ ld1 {v14.8h}, [x12] 1120be168c0dSopenharmony_ci+ ld1 {v16.8h}, [x12] 1121be168c0dSopenharmony_ci+ ld1 {v18.8h}, [x12] 1122be168c0dSopenharmony_ci+ ld1 {v20.8h}, [x12] 1123be168c0dSopenharmony_ci+ ld1 {v22.8h}, [x12] 1124be168c0dSopenharmony_ci+ add x12, x12, #16 1125be168c0dSopenharmony_ci+ Compute8x8Enter: 1126be168c0dSopenharmony_ci+ bl Compute8x8Unit 1127be168c0dSopenharmony_ci+ Activation8x8: 1128be168c0dSopenharmony_ci+ cmp x4, #3 1129be168c0dSopenharmony_ci+ beq Relu68x8 1130be168c0dSopenharmony_ci+ cmp x4, #1 1131be168c0dSopenharmony_ci+ beq Relu8x8 1132be168c0dSopenharmony_ci+ b Write8x8 1133be168c0dSopenharmony_ci+ 1134be168c0dSopenharmony_ci+ Relu68x8: 1135be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1136be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1137be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 1138be168c0dSopenharmony_ci+ fmin v14.8h, v14.8h, v7.8h 1139be168c0dSopenharmony_ci+ fmin v16.8h, v16.8h, v7.8h 1140be168c0dSopenharmony_ci+ fmin v18.8h, v18.8h, v7.8h 1141be168c0dSopenharmony_ci+ fmin v20.8h, v20.8h, v7.8h 1142be168c0dSopenharmony_ci+ fmin v22.8h, v22.8h, v7.8h 1143be168c0dSopenharmony_ci+ 1144be168c0dSopenharmony_ci+ Relu8x8: 1145be168c0dSopenharmony_ci+ dup v6.8h, wzr 1146be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1147be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1148be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 1149be168c0dSopenharmony_ci+ fmax v14.8h, v14.8h, v6.8h 1150be168c0dSopenharmony_ci+ fmax v16.8h, v16.8h, v6.8h 1151be168c0dSopenharmony_ci+ fmax v18.8h, v18.8h, v6.8h 1152be168c0dSopenharmony_ci+ fmax v20.8h, v20.8h, v6.8h 1153be168c0dSopenharmony_ci+ fmax v22.8h, v22.8h, v6.8h 1154be168c0dSopenharmony_ci+ Write8x8: 1155be168c0dSopenharmony_ci+ mov x22, x21 1156be168c0dSopenharmony_ci+ add x23, x21, x8, lsl #2 1157be168c0dSopenharmony_ci+ st1 {v8.8h}, [x22], x8 1158be168c0dSopenharmony_ci+ st1 {v10.8h}, [x22], x8 1159be168c0dSopenharmony_ci+ st1 {v12.8h}, [x22], x8 1160be168c0dSopenharmony_ci+ st1 {v14.8h}, [x22] 1161be168c0dSopenharmony_ci+ st1 {v16.8h}, [x23], x8 1162be168c0dSopenharmony_ci+ st1 {v18.8h}, [x23], x8 1163be168c0dSopenharmony_ci+ st1 {v20.8h}, [x23], x8 1164be168c0dSopenharmony_ci+ st1 {v22.8h}, [x23] 1165be168c0dSopenharmony_ci+ add x21, x21, #16 1166be168c0dSopenharmony_ci+ subs x13, x13, #8 1167be168c0dSopenharmony_ci+ 1168be168c0dSopenharmony_ci+ LoopCol8x4: 1169be168c0dSopenharmony_ci+ adds x13, x13, #8 1170be168c0dSopenharmony_ci+ cbz x13, LoopRow8End 1171be168c0dSopenharmony_ci+ LoopCol8x4Core: 1172be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1173be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 1174be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1175be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1176be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 1177be168c0dSopenharmony_ci+ cbnz x12, InitFromBias8x4 1178be168c0dSopenharmony_ci+ dup v8.2s, wzr 1179be168c0dSopenharmony_ci+ dup v10.2s, wzr 1180be168c0dSopenharmony_ci+ dup v12.2s, wzr 1181be168c0dSopenharmony_ci+ dup v14.2s, wzr 1182be168c0dSopenharmony_ci+ dup v16.2s, wzr 1183be168c0dSopenharmony_ci+ dup v18.2s, wzr 1184be168c0dSopenharmony_ci+ dup v20.2s, wzr 1185be168c0dSopenharmony_ci+ dup v22.2s, wzr 1186be168c0dSopenharmony_ci+ b Compute8x4Enter 1187be168c0dSopenharmony_ci+ InitFromBias8x4: 1188be168c0dSopenharmony_ci+ ld1 {v8.4h}, [x12] 1189be168c0dSopenharmony_ci+ ld1 {v10.4h}, [x12] 1190be168c0dSopenharmony_ci+ ld1 {v12.4h}, [x12] 1191be168c0dSopenharmony_ci+ ld1 {v14.4h}, [x12] 1192be168c0dSopenharmony_ci+ ld1 {v16.4h}, [x12] 1193be168c0dSopenharmony_ci+ ld1 {v18.4h}, [x12] 1194be168c0dSopenharmony_ci+ ld1 {v20.4h}, [x12] 1195be168c0dSopenharmony_ci+ ld1 {v22.4h}, [x12] 1196be168c0dSopenharmony_ci+ add x12, x12, #8 1197be168c0dSopenharmony_ci+ Compute8x4Enter: 1198be168c0dSopenharmony_ci+ bl Compute8x4Unit 1199be168c0dSopenharmony_ci+ Activation8x4: 1200be168c0dSopenharmony_ci+ cmp x4, #3 1201be168c0dSopenharmony_ci+ beq Relu68x4 1202be168c0dSopenharmony_ci+ cmp x4, #1 1203be168c0dSopenharmony_ci+ beq Relu8x4 1204be168c0dSopenharmony_ci+ b Write8x4 1205be168c0dSopenharmony_ci+ 1206be168c0dSopenharmony_ci+ Relu68x4: 1207be168c0dSopenharmony_ci+ fmin v8.4h, v8.4h, v7.4h 1208be168c0dSopenharmony_ci+ fmin v10.4h, v10.4h, v7.4h 1209be168c0dSopenharmony_ci+ fmin v12.4h, v12.4h, v7.4h 1210be168c0dSopenharmony_ci+ fmin v14.4h, v14.4h, v7.4h 1211be168c0dSopenharmony_ci+ fmin v16.4h, v16.4h, v7.4h 1212be168c0dSopenharmony_ci+ fmin v18.4h, v18.4h, v7.4h 1213be168c0dSopenharmony_ci+ fmin v20.4h, v20.4h, v7.4h 1214be168c0dSopenharmony_ci+ fmin v22.4h, v22.4h, v7.4h 1215be168c0dSopenharmony_ci+ 1216be168c0dSopenharmony_ci+ Relu8x4: 1217be168c0dSopenharmony_ci+ dup v6.4h, wzr 1218be168c0dSopenharmony_ci+ fmax v8.4h, v8.4h, v6.4h 1219be168c0dSopenharmony_ci+ fmax v10.4h, v10.4h, v6.4h 1220be168c0dSopenharmony_ci+ fmax v12.4h, v12.4h, v6.4h 1221be168c0dSopenharmony_ci+ fmax v14.4h, v14.4h, v6.4h 1222be168c0dSopenharmony_ci+ fmax v16.4h, v16.4h, v6.4h 1223be168c0dSopenharmony_ci+ fmax v18.4h, v18.4h, v6.4h 1224be168c0dSopenharmony_ci+ fmax v20.4h, v20.4h, v6.4h 1225be168c0dSopenharmony_ci+ fmax v22.4h, v22.4h, v6.4h 1226be168c0dSopenharmony_ci+ Write8x4: 1227be168c0dSopenharmony_ci+ mov x22, x21 1228be168c0dSopenharmony_ci+ add x23, x21, x8, lsl #2 1229be168c0dSopenharmony_ci+ cmp x13, #1 1230be168c0dSopenharmony_ci+ beq Write8x1 1231be168c0dSopenharmony_ci+ cmp x13, #2 1232be168c0dSopenharmony_ci+ beq Write8x2 1233be168c0dSopenharmony_ci+ cmp x13, #3 1234be168c0dSopenharmony_ci+ beq Write8x3 1235be168c0dSopenharmony_ci+ st1 {v8.4h}, [x22], x8 1236be168c0dSopenharmony_ci+ st1 {v10.4h}, [x22], x8 1237be168c0dSopenharmony_ci+ st1 {v12.4h}, [x22], x8 1238be168c0dSopenharmony_ci+ st1 {v14.4h}, [x22] 1239be168c0dSopenharmony_ci+ st1 {v16.4h}, [x23], x8 1240be168c0dSopenharmony_ci+ st1 {v18.4h}, [x23], x8 1241be168c0dSopenharmony_ci+ st1 {v20.4h}, [x23], x8 1242be168c0dSopenharmony_ci+ st1 {v22.4h}, [x23] 1243be168c0dSopenharmony_ci+ add x21, x21, #8 1244be168c0dSopenharmony_ci+ subs x13, x13, #4 1245be168c0dSopenharmony_ci+ bgt LoopCol8x4Core 1246be168c0dSopenharmony_ci+ b LoopRow8End 1247be168c0dSopenharmony_ci+ Write8x1: 1248be168c0dSopenharmony_ci+ st1 {v8.h}[0], [x22], x8 1249be168c0dSopenharmony_ci+ st1 {v10.h}[0], [x22], x8 1250be168c0dSopenharmony_ci+ st1 {v12.h}[0], [x22], x8 1251be168c0dSopenharmony_ci+ st1 {v14.h}[0], [x22] 1252be168c0dSopenharmony_ci+ st1 {v16.h}[0], [x23], x8 1253be168c0dSopenharmony_ci+ st1 {v18.h}[0], [x23], x8 1254be168c0dSopenharmony_ci+ st1 {v20.h}[0], [x23], x8 1255be168c0dSopenharmony_ci+ st1 {v22.h}[0], [x23] 1256be168c0dSopenharmony_ci+ b LoopRow8End 1257be168c0dSopenharmony_ci+ Write8x2: 1258be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1259be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1260be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 1261be168c0dSopenharmony_ci+ st1 {v14.s}[0], [x22] 1262be168c0dSopenharmony_ci+ st1 {v16.s}[0], [x23], x8 1263be168c0dSopenharmony_ci+ st1 {v18.s}[0], [x23], x8 1264be168c0dSopenharmony_ci+ st1 {v20.s}[0], [x23], x8 1265be168c0dSopenharmony_ci+ st1 {v22.s}[0], [x23] 1266be168c0dSopenharmony_ci+ b LoopRow8End 1267be168c0dSopenharmony_ci+ Write8x3: 1268be168c0dSopenharmony_ci+ add x23, x22, #4 1269be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1270be168c0dSopenharmony_ci+ st1 {v8.h}[2], [x23], x8 1271be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1272be168c0dSopenharmony_ci+ st1 {v10.h}[2], [x23], x8 1273be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 1274be168c0dSopenharmony_ci+ st1 {v12.h}[2], [x23], x8 1275be168c0dSopenharmony_ci+ st1 {v14.s}[0], [x22], x8 1276be168c0dSopenharmony_ci+ st1 {v14.h}[2], [x23], x8 1277be168c0dSopenharmony_ci+ st1 {v16.s}[0], [x22], x8 1278be168c0dSopenharmony_ci+ st1 {v16.h}[2], [x23], x8 1279be168c0dSopenharmony_ci+ st1 {v18.s}[0], [x22], x8 1280be168c0dSopenharmony_ci+ st1 {v18.h}[2], [x23], x8 1281be168c0dSopenharmony_ci+ st1 {v20.s}[0], [x22], x8 1282be168c0dSopenharmony_ci+ st1 {v20.h}[2], [x23], x8 1283be168c0dSopenharmony_ci+ st1 {v22.s}[0], [x22], x8 1284be168c0dSopenharmony_ci+ st1 {v22.h}[2], [x23], x8 1285be168c0dSopenharmony_ci+ LoopRow8End: 1286be168c0dSopenharmony_ci+ add x0, x0, x16, lsl #3 1287be168c0dSopenharmony_ci+ add x2, x2, x8, lsl #3 1288be168c0dSopenharmony_ci+ subs x6, x6, #8 1289be168c0dSopenharmony_ci+ 1290be168c0dSopenharmony_ci+LoopRow4: 1291be168c0dSopenharmony_ci+ adds x6, x6, #8 1292be168c0dSopenharmony_ci+ cbz x6, End 1293be168c0dSopenharmony_ci+ subs x6, x6, #4 1294be168c0dSopenharmony_ci+ blt LoopRowTail 1295be168c0dSopenharmony_ci+ mov x11, x1 // reload matrixB 1296be168c0dSopenharmony_ci+ mov x12, x3 // reload bias 1297be168c0dSopenharmony_ci+ mov x13, x7 // reload col 1298be168c0dSopenharmony_ci+ mov x21, x2 // relocate output 1299be168c0dSopenharmony_ci+ subs x13, x13, #16 1300be168c0dSopenharmony_ci+ blt LoopCol4x8 1301be168c0dSopenharmony_ci+ LoopCol4x16: 1302be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1303be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 1304be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1305be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1306be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 1307be168c0dSopenharmony_ci+ cbnz x12, InitFromBias4x16 1308be168c0dSopenharmony_ci+ dup v8.2d, xzr 1309be168c0dSopenharmony_ci+ dup v9.2d, xzr 1310be168c0dSopenharmony_ci+ dup v10.2d, xzr 1311be168c0dSopenharmony_ci+ dup v11.2d, xzr 1312be168c0dSopenharmony_ci+ dup v12.2d, xzr 1313be168c0dSopenharmony_ci+ dup v13.2d, xzr 1314be168c0dSopenharmony_ci+ dup v14.2d, xzr 1315be168c0dSopenharmony_ci+ dup v15.2d, xzr 1316be168c0dSopenharmony_ci+ b Compute4x16Enter 1317be168c0dSopenharmony_ci+ InitFromBias4x16: 1318be168c0dSopenharmony_ci+ ld1 {v8.8h, v9.8h}, [x12] 1319be168c0dSopenharmony_ci+ ld1 {v10.8h, v11.8h}, [x12] 1320be168c0dSopenharmony_ci+ ld1 {v12.8h, v13.8h}, [x12] 1321be168c0dSopenharmony_ci+ ld1 {v14.8h, v15.8h}, [x12] 1322be168c0dSopenharmony_ci+ add x12, x12, #32 1323be168c0dSopenharmony_ci+ Compute4x16Enter: 1324be168c0dSopenharmony_ci+ bl Compute4x16Unit 1325be168c0dSopenharmony_ci+ Activation4x16: 1326be168c0dSopenharmony_ci+ cmp x4, #3 1327be168c0dSopenharmony_ci+ beq Relu64x16 1328be168c0dSopenharmony_ci+ cmp x4, #1 1329be168c0dSopenharmony_ci+ beq Relu4x16 1330be168c0dSopenharmony_ci+ b Write4x16 1331be168c0dSopenharmony_ci+ 1332be168c0dSopenharmony_ci+ Relu64x16: 1333be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1334be168c0dSopenharmony_ci+ fmin v9.8h, v9.8h, v7.8h 1335be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1336be168c0dSopenharmony_ci+ fmin v11.8h, v11.8h, v7.8h 1337be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 1338be168c0dSopenharmony_ci+ fmin v13.8h, v13.8h, v7.8h 1339be168c0dSopenharmony_ci+ fmin v14.8h, v14.8h, v7.8h 1340be168c0dSopenharmony_ci+ fmin v15.8h, v15.8h, v7.8h 1341be168c0dSopenharmony_ci+ 1342be168c0dSopenharmony_ci+ Relu4x16: 1343be168c0dSopenharmony_ci+ dup v6.8h, wzr 1344be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1345be168c0dSopenharmony_ci+ fmax v9.8h, v9.8h, v6.8h 1346be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1347be168c0dSopenharmony_ci+ fmax v11.8h, v11.8h, v6.8h 1348be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 1349be168c0dSopenharmony_ci+ fmax v13.8h, v13.8h, v6.8h 1350be168c0dSopenharmony_ci+ fmax v14.8h, v14.8h, v6.8h 1351be168c0dSopenharmony_ci+ fmax v15.8h, v15.8h, v6.8h 1352be168c0dSopenharmony_ci+ Write4x16: 1353be168c0dSopenharmony_ci+ mov x22, x21 1354be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h}, [x22], x8 1355be168c0dSopenharmony_ci+ st1 {v10.8h, v11.8h}, [x22], x8 1356be168c0dSopenharmony_ci+ st1 {v12.8h, v13.8h}, [x22], x8 1357be168c0dSopenharmony_ci+ st1 {v14.8h, v15.8h}, [x22] 1358be168c0dSopenharmony_ci+ add x21, x21, #32 1359be168c0dSopenharmony_ci+ subs x13, x13, #16 1360be168c0dSopenharmony_ci+ bge LoopCol4x16 1361be168c0dSopenharmony_ci+ 1362be168c0dSopenharmony_ci+ LoopCol4x8: 1363be168c0dSopenharmony_ci+ adds x13, x13, #16 1364be168c0dSopenharmony_ci+ cbz x13, LoopRow4End 1365be168c0dSopenharmony_ci+ subs x13, x13, #8 1366be168c0dSopenharmony_ci+ blt LoopCol4x4 1367be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1368be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 1369be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1370be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1371be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 1372be168c0dSopenharmony_ci+ cbnz x12, InitFromBias4x8 1373be168c0dSopenharmony_ci+ dup v8.2d, xzr 1374be168c0dSopenharmony_ci+ dup v10.2d, xzr 1375be168c0dSopenharmony_ci+ dup v12.2d, xzr 1376be168c0dSopenharmony_ci+ dup v14.2d, xzr 1377be168c0dSopenharmony_ci+ b Compute4x8Enter 1378be168c0dSopenharmony_ci+ InitFromBias4x8: 1379be168c0dSopenharmony_ci+ ld1 {v8.8h}, [x12] 1380be168c0dSopenharmony_ci+ ld1 {v10.8h}, [x12] 1381be168c0dSopenharmony_ci+ ld1 {v12.8h}, [x12] 1382be168c0dSopenharmony_ci+ ld1 {v14.8h}, [x12] 1383be168c0dSopenharmony_ci+ add x12, x12, #16 1384be168c0dSopenharmony_ci+ Compute4x8Enter: 1385be168c0dSopenharmony_ci+ bl Compute4x8Unit 1386be168c0dSopenharmony_ci+ Activation4x8: 1387be168c0dSopenharmony_ci+ cmp x4, #3 1388be168c0dSopenharmony_ci+ beq Relu64x8 1389be168c0dSopenharmony_ci+ cmp x4, #1 1390be168c0dSopenharmony_ci+ beq Relu4x8 1391be168c0dSopenharmony_ci+ b Write4x8 1392be168c0dSopenharmony_ci+ 1393be168c0dSopenharmony_ci+ Relu64x8: 1394be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1395be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1396be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 1397be168c0dSopenharmony_ci+ fmin v14.8h, v14.8h, v7.8h 1398be168c0dSopenharmony_ci+ 1399be168c0dSopenharmony_ci+ Relu4x8: 1400be168c0dSopenharmony_ci+ dup v6.8h, wzr 1401be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1402be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1403be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 1404be168c0dSopenharmony_ci+ fmax v14.8h, v14.8h, v6.8h 1405be168c0dSopenharmony_ci+ Write4x8: 1406be168c0dSopenharmony_ci+ mov x22, x21 1407be168c0dSopenharmony_ci+ st1 {v8.8h}, [x22], x8 1408be168c0dSopenharmony_ci+ st1 {v10.8h}, [x22], x8 1409be168c0dSopenharmony_ci+ st1 {v12.8h}, [x22], x8 1410be168c0dSopenharmony_ci+ st1 {v14.8h}, [x22] 1411be168c0dSopenharmony_ci+ add x21, x21, #16 1412be168c0dSopenharmony_ci+ subs x13, x13, #8 1413be168c0dSopenharmony_ci+ 1414be168c0dSopenharmony_ci+ LoopCol4x4: 1415be168c0dSopenharmony_ci+ adds x13, x13, #8 1416be168c0dSopenharmony_ci+ cbz x13, LoopRow4End 1417be168c0dSopenharmony_ci+ LoopCol4x4Core: 1418be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1419be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 1420be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1421be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1422be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 1423be168c0dSopenharmony_ci+ cbnz x12, InitFromBias4x4 1424be168c0dSopenharmony_ci+ dup v8.2s, wzr 1425be168c0dSopenharmony_ci+ dup v10.2s, wzr 1426be168c0dSopenharmony_ci+ dup v12.2s, wzr 1427be168c0dSopenharmony_ci+ dup v14.2s, wzr 1428be168c0dSopenharmony_ci+ b Compute4x4Enter 1429be168c0dSopenharmony_ci+ InitFromBias4x4: 1430be168c0dSopenharmony_ci+ ld1 {v8.4h}, [x12] 1431be168c0dSopenharmony_ci+ ld1 {v10.4h}, [x12] 1432be168c0dSopenharmony_ci+ ld1 {v12.4h}, [x12] 1433be168c0dSopenharmony_ci+ ld1 {v14.4h}, [x12] 1434be168c0dSopenharmony_ci+ add x12, x12, #8 1435be168c0dSopenharmony_ci+ Compute4x4Enter: 1436be168c0dSopenharmony_ci+ bl Compute4x4Unit 1437be168c0dSopenharmony_ci+ Activation4x4: 1438be168c0dSopenharmony_ci+ cmp x4, #3 1439be168c0dSopenharmony_ci+ beq Relu64x4 1440be168c0dSopenharmony_ci+ cmp x4, #1 1441be168c0dSopenharmony_ci+ beq Relu4x4 1442be168c0dSopenharmony_ci+ b Write4x4 1443be168c0dSopenharmony_ci+ 1444be168c0dSopenharmony_ci+ Relu64x4: 1445be168c0dSopenharmony_ci+ fmin v8.4h, v8.4h, v7.4h 1446be168c0dSopenharmony_ci+ fmin v10.4h, v10.4h, v7.4h 1447be168c0dSopenharmony_ci+ fmin v12.4h, v12.4h, v7.4h 1448be168c0dSopenharmony_ci+ fmin v14.4h, v14.4h, v7.4h 1449be168c0dSopenharmony_ci+ 1450be168c0dSopenharmony_ci+ Relu4x4: 1451be168c0dSopenharmony_ci+ dup v6.4h, wzr 1452be168c0dSopenharmony_ci+ fmax v8.4h, v8.4h, v6.4h 1453be168c0dSopenharmony_ci+ fmax v10.4h, v10.4h, v6.4h 1454be168c0dSopenharmony_ci+ fmax v12.4h, v12.4h, v6.4h 1455be168c0dSopenharmony_ci+ fmax v14.4h, v14.4h, v6.4h 1456be168c0dSopenharmony_ci+ Write4x4: 1457be168c0dSopenharmony_ci+ mov x22, x21 1458be168c0dSopenharmony_ci+ cmp x13, #1 1459be168c0dSopenharmony_ci+ beq Write4x1 1460be168c0dSopenharmony_ci+ cmp x13, #2 1461be168c0dSopenharmony_ci+ beq Write4x2 1462be168c0dSopenharmony_ci+ cmp x13, #3 1463be168c0dSopenharmony_ci+ beq Write4x3 1464be168c0dSopenharmony_ci+ st1 {v8.4h}, [x22], x8 1465be168c0dSopenharmony_ci+ st1 {v10.4h}, [x22], x8 1466be168c0dSopenharmony_ci+ st1 {v12.4h}, [x22], x8 1467be168c0dSopenharmony_ci+ st1 {v14.4h}, [x22] 1468be168c0dSopenharmony_ci+ add x21, x21, #8 1469be168c0dSopenharmony_ci+ subs x13, x13, #4 1470be168c0dSopenharmony_ci+ bgt LoopCol4x4Core 1471be168c0dSopenharmony_ci+ b LoopRow4End 1472be168c0dSopenharmony_ci+ Write4x1: 1473be168c0dSopenharmony_ci+ st1 {v8.h}[0], [x22], x8 1474be168c0dSopenharmony_ci+ st1 {v10.h}[0], [x22], x8 1475be168c0dSopenharmony_ci+ st1 {v12.h}[0], [x22], x8 1476be168c0dSopenharmony_ci+ st1 {v14.h}[0], [x22] 1477be168c0dSopenharmony_ci+ b LoopRow4End 1478be168c0dSopenharmony_ci+ Write4x2: 1479be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1480be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1481be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 1482be168c0dSopenharmony_ci+ st1 {v14.s}[0], [x22] 1483be168c0dSopenharmony_ci+ b LoopRow4End 1484be168c0dSopenharmony_ci+ Write4x3: 1485be168c0dSopenharmony_ci+ add x23, x22, #4 1486be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1487be168c0dSopenharmony_ci+ st1 {v8.h}[2], [x23], x8 1488be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1489be168c0dSopenharmony_ci+ st1 {v10.h}[2], [x23], x8 1490be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 1491be168c0dSopenharmony_ci+ st1 {v12.h}[2], [x23], x8 1492be168c0dSopenharmony_ci+ st1 {v14.s}[0], [x22], x8 1493be168c0dSopenharmony_ci+ st1 {v14.h}[2], [x23], x8 1494be168c0dSopenharmony_ci+ LoopRow4End: 1495be168c0dSopenharmony_ci+ add x0, x0, x16, lsl #2 1496be168c0dSopenharmony_ci+ add x2, x2, x8, lsl #2 1497be168c0dSopenharmony_ci+ subs x6, x6, #4 1498be168c0dSopenharmony_ci+ 1499be168c0dSopenharmony_ci+LoopRowTail: 1500be168c0dSopenharmony_ci+ adds x6, x6, #4 1501be168c0dSopenharmony_ci+ cbz x6, End 1502be168c0dSopenharmony_ci+ cmp x6, #1 1503be168c0dSopenharmony_ci+ beq LoopRow1 1504be168c0dSopenharmony_ci+ cmp x6, #2 1505be168c0dSopenharmony_ci+ beq LoopRow2 1506be168c0dSopenharmony_ci+ // LoopRow3 1507be168c0dSopenharmony_ci+ mov x11, x1 // reload matrixB 1508be168c0dSopenharmony_ci+ mov x12, x3 // reload bias 1509be168c0dSopenharmony_ci+ mov x13, x7 // reload col 1510be168c0dSopenharmony_ci+ mov x21, x2 // relocate output 1511be168c0dSopenharmony_ci+ subs x13, x13, #16 1512be168c0dSopenharmony_ci+ blt LoopCol3x8 1513be168c0dSopenharmony_ci+ LoopCol3x16: 1514be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1515be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1516be168c0dSopenharmony_ci+ cbnz x12, InitFromBias3x16 1517be168c0dSopenharmony_ci+ dup v8.2d, xzr 1518be168c0dSopenharmony_ci+ dup v9.2d, xzr 1519be168c0dSopenharmony_ci+ dup v10.2d, xzr 1520be168c0dSopenharmony_ci+ dup v11.2d, xzr 1521be168c0dSopenharmony_ci+ dup v12.2d, xzr 1522be168c0dSopenharmony_ci+ dup v13.2d, xzr 1523be168c0dSopenharmony_ci+ b Compute3x16Enter 1524be168c0dSopenharmony_ci+ InitFromBias3x16: 1525be168c0dSopenharmony_ci+ ld1 {v8.8h, v9.8h}, [x12] 1526be168c0dSopenharmony_ci+ ld1 {v10.8h, v11.8h}, [x12] 1527be168c0dSopenharmony_ci+ ld1 {v12.8h, v13.8h}, [x12] 1528be168c0dSopenharmony_ci+ add x12, x12, #32 1529be168c0dSopenharmony_ci+ Compute3x16Enter: 1530be168c0dSopenharmony_ci+ bl Compute3x16Unit 1531be168c0dSopenharmony_ci+ Activation3x16: 1532be168c0dSopenharmony_ci+ cmp x4, #3 1533be168c0dSopenharmony_ci+ beq Relu63x16 1534be168c0dSopenharmony_ci+ cmp x4, #1 1535be168c0dSopenharmony_ci+ beq Relu3x16 1536be168c0dSopenharmony_ci+ b Write3x16 1537be168c0dSopenharmony_ci+ 1538be168c0dSopenharmony_ci+ Relu63x16: 1539be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1540be168c0dSopenharmony_ci+ fmin v9.8h, v9.8h, v7.8h 1541be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1542be168c0dSopenharmony_ci+ fmin v11.8h, v11.8h, v7.8h 1543be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 1544be168c0dSopenharmony_ci+ fmin v13.8h, v13.8h, v7.8h 1545be168c0dSopenharmony_ci+ 1546be168c0dSopenharmony_ci+ Relu3x16: 1547be168c0dSopenharmony_ci+ dup v6.8h, wzr 1548be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1549be168c0dSopenharmony_ci+ fmax v9.8h, v9.8h, v6.8h 1550be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1551be168c0dSopenharmony_ci+ fmax v11.8h, v11.8h, v6.8h 1552be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 1553be168c0dSopenharmony_ci+ fmax v13.8h, v13.8h, v6.8h 1554be168c0dSopenharmony_ci+ Write3x16: 1555be168c0dSopenharmony_ci+ mov x22, x21 1556be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h}, [x22], x8 1557be168c0dSopenharmony_ci+ st1 {v10.8h, v11.8h}, [x22], x8 1558be168c0dSopenharmony_ci+ st1 {v12.8h, v13.8h}, [x22] 1559be168c0dSopenharmony_ci+ add x21, x21, #32 1560be168c0dSopenharmony_ci+ subs x13, x13, #16 1561be168c0dSopenharmony_ci+ bge LoopCol3x16 1562be168c0dSopenharmony_ci+ 1563be168c0dSopenharmony_ci+ LoopCol3x8: 1564be168c0dSopenharmony_ci+ adds x13, x13, #16 1565be168c0dSopenharmony_ci+ cbz x13, End 1566be168c0dSopenharmony_ci+ subs x13, x13, #8 1567be168c0dSopenharmony_ci+ blt LoopCol3x4 1568be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1569be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1570be168c0dSopenharmony_ci+ cbnz x12, InitFromBias3x8 1571be168c0dSopenharmony_ci+ dup v8.2d, xzr 1572be168c0dSopenharmony_ci+ dup v10.2d, xzr 1573be168c0dSopenharmony_ci+ dup v12.2d, xzr 1574be168c0dSopenharmony_ci+ b Compute3x8Enter 1575be168c0dSopenharmony_ci+ InitFromBias3x8: 1576be168c0dSopenharmony_ci+ ld1 {v8.8h}, [x12] 1577be168c0dSopenharmony_ci+ ld1 {v10.8h}, [x12] 1578be168c0dSopenharmony_ci+ ld1 {v12.8h}, [x12] 1579be168c0dSopenharmony_ci+ add x12, x12, #16 1580be168c0dSopenharmony_ci+ Compute3x8Enter: 1581be168c0dSopenharmony_ci+ bl Compute3x8Unit 1582be168c0dSopenharmony_ci+ Activation3x8: 1583be168c0dSopenharmony_ci+ cmp x4, #3 1584be168c0dSopenharmony_ci+ beq Relu63x8 1585be168c0dSopenharmony_ci+ cmp x4, #1 1586be168c0dSopenharmony_ci+ beq Relu3x8 1587be168c0dSopenharmony_ci+ b Write3x8 1588be168c0dSopenharmony_ci+ 1589be168c0dSopenharmony_ci+ Relu63x8: 1590be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1591be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1592be168c0dSopenharmony_ci+ fmin v12.8h, v12.8h, v7.8h 1593be168c0dSopenharmony_ci+ 1594be168c0dSopenharmony_ci+ Relu3x8: 1595be168c0dSopenharmony_ci+ dup v6.8h, wzr 1596be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1597be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1598be168c0dSopenharmony_ci+ fmax v12.8h, v12.8h, v6.8h 1599be168c0dSopenharmony_ci+ Write3x8: 1600be168c0dSopenharmony_ci+ mov x22, x21 1601be168c0dSopenharmony_ci+ st1 {v8.8h}, [x22], x8 1602be168c0dSopenharmony_ci+ st1 {v10.8h}, [x22], x8 1603be168c0dSopenharmony_ci+ st1 {v12.8h}, [x22] 1604be168c0dSopenharmony_ci+ add x21, x21, #16 1605be168c0dSopenharmony_ci+ subs x13, x13, #8 1606be168c0dSopenharmony_ci+ 1607be168c0dSopenharmony_ci+ LoopCol3x4: 1608be168c0dSopenharmony_ci+ adds x13, x13, #8 1609be168c0dSopenharmony_ci+ cbz x13, End 1610be168c0dSopenharmony_ci+ LoopCol3x4Core: 1611be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1612be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1613be168c0dSopenharmony_ci+ cbnz x12, InitFromBias3x4 1614be168c0dSopenharmony_ci+ dup v8.2s, wzr 1615be168c0dSopenharmony_ci+ dup v10.2s, wzr 1616be168c0dSopenharmony_ci+ dup v12.2s, wzr 1617be168c0dSopenharmony_ci+ b Compute3x4Enter 1618be168c0dSopenharmony_ci+ InitFromBias3x4: 1619be168c0dSopenharmony_ci+ ld1 {v8.4h}, [x12] 1620be168c0dSopenharmony_ci+ ld1 {v10.4h}, [x12] 1621be168c0dSopenharmony_ci+ ld1 {v12.4h}, [x12] 1622be168c0dSopenharmony_ci+ add x12, x12, #8 1623be168c0dSopenharmony_ci+ Compute3x4Enter: 1624be168c0dSopenharmony_ci+ bl Compute3x4Unit 1625be168c0dSopenharmony_ci+ Activation3x4: 1626be168c0dSopenharmony_ci+ cmp x4, #3 1627be168c0dSopenharmony_ci+ beq Relu63x4 1628be168c0dSopenharmony_ci+ cmp x4, #1 1629be168c0dSopenharmony_ci+ beq Relu3x4 1630be168c0dSopenharmony_ci+ b Write3x4 1631be168c0dSopenharmony_ci+ 1632be168c0dSopenharmony_ci+ Relu63x4: 1633be168c0dSopenharmony_ci+ fmin v8.4h, v8.4h, v7.4h 1634be168c0dSopenharmony_ci+ fmin v10.4h, v10.4h, v7.4h 1635be168c0dSopenharmony_ci+ fmin v12.4h, v12.4h, v7.4h 1636be168c0dSopenharmony_ci+ 1637be168c0dSopenharmony_ci+ Relu3x4: 1638be168c0dSopenharmony_ci+ dup v6.4h, wzr 1639be168c0dSopenharmony_ci+ fmax v8.4h, v8.4h, v6.4h 1640be168c0dSopenharmony_ci+ fmax v10.4h, v10.4h, v6.4h 1641be168c0dSopenharmony_ci+ fmax v12.4h, v12.4h, v6.4h 1642be168c0dSopenharmony_ci+ Write3x4: 1643be168c0dSopenharmony_ci+ mov x22, x21 1644be168c0dSopenharmony_ci+ cmp x13, #1 1645be168c0dSopenharmony_ci+ beq Write3x1 1646be168c0dSopenharmony_ci+ cmp x13, #2 1647be168c0dSopenharmony_ci+ beq Write3x2 1648be168c0dSopenharmony_ci+ cmp x13, #3 1649be168c0dSopenharmony_ci+ beq Write3x3 1650be168c0dSopenharmony_ci+ st1 {v8.4h}, [x22], x8 1651be168c0dSopenharmony_ci+ st1 {v10.4h}, [x22], x8 1652be168c0dSopenharmony_ci+ st1 {v12.4h}, [x22] 1653be168c0dSopenharmony_ci+ add x21, x21, #8 1654be168c0dSopenharmony_ci+ subs x13, x13, #4 1655be168c0dSopenharmony_ci+ bgt LoopCol3x4Core 1656be168c0dSopenharmony_ci+ b End 1657be168c0dSopenharmony_ci+ Write3x1: 1658be168c0dSopenharmony_ci+ st1 {v8.h}[0], [x22], x8 1659be168c0dSopenharmony_ci+ st1 {v10.h}[0], [x22], x8 1660be168c0dSopenharmony_ci+ st1 {v12.h}[0], [x22] 1661be168c0dSopenharmony_ci+ b End 1662be168c0dSopenharmony_ci+ Write3x2: 1663be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1664be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1665be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22] 1666be168c0dSopenharmony_ci+ b End 1667be168c0dSopenharmony_ci+ Write3x3: 1668be168c0dSopenharmony_ci+ add x23, x22, #4 1669be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1670be168c0dSopenharmony_ci+ st1 {v8.h}[2], [x23], x8 1671be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1672be168c0dSopenharmony_ci+ st1 {v10.h}[2], [x23], x8 1673be168c0dSopenharmony_ci+ st1 {v12.s}[0], [x22], x8 1674be168c0dSopenharmony_ci+ st1 {v12.h}[2], [x23], x8 1675be168c0dSopenharmony_ci+ b End 1676be168c0dSopenharmony_ci+ 1677be168c0dSopenharmony_ci+LoopRow2: 1678be168c0dSopenharmony_ci+ mov x11, x1 // reload matrixB 1679be168c0dSopenharmony_ci+ mov x12, x3 // reload bias 1680be168c0dSopenharmony_ci+ mov x13, x7 // reload col 1681be168c0dSopenharmony_ci+ mov x21, x2 // relocate output 1682be168c0dSopenharmony_ci+ subs x13, x13, #16 1683be168c0dSopenharmony_ci+ blt LoopCol2x8 1684be168c0dSopenharmony_ci+ LoopCol2x16: 1685be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1686be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1687be168c0dSopenharmony_ci+ cbnz x12, InitFromBias2x16 1688be168c0dSopenharmony_ci+ dup v8.2d, xzr 1689be168c0dSopenharmony_ci+ dup v9.2d, xzr 1690be168c0dSopenharmony_ci+ dup v10.2d, xzr 1691be168c0dSopenharmony_ci+ dup v11.2d, xzr 1692be168c0dSopenharmony_ci+ b Compute2x16Enter 1693be168c0dSopenharmony_ci+ InitFromBias2x16: 1694be168c0dSopenharmony_ci+ ld1 {v8.8h, v9.8h}, [x12] 1695be168c0dSopenharmony_ci+ ld1 {v10.8h, v11.8h}, [x12] 1696be168c0dSopenharmony_ci+ add x12, x12, #32 1697be168c0dSopenharmony_ci+ Compute2x16Enter: 1698be168c0dSopenharmony_ci+ bl Compute2x16Unit 1699be168c0dSopenharmony_ci+ Activation2x16: 1700be168c0dSopenharmony_ci+ cmp x4, #3 1701be168c0dSopenharmony_ci+ beq Relu62x16 1702be168c0dSopenharmony_ci+ cmp x4, #1 1703be168c0dSopenharmony_ci+ beq Relu2x16 1704be168c0dSopenharmony_ci+ b Write2x16 1705be168c0dSopenharmony_ci+ 1706be168c0dSopenharmony_ci+ Relu62x16: 1707be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1708be168c0dSopenharmony_ci+ fmin v9.8h, v9.8h, v7.8h 1709be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1710be168c0dSopenharmony_ci+ fmin v11.8h, v11.8h, v7.8h 1711be168c0dSopenharmony_ci+ 1712be168c0dSopenharmony_ci+ Relu2x16: 1713be168c0dSopenharmony_ci+ dup v6.8h, wzr 1714be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1715be168c0dSopenharmony_ci+ fmax v9.8h, v9.8h, v6.8h 1716be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1717be168c0dSopenharmony_ci+ fmax v11.8h, v11.8h, v6.8h 1718be168c0dSopenharmony_ci+ Write2x16: 1719be168c0dSopenharmony_ci+ mov x22, x21 1720be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h}, [x22], x8 1721be168c0dSopenharmony_ci+ st1 {v10.8h, v11.8h}, [x22] 1722be168c0dSopenharmony_ci+ add x21, x21, #32 1723be168c0dSopenharmony_ci+ subs x13, x13, #16 1724be168c0dSopenharmony_ci+ bge LoopCol2x16 1725be168c0dSopenharmony_ci+ 1726be168c0dSopenharmony_ci+ LoopCol2x8: 1727be168c0dSopenharmony_ci+ adds x13, x13, #16 1728be168c0dSopenharmony_ci+ cbz x13, End 1729be168c0dSopenharmony_ci+ subs x13, x13, #8 1730be168c0dSopenharmony_ci+ blt LoopCol2x4 1731be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1732be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1733be168c0dSopenharmony_ci+ cbnz x12, InitFromBias2x8 1734be168c0dSopenharmony_ci+ dup v8.2d, xzr 1735be168c0dSopenharmony_ci+ dup v10.2d, xzr 1736be168c0dSopenharmony_ci+ b Compute2x8Enter 1737be168c0dSopenharmony_ci+ InitFromBias2x8: 1738be168c0dSopenharmony_ci+ ld1 {v8.8h}, [x12] 1739be168c0dSopenharmony_ci+ ld1 {v10.8h}, [x12] 1740be168c0dSopenharmony_ci+ add x12, x12, #16 1741be168c0dSopenharmony_ci+ Compute2x8Enter: 1742be168c0dSopenharmony_ci+ bl Compute2x8Unit 1743be168c0dSopenharmony_ci+ Activation2x8: 1744be168c0dSopenharmony_ci+ cmp x4, #3 1745be168c0dSopenharmony_ci+ beq Relu62x8 1746be168c0dSopenharmony_ci+ cmp x4, #1 1747be168c0dSopenharmony_ci+ beq Relu2x8 1748be168c0dSopenharmony_ci+ b Write2x8 1749be168c0dSopenharmony_ci+ 1750be168c0dSopenharmony_ci+ Relu62x8: 1751be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1752be168c0dSopenharmony_ci+ fmin v10.8h, v10.8h, v7.8h 1753be168c0dSopenharmony_ci+ 1754be168c0dSopenharmony_ci+ Relu2x8: 1755be168c0dSopenharmony_ci+ dup v6.8h, wzr 1756be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1757be168c0dSopenharmony_ci+ fmax v10.8h, v10.8h, v6.8h 1758be168c0dSopenharmony_ci+ Write2x8: 1759be168c0dSopenharmony_ci+ mov x22, x21 1760be168c0dSopenharmony_ci+ st1 {v8.8h}, [x22], x8 1761be168c0dSopenharmony_ci+ st1 {v10.8h}, [x22] 1762be168c0dSopenharmony_ci+ add x21, x21, #16 1763be168c0dSopenharmony_ci+ subs x13, x13, #8 1764be168c0dSopenharmony_ci+ 1765be168c0dSopenharmony_ci+ LoopCol2x4: 1766be168c0dSopenharmony_ci+ adds x13, x13, #8 1767be168c0dSopenharmony_ci+ cbz x13, End 1768be168c0dSopenharmony_ci+ LoopCol2x4Core: 1769be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1770be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1771be168c0dSopenharmony_ci+ cbnz x12, InitFromBias2x4 1772be168c0dSopenharmony_ci+ dup v8.2s, wzr 1773be168c0dSopenharmony_ci+ dup v10.2s, wzr 1774be168c0dSopenharmony_ci+ b Compute2x4Enter 1775be168c0dSopenharmony_ci+ InitFromBias2x4: 1776be168c0dSopenharmony_ci+ ld1 {v8.4h}, [x12] 1777be168c0dSopenharmony_ci+ ld1 {v10.4h}, [x12] 1778be168c0dSopenharmony_ci+ add x12, x12, #8 1779be168c0dSopenharmony_ci+ Compute2x4Enter: 1780be168c0dSopenharmony_ci+ bl Compute2x4Unit 1781be168c0dSopenharmony_ci+ Activation2x4: 1782be168c0dSopenharmony_ci+ cmp x4, #3 1783be168c0dSopenharmony_ci+ beq Relu62x4 1784be168c0dSopenharmony_ci+ cmp x4, #1 1785be168c0dSopenharmony_ci+ beq Relu2x4 1786be168c0dSopenharmony_ci+ b Write2x4 1787be168c0dSopenharmony_ci+ 1788be168c0dSopenharmony_ci+ Relu62x4: 1789be168c0dSopenharmony_ci+ fmin v8.4h, v8.4h, v7.4h 1790be168c0dSopenharmony_ci+ fmin v10.4h, v10.4h, v7.4h 1791be168c0dSopenharmony_ci+ Relu2x4: 1792be168c0dSopenharmony_ci+ dup v6.4h, wzr 1793be168c0dSopenharmony_ci+ fmax v8.4h, v8.4h, v6.4h 1794be168c0dSopenharmony_ci+ fmax v10.4h, v10.4h, v6.4h 1795be168c0dSopenharmony_ci+ Write2x4: 1796be168c0dSopenharmony_ci+ mov x22, x21 1797be168c0dSopenharmony_ci+ cmp x13, #1 1798be168c0dSopenharmony_ci+ beq Write2x1 1799be168c0dSopenharmony_ci+ cmp x13, #2 1800be168c0dSopenharmony_ci+ beq Write2x2 1801be168c0dSopenharmony_ci+ cmp x13, #3 1802be168c0dSopenharmony_ci+ beq Write2x3 1803be168c0dSopenharmony_ci+ st1 {v8.4h}, [x22], x8 1804be168c0dSopenharmony_ci+ st1 {v10.4h}, [x22] 1805be168c0dSopenharmony_ci+ add x21, x21, #8 1806be168c0dSopenharmony_ci+ subs x13, x13, #4 1807be168c0dSopenharmony_ci+ bgt LoopCol2x4Core 1808be168c0dSopenharmony_ci+ b End 1809be168c0dSopenharmony_ci+ Write2x1: 1810be168c0dSopenharmony_ci+ st1 {v8.h}[0], [x22], x8 1811be168c0dSopenharmony_ci+ st1 {v10.h}[0], [x22] 1812be168c0dSopenharmony_ci+ b End 1813be168c0dSopenharmony_ci+ Write2x2: 1814be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1815be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22] 1816be168c0dSopenharmony_ci+ b End 1817be168c0dSopenharmony_ci+ Write2x3: 1818be168c0dSopenharmony_ci+ add x23, x22, #4 1819be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x22], x8 1820be168c0dSopenharmony_ci+ st1 {v8.h}[2], [x23], x8 1821be168c0dSopenharmony_ci+ st1 {v10.s}[0], [x22], x8 1822be168c0dSopenharmony_ci+ st1 {v10.h}[2], [x23], x8 1823be168c0dSopenharmony_ci+ b End 1824be168c0dSopenharmony_ci+ 1825be168c0dSopenharmony_ci+LoopRow1: 1826be168c0dSopenharmony_ci+ mov x11, x1 // reload matrixB 1827be168c0dSopenharmony_ci+ mov x12, x3 // reload bias 1828be168c0dSopenharmony_ci+ mov x13, x7 // reload col 1829be168c0dSopenharmony_ci+ mov x21, x2 // relocate output 1830be168c0dSopenharmony_ci+ subs x13, x13, #16 1831be168c0dSopenharmony_ci+ blt LoopCol1x8 1832be168c0dSopenharmony_ci+ LoopCol1x16: 1833be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1834be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1835be168c0dSopenharmony_ci+ cbnz x12, InitFromBias1x16 1836be168c0dSopenharmony_ci+ dup v8.2d, xzr 1837be168c0dSopenharmony_ci+ dup v9.2d, xzr 1838be168c0dSopenharmony_ci+ b Compute1x16Enter 1839be168c0dSopenharmony_ci+ InitFromBias1x16: 1840be168c0dSopenharmony_ci+ ld1 {v8.8h, v9.8h}, [x12], #32 1841be168c0dSopenharmony_ci+ Compute1x16Enter: 1842be168c0dSopenharmony_ci+ bl Compute1x16Unit 1843be168c0dSopenharmony_ci+ Activation1x16: 1844be168c0dSopenharmony_ci+ cmp x4, #3 1845be168c0dSopenharmony_ci+ beq Relu61x16 1846be168c0dSopenharmony_ci+ cmp x4, #1 1847be168c0dSopenharmony_ci+ beq Relu1x16 1848be168c0dSopenharmony_ci+ b Write1x16 1849be168c0dSopenharmony_ci+ 1850be168c0dSopenharmony_ci+ Relu61x16: 1851be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1852be168c0dSopenharmony_ci+ fmin v9.8h, v9.8h, v7.8h 1853be168c0dSopenharmony_ci+ 1854be168c0dSopenharmony_ci+ Relu1x16: 1855be168c0dSopenharmony_ci+ dup v6.8h, wzr 1856be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1857be168c0dSopenharmony_ci+ fmax v9.8h, v9.8h, v6.8h 1858be168c0dSopenharmony_ci+ Write1x16: 1859be168c0dSopenharmony_ci+ st1 {v8.8h, v9.8h}, [x21], #32 1860be168c0dSopenharmony_ci+ subs x13, x13, #16 1861be168c0dSopenharmony_ci+ bge LoopCol1x16 1862be168c0dSopenharmony_ci+ 1863be168c0dSopenharmony_ci+ LoopCol1x8: 1864be168c0dSopenharmony_ci+ adds x13, x13, #16 1865be168c0dSopenharmony_ci+ cbz x13, End 1866be168c0dSopenharmony_ci+ subs x13, x13, #8 1867be168c0dSopenharmony_ci+ blt LoopCol1x4 1868be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1869be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1870be168c0dSopenharmony_ci+ cbnz x12, InitFromBias1x8 1871be168c0dSopenharmony_ci+ dup v8.2d, xzr 1872be168c0dSopenharmony_ci+ b Compute1x8Enter 1873be168c0dSopenharmony_ci+ InitFromBias1x8: 1874be168c0dSopenharmony_ci+ ld1 {v8.8h}, [x12], #16 1875be168c0dSopenharmony_ci+ Compute1x8Enter: 1876be168c0dSopenharmony_ci+ bl Compute1x8Unit 1877be168c0dSopenharmony_ci+ Activation1x8: 1878be168c0dSopenharmony_ci+ cmp x4, #3 1879be168c0dSopenharmony_ci+ beq Relu61x8 1880be168c0dSopenharmony_ci+ cmp x4, #1 1881be168c0dSopenharmony_ci+ beq Relu1x8 1882be168c0dSopenharmony_ci+ b Write1x8 1883be168c0dSopenharmony_ci+ 1884be168c0dSopenharmony_ci+ Relu61x8: 1885be168c0dSopenharmony_ci+ fmin v8.8h, v8.8h, v7.8h 1886be168c0dSopenharmony_ci+ 1887be168c0dSopenharmony_ci+ Relu1x8: 1888be168c0dSopenharmony_ci+ dup v6.8h, wzr 1889be168c0dSopenharmony_ci+ fmax v8.8h, v8.8h, v6.8h 1890be168c0dSopenharmony_ci+ Write1x8: 1891be168c0dSopenharmony_ci+ st1 {v8.8h}, [x21], #16 1892be168c0dSopenharmony_ci+ subs x13, x13, #8 1893be168c0dSopenharmony_ci+ 1894be168c0dSopenharmony_ci+ LoopCol1x4: 1895be168c0dSopenharmony_ci+ adds x13, x13, #8 1896be168c0dSopenharmony_ci+ cbz x13, End 1897be168c0dSopenharmony_ci+ LoopCol1x4Core: 1898be168c0dSopenharmony_ci+ mov x10, x0 // update matrixA 1899be168c0dSopenharmony_ci+ mov x14, x5 // reload depth 1900be168c0dSopenharmony_ci+ cbnz x12, InitFromBias1x4 1901be168c0dSopenharmony_ci+ dup v8.2s, wzr 1902be168c0dSopenharmony_ci+ b Compute1x4Enter 1903be168c0dSopenharmony_ci+ InitFromBias1x4: 1904be168c0dSopenharmony_ci+ ld1 {v8.4h}, [x12], #8 1905be168c0dSopenharmony_ci+ Compute1x4Enter: 1906be168c0dSopenharmony_ci+ bl Compute1x4Unit 1907be168c0dSopenharmony_ci+ Activation1x4: 1908be168c0dSopenharmony_ci+ cmp x4, #3 1909be168c0dSopenharmony_ci+ beq Relu61x4 1910be168c0dSopenharmony_ci+ cmp x4, #1 1911be168c0dSopenharmony_ci+ beq Relu1x4 1912be168c0dSopenharmony_ci+ b Write1x4 1913be168c0dSopenharmony_ci+ 1914be168c0dSopenharmony_ci+ Relu61x4: 1915be168c0dSopenharmony_ci+ fmin v8.4h, v8.4h, v7.4h 1916be168c0dSopenharmony_ci+ Relu1x4: 1917be168c0dSopenharmony_ci+ dup v6.4h, wzr 1918be168c0dSopenharmony_ci+ fmax v8.4h, v8.4h, v6.4h 1919be168c0dSopenharmony_ci+ Write1x4: 1920be168c0dSopenharmony_ci+ cmp x13, #1 1921be168c0dSopenharmony_ci+ beq Write1x1 1922be168c0dSopenharmony_ci+ cmp x13, #2 1923be168c0dSopenharmony_ci+ beq Write1x2 1924be168c0dSopenharmony_ci+ cmp x13, #3 1925be168c0dSopenharmony_ci+ beq Write1x3 1926be168c0dSopenharmony_ci+ st1 {v8.4h}, [x21], #8 1927be168c0dSopenharmony_ci+ subs x13, x13, #4 1928be168c0dSopenharmony_ci+ bgt LoopCol1x4Core 1929be168c0dSopenharmony_ci+ b End 1930be168c0dSopenharmony_ci+ Write1x1: 1931be168c0dSopenharmony_ci+ st1 {v8.h}[0], [x21] 1932be168c0dSopenharmony_ci+ b End 1933be168c0dSopenharmony_ci+ Write1x2: 1934be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x21] 1935be168c0dSopenharmony_ci+ b End 1936be168c0dSopenharmony_ci+ Write1x3: 1937be168c0dSopenharmony_ci+ add x22, x21, #4 1938be168c0dSopenharmony_ci+ st1 {v8.s}[0], [x21] 1939be168c0dSopenharmony_ci+ st1 {v8.h}[2], [x22] 1940be168c0dSopenharmony_ci+ b End 1941be168c0dSopenharmony_ci+ 1942be168c0dSopenharmony_ci+Compute12x16Unit: 1943be168c0dSopenharmony_ci+ subs x14, x14, #2 1944be168c0dSopenharmony_ci+ ble Compute12x16End 1945be168c0dSopenharmony_ci+ Compute12x16: 1946be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 1947be168c0dSopenharmony_ci+ ld1 {v1.8h, v2.8h}, [x10], #32 1948be168c0dSopenharmony_ci+ ld1 {v4.8h, v5.8h}, [x11], #32 1949be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 1950be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 1951be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 1952be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 1953be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 1954be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 1955be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 1956be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 1957be168c0dSopenharmony_ci+ fmla v24.8h, v3.8h, v1.h[0] 1958be168c0dSopenharmony_ci+ fmla v26.8h, v3.8h, v1.h[1] 1959be168c0dSopenharmony_ci+ fmla v28.8h, v3.8h, v1.h[2] 1960be168c0dSopenharmony_ci+ fmla v30.8h, v3.8h, v1.h[3] 1961be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1962be168c0dSopenharmony_ci+ ld1 {v6.8h}, [x11], #16 1963be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 1964be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 1965be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 1966be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 1967be168c0dSopenharmony_ci+ fmla v17.8h, v4.8h, v0.h[4] 1968be168c0dSopenharmony_ci+ fmla v19.8h, v4.8h, v0.h[5] 1969be168c0dSopenharmony_ci+ fmla v21.8h, v4.8h, v0.h[6] 1970be168c0dSopenharmony_ci+ fmla v23.8h, v4.8h, v0.h[7] 1971be168c0dSopenharmony_ci+ fmla v25.8h, v4.8h, v1.h[0] 1972be168c0dSopenharmony_ci+ fmla v27.8h, v4.8h, v1.h[1] 1973be168c0dSopenharmony_ci+ fmla v29.8h, v4.8h, v1.h[2] 1974be168c0dSopenharmony_ci+ fmla v31.8h, v4.8h, v1.h[3] 1975be168c0dSopenharmony_ci+ 1976be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v1.h[4] 1977be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[5] 1978be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v1.h[6] 1979be168c0dSopenharmony_ci+ fmla v14.8h, v5.8h, v1.h[7] 1980be168c0dSopenharmony_ci+ fmla v16.8h, v5.8h, v2.h[0] 1981be168c0dSopenharmony_ci+ fmla v18.8h, v5.8h, v2.h[1] 1982be168c0dSopenharmony_ci+ fmla v20.8h, v5.8h, v2.h[2] 1983be168c0dSopenharmony_ci+ fmla v22.8h, v5.8h, v2.h[3] 1984be168c0dSopenharmony_ci+ fmla v24.8h, v5.8h, v2.h[4] 1985be168c0dSopenharmony_ci+ fmla v26.8h, v5.8h, v2.h[5] 1986be168c0dSopenharmony_ci+ fmla v28.8h, v5.8h, v2.h[6] 1987be168c0dSopenharmony_ci+ fmla v30.8h, v5.8h, v2.h[7] 1988be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 1989be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 1990be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v1.h[4] 1991be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[5] 1992be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v1.h[6] 1993be168c0dSopenharmony_ci+ fmla v15.8h, v6.8h, v1.h[7] 1994be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 1995be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 1996be168c0dSopenharmony_ci+ fmla v17.8h, v6.8h, v2.h[0] 1997be168c0dSopenharmony_ci+ fmla v19.8h, v6.8h, v2.h[1] 1998be168c0dSopenharmony_ci+ fmla v21.8h, v6.8h, v2.h[2] 1999be168c0dSopenharmony_ci+ fmla v23.8h, v6.8h, v2.h[3] 2000be168c0dSopenharmony_ci+ fmla v25.8h, v6.8h, v2.h[4] 2001be168c0dSopenharmony_ci+ fmla v27.8h, v6.8h, v2.h[5] 2002be168c0dSopenharmony_ci+ fmla v29.8h, v6.8h, v2.h[6] 2003be168c0dSopenharmony_ci+ fmla v31.8h, v6.8h, v2.h[7] 2004be168c0dSopenharmony_ci+ 2005be168c0dSopenharmony_ci+ subs x14, x14, #2 2006be168c0dSopenharmony_ci+ bgt Compute12x16 2007be168c0dSopenharmony_ci+ Compute12x16End: 2008be168c0dSopenharmony_ci+ cbnz x14, Compute12x16End1 2009be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2010be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10], #8 2011be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2012be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2013be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2014be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2015be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2016be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2017be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2018be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2019be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2020be168c0dSopenharmony_ci+ fmla v24.8h, v3.8h, v1.h[0] 2021be168c0dSopenharmony_ci+ fmla v26.8h, v3.8h, v1.h[1] 2022be168c0dSopenharmony_ci+ fmla v28.8h, v3.8h, v1.h[2] 2023be168c0dSopenharmony_ci+ fmla v30.8h, v3.8h, v1.h[3] 2024be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2025be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2026be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2027be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2028be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2029be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2030be168c0dSopenharmony_ci+ ld1 {v2.8h}, [x10], #16 2031be168c0dSopenharmony_ci+ fmla v17.8h, v4.8h, v0.h[4] 2032be168c0dSopenharmony_ci+ fmla v19.8h, v4.8h, v0.h[5] 2033be168c0dSopenharmony_ci+ fmla v21.8h, v4.8h, v0.h[6] 2034be168c0dSopenharmony_ci+ fmla v23.8h, v4.8h, v0.h[7] 2035be168c0dSopenharmony_ci+ fmla v25.8h, v4.8h, v1.h[0] 2036be168c0dSopenharmony_ci+ fmla v27.8h, v4.8h, v1.h[1] 2037be168c0dSopenharmony_ci+ fmla v29.8h, v4.8h, v1.h[2] 2038be168c0dSopenharmony_ci+ fmla v31.8h, v4.8h, v1.h[3] 2039be168c0dSopenharmony_ci+ mov v0.16b, v2.16b 2040be168c0dSopenharmony_ci+ Compute12x16End1: 2041be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10] 2042be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2043be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2044be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2045be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2046be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2047be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2048be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2049be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2050be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2051be168c0dSopenharmony_ci+ fmla v24.8h, v3.8h, v1.h[0] 2052be168c0dSopenharmony_ci+ fmla v26.8h, v3.8h, v1.h[1] 2053be168c0dSopenharmony_ci+ fmla v28.8h, v3.8h, v1.h[2] 2054be168c0dSopenharmony_ci+ fmla v30.8h, v3.8h, v1.h[3] 2055be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2056be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2057be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2058be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2059be168c0dSopenharmony_ci+ fmla v17.8h, v4.8h, v0.h[4] 2060be168c0dSopenharmony_ci+ fmla v19.8h, v4.8h, v0.h[5] 2061be168c0dSopenharmony_ci+ fmla v21.8h, v4.8h, v0.h[6] 2062be168c0dSopenharmony_ci+ fmla v23.8h, v4.8h, v0.h[7] 2063be168c0dSopenharmony_ci+ fmla v25.8h, v4.8h, v1.h[0] 2064be168c0dSopenharmony_ci+ fmla v27.8h, v4.8h, v1.h[1] 2065be168c0dSopenharmony_ci+ fmla v29.8h, v4.8h, v1.h[2] 2066be168c0dSopenharmony_ci+ fmla v31.8h, v4.8h, v1.h[3] 2067be168c0dSopenharmony_ci+ ret 2068be168c0dSopenharmony_ci+ 2069be168c0dSopenharmony_ci+Compute12x8Unit: 2070be168c0dSopenharmony_ci+ subs x14, x14, #2 2071be168c0dSopenharmony_ci+ ble Compute12x8End 2072be168c0dSopenharmony_ci+ Compute12x8: 2073be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2074be168c0dSopenharmony_ci+ ld1 {v1.8h, v2.8h}, [x10], #32 2075be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2076be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2077be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2078be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2079be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2080be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2081be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2082be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2083be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2084be168c0dSopenharmony_ci+ fmla v24.8h, v3.8h, v1.h[0] 2085be168c0dSopenharmony_ci+ fmla v26.8h, v3.8h, v1.h[1] 2086be168c0dSopenharmony_ci+ fmla v28.8h, v3.8h, v1.h[2] 2087be168c0dSopenharmony_ci+ fmla v30.8h, v3.8h, v1.h[3] 2088be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2089be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2090be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v1.h[4] 2091be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[5] 2092be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v1.h[6] 2093be168c0dSopenharmony_ci+ fmla v14.8h, v4.8h, v1.h[7] 2094be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2095be168c0dSopenharmony_ci+ fmla v16.8h, v4.8h, v2.h[0] 2096be168c0dSopenharmony_ci+ fmla v18.8h, v4.8h, v2.h[1] 2097be168c0dSopenharmony_ci+ fmla v20.8h, v4.8h, v2.h[2] 2098be168c0dSopenharmony_ci+ fmla v22.8h, v4.8h, v2.h[3] 2099be168c0dSopenharmony_ci+ fmla v24.8h, v4.8h, v2.h[4] 2100be168c0dSopenharmony_ci+ fmla v26.8h, v4.8h, v2.h[5] 2101be168c0dSopenharmony_ci+ fmla v28.8h, v4.8h, v2.h[6] 2102be168c0dSopenharmony_ci+ fmla v30.8h, v4.8h, v2.h[7] 2103be168c0dSopenharmony_ci+ 2104be168c0dSopenharmony_ci+ subs x14, x14, #2 2105be168c0dSopenharmony_ci+ bgt Compute12x8 2106be168c0dSopenharmony_ci+ Compute12x8End: 2107be168c0dSopenharmony_ci+ cbnz x14, Compute12x8End1 2108be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2109be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10], #8 2110be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2111be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2112be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2113be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2114be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2115be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2116be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2117be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2118be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2119be168c0dSopenharmony_ci+ fmla v24.8h, v3.8h, v1.h[0] 2120be168c0dSopenharmony_ci+ fmla v26.8h, v3.8h, v1.h[1] 2121be168c0dSopenharmony_ci+ fmla v28.8h, v3.8h, v1.h[2] 2122be168c0dSopenharmony_ci+ fmla v30.8h, v3.8h, v1.h[3] 2123be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2124be168c0dSopenharmony_ci+ mov v3.16b, v4.16b 2125be168c0dSopenharmony_ci+ Compute12x8End1: 2126be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10] 2127be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2128be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2129be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2130be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2131be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2132be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2133be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2134be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2135be168c0dSopenharmony_ci+ fmla v24.8h, v3.8h, v1.h[0] 2136be168c0dSopenharmony_ci+ fmla v26.8h, v3.8h, v1.h[1] 2137be168c0dSopenharmony_ci+ fmla v28.8h, v3.8h, v1.h[2] 2138be168c0dSopenharmony_ci+ fmla v30.8h, v3.8h, v1.h[3] 2139be168c0dSopenharmony_ci+ ret 2140be168c0dSopenharmony_ci+ 2141be168c0dSopenharmony_ci+Compute12x4Unit: 2142be168c0dSopenharmony_ci+ subs x14, x14, #2 2143be168c0dSopenharmony_ci+ ble Compute12x4End 2144be168c0dSopenharmony_ci+ Compute12x4: 2145be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2146be168c0dSopenharmony_ci+ ld1 {v1.8h, v2.8h}, [x10], #32 2147be168c0dSopenharmony_ci+ ld1 {v4.4h}, [x11], #8 2148be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2149be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2150be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2151be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2152be168c0dSopenharmony_ci+ fmla v16.4h, v3.4h, v0.h[4] 2153be168c0dSopenharmony_ci+ fmla v18.4h, v3.4h, v0.h[5] 2154be168c0dSopenharmony_ci+ fmla v20.4h, v3.4h, v0.h[6] 2155be168c0dSopenharmony_ci+ fmla v22.4h, v3.4h, v0.h[7] 2156be168c0dSopenharmony_ci+ fmla v24.4h, v3.4h, v1.h[0] 2157be168c0dSopenharmony_ci+ fmla v26.4h, v3.4h, v1.h[1] 2158be168c0dSopenharmony_ci+ fmla v28.4h, v3.4h, v1.h[2] 2159be168c0dSopenharmony_ci+ fmla v30.4h, v3.4h, v1.h[3] 2160be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2161be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 2162be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v1.h[4] 2163be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[5] 2164be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v1.h[6] 2165be168c0dSopenharmony_ci+ fmla v14.4h, v4.4h, v1.h[7] 2166be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2167be168c0dSopenharmony_ci+ fmla v16.4h, v4.4h, v2.h[0] 2168be168c0dSopenharmony_ci+ fmla v18.4h, v4.4h, v2.h[1] 2169be168c0dSopenharmony_ci+ fmla v20.4h, v4.4h, v2.h[2] 2170be168c0dSopenharmony_ci+ fmla v22.4h, v4.4h, v2.h[3] 2171be168c0dSopenharmony_ci+ fmla v24.4h, v4.4h, v2.h[4] 2172be168c0dSopenharmony_ci+ fmla v26.4h, v4.4h, v2.h[5] 2173be168c0dSopenharmony_ci+ fmla v28.4h, v4.4h, v2.h[6] 2174be168c0dSopenharmony_ci+ fmla v30.4h, v4.4h, v2.h[7] 2175be168c0dSopenharmony_ci+ 2176be168c0dSopenharmony_ci+ subs x14, x14, #2 2177be168c0dSopenharmony_ci+ bgt Compute12x4 2178be168c0dSopenharmony_ci+ Compute12x4End: 2179be168c0dSopenharmony_ci+ cbnz x14, Compute12x4End1 2180be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2181be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10], #8 2182be168c0dSopenharmony_ci+ ld1 {v4.4h}, [x11], #8 2183be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2184be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2185be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2186be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2187be168c0dSopenharmony_ci+ fmla v16.4h, v3.4h, v0.h[4] 2188be168c0dSopenharmony_ci+ fmla v18.4h, v3.4h, v0.h[5] 2189be168c0dSopenharmony_ci+ fmla v20.4h, v3.4h, v0.h[6] 2190be168c0dSopenharmony_ci+ fmla v22.4h, v3.4h, v0.h[7] 2191be168c0dSopenharmony_ci+ fmla v24.4h, v3.4h, v1.h[0] 2192be168c0dSopenharmony_ci+ fmla v26.4h, v3.4h, v1.h[1] 2193be168c0dSopenharmony_ci+ fmla v28.4h, v3.4h, v1.h[2] 2194be168c0dSopenharmony_ci+ fmla v30.4h, v3.4h, v1.h[3] 2195be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2196be168c0dSopenharmony_ci+ mov v3.8b, v4.8b 2197be168c0dSopenharmony_ci+ Compute12x4End1: 2198be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10] 2199be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2200be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2201be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2202be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2203be168c0dSopenharmony_ci+ fmla v16.4h, v3.4h, v0.h[4] 2204be168c0dSopenharmony_ci+ fmla v18.4h, v3.4h, v0.h[5] 2205be168c0dSopenharmony_ci+ fmla v20.4h, v3.4h, v0.h[6] 2206be168c0dSopenharmony_ci+ fmla v22.4h, v3.4h, v0.h[7] 2207be168c0dSopenharmony_ci+ fmla v24.4h, v3.4h, v1.h[0] 2208be168c0dSopenharmony_ci+ fmla v26.4h, v3.4h, v1.h[1] 2209be168c0dSopenharmony_ci+ fmla v28.4h, v3.4h, v1.h[2] 2210be168c0dSopenharmony_ci+ fmla v30.4h, v3.4h, v1.h[3] 2211be168c0dSopenharmony_ci+ ret 2212be168c0dSopenharmony_ci+ 2213be168c0dSopenharmony_ci+Compute8x16Unit: 2214be168c0dSopenharmony_ci+ subs x14, x14, #2 2215be168c0dSopenharmony_ci+ ble Compute8x16End 2216be168c0dSopenharmony_ci+ Compute8x16: 2217be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2218be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x10], #16 2219be168c0dSopenharmony_ci+ ld1 {v4.8h, v5.8h}, [x11], #32 2220be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2221be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2222be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2223be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2224be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2225be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2226be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2227be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2228be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2229be168c0dSopenharmony_ci+ ld1 {v6.8h}, [x11], #16 2230be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2231be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2232be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2233be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2234be168c0dSopenharmony_ci+ fmla v17.8h, v4.8h, v0.h[4] 2235be168c0dSopenharmony_ci+ fmla v19.8h, v4.8h, v0.h[5] 2236be168c0dSopenharmony_ci+ fmla v21.8h, v4.8h, v0.h[6] 2237be168c0dSopenharmony_ci+ fmla v23.8h, v4.8h, v0.h[7] 2238be168c0dSopenharmony_ci+ 2239be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v1.h[0] 2240be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2241be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v1.h[2] 2242be168c0dSopenharmony_ci+ fmla v14.8h, v5.8h, v1.h[3] 2243be168c0dSopenharmony_ci+ fmla v16.8h, v5.8h, v1.h[4] 2244be168c0dSopenharmony_ci+ fmla v18.8h, v5.8h, v1.h[5] 2245be168c0dSopenharmony_ci+ fmla v20.8h, v5.8h, v1.h[6] 2246be168c0dSopenharmony_ci+ fmla v22.8h, v5.8h, v1.h[7] 2247be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2248be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2249be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v1.h[0] 2250be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2251be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v1.h[2] 2252be168c0dSopenharmony_ci+ fmla v15.8h, v6.8h, v1.h[3] 2253be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2254be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2255be168c0dSopenharmony_ci+ fmla v17.8h, v6.8h, v1.h[4] 2256be168c0dSopenharmony_ci+ fmla v19.8h, v6.8h, v1.h[5] 2257be168c0dSopenharmony_ci+ fmla v21.8h, v6.8h, v1.h[6] 2258be168c0dSopenharmony_ci+ fmla v23.8h, v6.8h, v1.h[7] 2259be168c0dSopenharmony_ci+ 2260be168c0dSopenharmony_ci+ subs x14, x14, #2 2261be168c0dSopenharmony_ci+ bgt Compute8x16 2262be168c0dSopenharmony_ci+ Compute8x16End: 2263be168c0dSopenharmony_ci+ cbnz x14, Compute8x16End1 2264be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2265be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x10] 2266be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2267be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2268be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2269be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2270be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2271be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2272be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2273be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2274be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2275be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2276be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2277be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2278be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2279be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2280be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2281be168c0dSopenharmony_ci+ fmla v17.8h, v4.8h, v0.h[4] 2282be168c0dSopenharmony_ci+ fmla v19.8h, v4.8h, v0.h[5] 2283be168c0dSopenharmony_ci+ fmla v21.8h, v4.8h, v0.h[6] 2284be168c0dSopenharmony_ci+ fmla v23.8h, v4.8h, v0.h[7] 2285be168c0dSopenharmony_ci+ mov v0.16b, v1.16b 2286be168c0dSopenharmony_ci+ Compute8x16End1: 2287be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2288be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2289be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2290be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2291be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2292be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2293be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2294be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2295be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2296be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2297be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2298be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2299be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2300be168c0dSopenharmony_ci+ fmla v17.8h, v4.8h, v0.h[4] 2301be168c0dSopenharmony_ci+ fmla v19.8h, v4.8h, v0.h[5] 2302be168c0dSopenharmony_ci+ fmla v21.8h, v4.8h, v0.h[6] 2303be168c0dSopenharmony_ci+ fmla v23.8h, v4.8h, v0.h[7] 2304be168c0dSopenharmony_ci+ ret 2305be168c0dSopenharmony_ci+ 2306be168c0dSopenharmony_ci+Compute8x8Unit: 2307be168c0dSopenharmony_ci+ subs x14, x14, #2 2308be168c0dSopenharmony_ci+ ble Compute8x8End 2309be168c0dSopenharmony_ci+ Compute8x8: 2310be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2311be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x10], #16 2312be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2313be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2314be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2315be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2316be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2317be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2318be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2319be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2320be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2321be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2322be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2323be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v1.h[0] 2324be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 2325be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v1.h[2] 2326be168c0dSopenharmony_ci+ fmla v14.8h, v4.8h, v1.h[3] 2327be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2328be168c0dSopenharmony_ci+ fmla v16.8h, v4.8h, v1.h[4] 2329be168c0dSopenharmony_ci+ fmla v18.8h, v4.8h, v1.h[5] 2330be168c0dSopenharmony_ci+ fmla v20.8h, v4.8h, v1.h[6] 2331be168c0dSopenharmony_ci+ fmla v22.8h, v4.8h, v1.h[7] 2332be168c0dSopenharmony_ci+ 2333be168c0dSopenharmony_ci+ subs x14, x14, #2 2334be168c0dSopenharmony_ci+ bgt Compute8x8 2335be168c0dSopenharmony_ci+ Compute8x8End: 2336be168c0dSopenharmony_ci+ cbnz x14, Compute8x8End1 2337be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2338be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x10] 2339be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2340be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2341be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2342be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2343be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2344be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2345be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2346be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2347be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2348be168c0dSopenharmony_ci+ mov v0.16b, v1.16b 2349be168c0dSopenharmony_ci+ mov v3.16b, v4.16b 2350be168c0dSopenharmony_ci+ Compute8x8End1: 2351be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2352be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2353be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2354be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2355be168c0dSopenharmony_ci+ fmla v16.8h, v3.8h, v0.h[4] 2356be168c0dSopenharmony_ci+ fmla v18.8h, v3.8h, v0.h[5] 2357be168c0dSopenharmony_ci+ fmla v20.8h, v3.8h, v0.h[6] 2358be168c0dSopenharmony_ci+ fmla v22.8h, v3.8h, v0.h[7] 2359be168c0dSopenharmony_ci+ ret 2360be168c0dSopenharmony_ci+ 2361be168c0dSopenharmony_ci+Compute8x4Unit: 2362be168c0dSopenharmony_ci+ subs x14, x14, #2 2363be168c0dSopenharmony_ci+ ble Compute8x4End 2364be168c0dSopenharmony_ci+ Compute8x4: 2365be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2366be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x10], #16 2367be168c0dSopenharmony_ci+ ld1 {v4.4h}, [x11], #8 2368be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2369be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2370be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2371be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2372be168c0dSopenharmony_ci+ fmla v16.4h, v3.4h, v0.h[4] 2373be168c0dSopenharmony_ci+ fmla v18.4h, v3.4h, v0.h[5] 2374be168c0dSopenharmony_ci+ fmla v20.4h, v3.4h, v0.h[6] 2375be168c0dSopenharmony_ci+ fmla v22.4h, v3.4h, v0.h[7] 2376be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2377be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 2378be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v1.h[0] 2379be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 2380be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v1.h[2] 2381be168c0dSopenharmony_ci+ fmla v14.4h, v4.4h, v1.h[3] 2382be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2383be168c0dSopenharmony_ci+ fmla v16.4h, v4.4h, v1.h[4] 2384be168c0dSopenharmony_ci+ fmla v18.4h, v4.4h, v1.h[5] 2385be168c0dSopenharmony_ci+ fmla v20.4h, v4.4h, v1.h[6] 2386be168c0dSopenharmony_ci+ fmla v22.4h, v4.4h, v1.h[7] 2387be168c0dSopenharmony_ci+ 2388be168c0dSopenharmony_ci+ subs x14, x14, #2 2389be168c0dSopenharmony_ci+ bgt Compute8x4 2390be168c0dSopenharmony_ci+ Compute8x4End: 2391be168c0dSopenharmony_ci+ cbnz x14, Compute8x4End1 2392be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2393be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x10] 2394be168c0dSopenharmony_ci+ ld1 {v4.4h}, [x11], #8 2395be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2396be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2397be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2398be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2399be168c0dSopenharmony_ci+ fmla v16.4h, v3.4h, v0.h[4] 2400be168c0dSopenharmony_ci+ fmla v18.4h, v3.4h, v0.h[5] 2401be168c0dSopenharmony_ci+ fmla v20.4h, v3.4h, v0.h[6] 2402be168c0dSopenharmony_ci+ fmla v22.4h, v3.4h, v0.h[7] 2403be168c0dSopenharmony_ci+ mov v0.16b, v1.16b 2404be168c0dSopenharmony_ci+ mov v3.8b, v4.8b 2405be168c0dSopenharmony_ci+ Compute8x4End1: 2406be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2407be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2408be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2409be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2410be168c0dSopenharmony_ci+ fmla v16.4h, v3.4h, v0.h[4] 2411be168c0dSopenharmony_ci+ fmla v18.4h, v3.4h, v0.h[5] 2412be168c0dSopenharmony_ci+ fmla v20.4h, v3.4h, v0.h[6] 2413be168c0dSopenharmony_ci+ fmla v22.4h, v3.4h, v0.h[7] 2414be168c0dSopenharmony_ci+ ret 2415be168c0dSopenharmony_ci+ 2416be168c0dSopenharmony_ci+Compute4x16Unit: 2417be168c0dSopenharmony_ci+ subs x14, x14, #2 2418be168c0dSopenharmony_ci+ ble Compute4x16End 2419be168c0dSopenharmony_ci+ Compute4x16: 2420be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2421be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10], #8 2422be168c0dSopenharmony_ci+ ld1 {v4.8h, v5.8h}, [x11], #32 2423be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2424be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2425be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2426be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2427be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2428be168c0dSopenharmony_ci+ ld1 {v6.8h}, [x11], #16 2429be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2430be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2431be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2432be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2433be168c0dSopenharmony_ci+ 2434be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v1.h[0] 2435be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2436be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v1.h[2] 2437be168c0dSopenharmony_ci+ fmla v14.8h, v5.8h, v1.h[3] 2438be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2439be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2440be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v1.h[0] 2441be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2442be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v1.h[2] 2443be168c0dSopenharmony_ci+ fmla v15.8h, v6.8h, v1.h[3] 2444be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 2445be168c0dSopenharmony_ci+ 2446be168c0dSopenharmony_ci+ subs x14, x14, #2 2447be168c0dSopenharmony_ci+ bgt Compute4x16 2448be168c0dSopenharmony_ci+ Compute4x16End: 2449be168c0dSopenharmony_ci+ cbnz x14, Compute4x16End1 2450be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2451be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10] 2452be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2453be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2454be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2455be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2456be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2457be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2458be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2459be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2460be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2461be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2462be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2463be168c0dSopenharmony_ci+ mov v0.8b, v1.8b 2464be168c0dSopenharmony_ci+ Compute4x16End1: 2465be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2466be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2467be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2468be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2469be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2470be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2471be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v0.h[1] 2472be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v0.h[2] 2473be168c0dSopenharmony_ci+ fmla v15.8h, v4.8h, v0.h[3] 2474be168c0dSopenharmony_ci+ ret 2475be168c0dSopenharmony_ci+ 2476be168c0dSopenharmony_ci+Compute4x8Unit: 2477be168c0dSopenharmony_ci+ subs x14, x14, #2 2478be168c0dSopenharmony_ci+ ble Compute4x8End 2479be168c0dSopenharmony_ci+ Compute4x8: 2480be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2481be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10], #8 2482be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2483be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2484be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2485be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2486be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2487be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2488be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2489be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v1.h[0] 2490be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 2491be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v1.h[2] 2492be168c0dSopenharmony_ci+ fmla v14.8h, v4.8h, v1.h[3] 2493be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 2494be168c0dSopenharmony_ci+ 2495be168c0dSopenharmony_ci+ subs x14, x14, #2 2496be168c0dSopenharmony_ci+ bgt Compute4x8 2497be168c0dSopenharmony_ci+ Compute4x8End: 2498be168c0dSopenharmony_ci+ cbnz x14, Compute4x8End1 2499be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2500be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10] 2501be168c0dSopenharmony_ci+ ld1 {v4.8h}, [x11], #16 2502be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2503be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2504be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2505be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2506be168c0dSopenharmony_ci+ mov v0.8b, v1.8b 2507be168c0dSopenharmony_ci+ mov v3.16b, v4.16b 2508be168c0dSopenharmony_ci+ Compute4x8End1: 2509be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2510be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v0.h[1] 2511be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v0.h[2] 2512be168c0dSopenharmony_ci+ fmla v14.8h, v3.8h, v0.h[3] 2513be168c0dSopenharmony_ci+ ret 2514be168c0dSopenharmony_ci+ 2515be168c0dSopenharmony_ci+Compute4x4Unit: 2516be168c0dSopenharmony_ci+ subs x14, x14, #2 2517be168c0dSopenharmony_ci+ ble Compute4x4End 2518be168c0dSopenharmony_ci+ Compute4x4: 2519be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2520be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10], #8 2521be168c0dSopenharmony_ci+ ld1 {v4.4h}, [x11], #8 2522be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2523be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2524be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2525be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2526be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2527be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 2528be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v1.h[0] 2529be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 2530be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v1.h[2] 2531be168c0dSopenharmony_ci+ fmla v14.4h, v4.4h, v1.h[3] 2532be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 2533be168c0dSopenharmony_ci+ 2534be168c0dSopenharmony_ci+ subs x14, x14, #2 2535be168c0dSopenharmony_ci+ bgt Compute4x4 2536be168c0dSopenharmony_ci+ Compute4x4End: 2537be168c0dSopenharmony_ci+ cbnz x14, Compute4x4End1 2538be168c0dSopenharmony_ci+ prfm pldl1keep, [x10, #632] 2539be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x10] 2540be168c0dSopenharmony_ci+ ld1 {v4.4h}, [x11], #8 2541be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2542be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2543be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2544be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2545be168c0dSopenharmony_ci+ mov v0.8b, v1.8b 2546be168c0dSopenharmony_ci+ mov v3.8b, v4.8b 2547be168c0dSopenharmony_ci+ Compute4x4End1: 2548be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2549be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v0.h[1] 2550be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v0.h[2] 2551be168c0dSopenharmony_ci+ fmla v14.4h, v3.4h, v0.h[3] 2552be168c0dSopenharmony_ci+ ret 2553be168c0dSopenharmony_ci+ 2554be168c0dSopenharmony_ci+Compute3x16Unit: 2555be168c0dSopenharmony_ci+ add x19, x10, x16 2556be168c0dSopenharmony_ci+ add x20, x10, x16, lsl #1 2557be168c0dSopenharmony_ci+ subs x14, x14, #8 2558be168c0dSopenharmony_ci+ blt Compute3x16End4 2559be168c0dSopenharmony_ci+ Compute3x16: 2560be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2561be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x19], #16 2562be168c0dSopenharmony_ci+ ld1 {v2.8h}, [x20], #16 2563be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2564be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2565be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2566be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2567be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2568be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2569be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2570be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 2571be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[0] 2572be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 2573be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2574be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[1] 2575be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2576be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 2577be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2578be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[1] 2579be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 2580be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[2] 2581be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[2] 2582be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2583be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 2584be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[2] 2585be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[2] 2586be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[3] 2587be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[3] 2588be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[3] 2589be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2590be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2591be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[3] 2592be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[3] 2593be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[3] 2594be168c0dSopenharmony_ci+ 2595be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[4] 2596be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[4] 2597be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[4] 2598be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2599be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[4] 2600be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[4] 2601be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[4] 2602be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[5] 2603be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[5] 2604be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[5] 2605be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2606be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[5] 2607be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[5] 2608be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[5] 2609be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[6] 2610be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[6] 2611be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[6] 2612be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2613be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[6] 2614be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[6] 2615be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[6] 2616be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[7] 2617be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[7] 2618be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[7] 2619be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[7] 2620be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[7] 2621be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[7] 2622be168c0dSopenharmony_ci+ 2623be168c0dSopenharmony_ci+ subs x14, x14, #8 2624be168c0dSopenharmony_ci+ bge Compute3x16 2625be168c0dSopenharmony_ci+ Compute3x16End4: 2626be168c0dSopenharmony_ci+ adds x14, x14, #8 2627be168c0dSopenharmony_ci+ cbz x14, Compute3x16Return 2628be168c0dSopenharmony_ci+ subs x14, x14, #4 2629be168c0dSopenharmony_ci+ blt Compute3x16EndTail 2630be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 2631be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19], #8 2632be168c0dSopenharmony_ci+ ld1 {v2.4h}, [x20], #8 2633be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2634be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2635be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2636be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2637be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2638be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2639be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2640be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 2641be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[0] 2642be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 2643be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2644be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[1] 2645be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2646be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 2647be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2648be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[1] 2649be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 2650be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[2] 2651be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[2] 2652be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2653be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 2654be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[2] 2655be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[2] 2656be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[3] 2657be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[3] 2658be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[3] 2659be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[3] 2660be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[3] 2661be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[3] 2662be168c0dSopenharmony_ci+ subs x14, x14, #4 2663be168c0dSopenharmony_ci+ Compute3x16EndTail: 2664be168c0dSopenharmony_ci+ adds x14, x14, #4 2665be168c0dSopenharmony_ci+ cbz x14, Compute3x16Return 2666be168c0dSopenharmony_ci+ cmp x14, #1 2667be168c0dSopenharmony_ci+ beq Compute3x16EndTail1 2668be168c0dSopenharmony_ci+ cmp x14, #2 2669be168c0dSopenharmony_ci+ beq Compute3x16EndTail2 2670be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 2671be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19] 2672be168c0dSopenharmony_ci+ ld1 {v2.s}[0], [x20], #4 2673be168c0dSopenharmony_ci+ ld1 {v2.h}[2], [x20] 2674be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2675be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2676be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2677be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2678be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2679be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2680be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2681be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 2682be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[0] 2683be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 2684be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2685be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[1] 2686be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2687be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 2688be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2689be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[1] 2690be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 2691be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[2] 2692be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[2] 2693be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 2694be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[2] 2695be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[2] 2696be168c0dSopenharmony_ci+ b Compute3x16Return 2697be168c0dSopenharmony_ci+ Compute3x16EndTail2: 2698be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 2699be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19] 2700be168c0dSopenharmony_ci+ ld1 {v2.s}[0], [x20] 2701be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2702be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2703be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2704be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2705be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2706be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2707be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2708be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 2709be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[0] 2710be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 2711be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2712be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[1] 2713be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 2714be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2715be168c0dSopenharmony_ci+ fmla v13.8h, v6.8h, v2.h[1] 2716be168c0dSopenharmony_ci+ b Compute3x16Return 2717be168c0dSopenharmony_ci+ Compute3x16EndTail1: 2718be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 2719be168c0dSopenharmony_ci+ ld1 {v1.h}[0], [x19] 2720be168c0dSopenharmony_ci+ ld1 {v2.h}[0], [x20] 2721be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2722be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2723be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2724be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2725be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2726be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2727be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 2728be168c0dSopenharmony_ci+ fmla v13.8h, v4.8h, v2.h[0] 2729be168c0dSopenharmony_ci+ Compute3x16Return: 2730be168c0dSopenharmony_ci+ ret 2731be168c0dSopenharmony_ci+ 2732be168c0dSopenharmony_ci+Compute3x8Unit: 2733be168c0dSopenharmony_ci+ add x19, x10, x16 2734be168c0dSopenharmony_ci+ add x20, x10, x16, lsl #1 2735be168c0dSopenharmony_ci+ subs x14, x14, #8 2736be168c0dSopenharmony_ci+ blt Compute3x8End4 2737be168c0dSopenharmony_ci+ Compute3x8: 2738be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2739be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x19], #16 2740be168c0dSopenharmony_ci+ ld1 {v2.8h}, [x20], #16 2741be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2742be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2743be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2744be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2745be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2746be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2747be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 2748be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 2749be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v2.h[1] 2750be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 2751be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[2] 2752be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[2] 2753be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2754be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2755be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[3] 2756be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[3] 2757be168c0dSopenharmony_ci+ fmla v12.8h, v6.8h, v2.h[3] 2758be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[4] 2759be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[4] 2760be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[4] 2761be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2762be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[5] 2763be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[5] 2764be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v2.h[5] 2765be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[6] 2766be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[6] 2767be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[6] 2768be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[7] 2769be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[7] 2770be168c0dSopenharmony_ci+ fmla v12.8h, v6.8h, v2.h[7] 2771be168c0dSopenharmony_ci+ 2772be168c0dSopenharmony_ci+ subs x14, x14, #8 2773be168c0dSopenharmony_ci+ bge Compute3x8 2774be168c0dSopenharmony_ci+ Compute3x8End4: 2775be168c0dSopenharmony_ci+ adds x14, x14, #8 2776be168c0dSopenharmony_ci+ cbz x14, Compute3x8Return 2777be168c0dSopenharmony_ci+ subs x14, x14, #4 2778be168c0dSopenharmony_ci+ blt Compute3x8EndTail 2779be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 2780be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19], #8 2781be168c0dSopenharmony_ci+ ld1 {v2.4h}, [x20], #8 2782be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2783be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2784be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2785be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2786be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2787be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2788be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 2789be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 2790be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v2.h[1] 2791be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 2792be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[2] 2793be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[2] 2794be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[3] 2795be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[3] 2796be168c0dSopenharmony_ci+ fmla v12.8h, v6.8h, v2.h[3] 2797be168c0dSopenharmony_ci+ subs x14, x14, #4 2798be168c0dSopenharmony_ci+ Compute3x8EndTail: 2799be168c0dSopenharmony_ci+ adds x14, x14, #4 2800be168c0dSopenharmony_ci+ cbz x14, Compute3x8Return 2801be168c0dSopenharmony_ci+ cmp x14, #1 2802be168c0dSopenharmony_ci+ beq Compute3x8EndTail1 2803be168c0dSopenharmony_ci+ cmp x14, #2 2804be168c0dSopenharmony_ci+ beq Compute3x8EndTail2 2805be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 2806be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19] 2807be168c0dSopenharmony_ci+ ld1 {v2.s}[0], [x20], #4 2808be168c0dSopenharmony_ci+ ld1 {v2.h}[2], [x20] 2809be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2810be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2811be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2812be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2813be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2814be168c0dSopenharmony_ci+ ld1 {v5.8h}, [x11], #16 2815be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 2816be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 2817be168c0dSopenharmony_ci+ fmla v12.8h, v4.8h, v2.h[1] 2818be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 2819be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[2] 2820be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[2] 2821be168c0dSopenharmony_ci+ b Compute3x8Return 2822be168c0dSopenharmony_ci+ Compute3x8EndTail2: 2823be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 2824be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19] 2825be168c0dSopenharmony_ci+ ld2 {v2.h, v3.h}[0], [x20] 2826be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2827be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2828be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[0] 2829be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[0] 2830be168c0dSopenharmony_ci+ fmla v12.8h, v5.8h, v2.h[0] 2831be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[1] 2832be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[1] 2833be168c0dSopenharmony_ci+ fmla v12.8h, v6.8h, v3.h[0] 2834be168c0dSopenharmony_ci+ b Compute3x8Return 2835be168c0dSopenharmony_ci+ Compute3x8EndTail1: 2836be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 2837be168c0dSopenharmony_ci+ ld1 {v1.h}[0], [x19] 2838be168c0dSopenharmony_ci+ ld1 {v2.h}[0], [x20] 2839be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2840be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 2841be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2842be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2843be168c0dSopenharmony_ci+ fmla v12.8h, v3.8h, v2.h[0] 2844be168c0dSopenharmony_ci+ Compute3x8Return: 2845be168c0dSopenharmony_ci+ ret 2846be168c0dSopenharmony_ci+ 2847be168c0dSopenharmony_ci+Compute3x4Unit: 2848be168c0dSopenharmony_ci+ add x19, x10, x16 2849be168c0dSopenharmony_ci+ add x20, x10, x16, lsl #1 2850be168c0dSopenharmony_ci+ subs x14, x14, #8 2851be168c0dSopenharmony_ci+ blt Compute3x4End4 2852be168c0dSopenharmony_ci+ Compute3x4: 2853be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2854be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x19], #16 2855be168c0dSopenharmony_ci+ ld1 {v2.8h}, [x20], #16 2856be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2857be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 2858be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2859be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 2860be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v2.h[0] 2861be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 2862be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 2863be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 2864be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v2.h[1] 2865be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 2866be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[2] 2867be168c0dSopenharmony_ci+ fmla v12.4h, v5.4h, v2.h[2] 2868be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2869be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 2870be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[3] 2871be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[3] 2872be168c0dSopenharmony_ci+ fmla v12.4h, v6.4h, v2.h[3] 2873be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[4] 2874be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[4] 2875be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v2.h[4] 2876be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 2877be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[5] 2878be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[5] 2879be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v2.h[5] 2880be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[6] 2881be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[6] 2882be168c0dSopenharmony_ci+ fmla v12.4h, v5.4h, v2.h[6] 2883be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[7] 2884be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[7] 2885be168c0dSopenharmony_ci+ fmla v12.4h, v6.4h, v2.h[7] 2886be168c0dSopenharmony_ci+ 2887be168c0dSopenharmony_ci+ subs x14, x14, #8 2888be168c0dSopenharmony_ci+ bge Compute3x4 2889be168c0dSopenharmony_ci+ Compute3x4End4: 2890be168c0dSopenharmony_ci+ adds x14, x14, #8 2891be168c0dSopenharmony_ci+ cbz x14, Compute3x4Return 2892be168c0dSopenharmony_ci+ subs x14, x14, #4 2893be168c0dSopenharmony_ci+ blt Compute3x4EndTail 2894be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 2895be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19], #8 2896be168c0dSopenharmony_ci+ ld1 {v2.4h}, [x20], #8 2897be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2898be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 2899be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2900be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 2901be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v2.h[0] 2902be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 2903be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 2904be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 2905be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v2.h[1] 2906be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 2907be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[2] 2908be168c0dSopenharmony_ci+ fmla v12.4h, v5.4h, v2.h[2] 2909be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[3] 2910be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[3] 2911be168c0dSopenharmony_ci+ fmla v12.4h, v6.4h, v2.h[3] 2912be168c0dSopenharmony_ci+ subs x14, x14, #4 2913be168c0dSopenharmony_ci+ Compute3x4EndTail: 2914be168c0dSopenharmony_ci+ adds x14, x14, #4 2915be168c0dSopenharmony_ci+ cbz x14, Compute3x4Return 2916be168c0dSopenharmony_ci+ cmp x14, #1 2917be168c0dSopenharmony_ci+ beq Compute3x4EndTail1 2918be168c0dSopenharmony_ci+ cmp x14, #2 2919be168c0dSopenharmony_ci+ beq Compute3x4EndTail2 2920be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 2921be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19] 2922be168c0dSopenharmony_ci+ ld1 {v2.s}[0], [x20], #4 2923be168c0dSopenharmony_ci+ ld1 {v2.h}[2], [x20] 2924be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2925be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 2926be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2927be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 2928be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v2.h[0] 2929be168c0dSopenharmony_ci+ ld1 {v5.4h}, [x11], #8 2930be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 2931be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 2932be168c0dSopenharmony_ci+ fmla v12.4h, v4.4h, v2.h[1] 2933be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 2934be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[2] 2935be168c0dSopenharmony_ci+ fmla v12.4h, v5.4h, v2.h[2] 2936be168c0dSopenharmony_ci+ b Compute3x4Return 2937be168c0dSopenharmony_ci+ Compute3x4EndTail2: 2938be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 2939be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19] 2940be168c0dSopenharmony_ci+ ld2 {v2.h, v3.h}[0], [x20] 2941be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2942be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 2943be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[0] 2944be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[0] 2945be168c0dSopenharmony_ci+ fmla v12.4h, v5.4h, v2.h[0] 2946be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[1] 2947be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[1] 2948be168c0dSopenharmony_ci+ fmla v12.4h, v6.4h, v3.h[0] 2949be168c0dSopenharmony_ci+ b Compute3x4Return 2950be168c0dSopenharmony_ci+ Compute3x4EndTail1: 2951be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 2952be168c0dSopenharmony_ci+ ld1 {v1.h}[0], [x19] 2953be168c0dSopenharmony_ci+ ld1 {v2.h}[0], [x20] 2954be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2955be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 2956be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 2957be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 2958be168c0dSopenharmony_ci+ fmla v12.4h, v3.4h, v2.h[0] 2959be168c0dSopenharmony_ci+ Compute3x4Return: 2960be168c0dSopenharmony_ci+ ret 2961be168c0dSopenharmony_ci+ 2962be168c0dSopenharmony_ci+Compute2x16Unit: 2963be168c0dSopenharmony_ci+ add x19, x10, x16 2964be168c0dSopenharmony_ci+ subs x14, x14, #8 2965be168c0dSopenharmony_ci+ blt Compute2x16End4 2966be168c0dSopenharmony_ci+ Compute2x16: 2967be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 2968be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x19], #16 2969be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2970be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2971be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 2972be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 2973be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2974be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 2975be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 2976be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 2977be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 2978be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2979be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 2980be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 2981be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 2982be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[2] 2983be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2984be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 2985be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[2] 2986be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[3] 2987be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[3] 2988be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 2989be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 2990be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[3] 2991be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[3] 2992be168c0dSopenharmony_ci+ 2993be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[4] 2994be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[4] 2995be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 2996be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[4] 2997be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[4] 2998be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[5] 2999be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[5] 3000be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3001be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[5] 3002be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[5] 3003be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[6] 3004be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[6] 3005be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3006be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[6] 3007be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[6] 3008be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[7] 3009be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[7] 3010be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[7] 3011be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[7] 3012be168c0dSopenharmony_ci+ 3013be168c0dSopenharmony_ci+ subs x14, x14, #8 3014be168c0dSopenharmony_ci+ bge Compute2x16 3015be168c0dSopenharmony_ci+ Compute2x16End4: 3016be168c0dSopenharmony_ci+ adds x14, x14, #8 3017be168c0dSopenharmony_ci+ cbz x14, Compute2x16Return 3018be168c0dSopenharmony_ci+ subs x14, x14, #4 3019be168c0dSopenharmony_ci+ blt Compute2x16EndTail 3020be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 3021be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19], #8 3022be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3023be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3024be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3025be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3026be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3027be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3028be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 3029be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 3030be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 3031be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3032be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 3033be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 3034be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 3035be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[2] 3036be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3037be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 3038be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[2] 3039be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[3] 3040be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[3] 3041be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[3] 3042be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[3] 3043be168c0dSopenharmony_ci+ subs x14, x14, #4 3044be168c0dSopenharmony_ci+ Compute2x16EndTail: 3045be168c0dSopenharmony_ci+ adds x14, x14, #4 3046be168c0dSopenharmony_ci+ cbz x14, Compute2x16Return 3047be168c0dSopenharmony_ci+ cmp x14, #1 3048be168c0dSopenharmony_ci+ beq Compute2x16EndTail1 3049be168c0dSopenharmony_ci+ cmp x14, #2 3050be168c0dSopenharmony_ci+ beq Compute2x16EndTail2 3051be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 3052be168c0dSopenharmony_ci+ ld1 {v1.s}[0], [x19], #4 3053be168c0dSopenharmony_ci+ ld1 {v1.h}[2], [x19] 3054be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3055be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3056be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3057be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3058be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3059be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3060be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 3061be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 3062be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[1] 3063be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3064be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 3065be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v1.h[1] 3066be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 3067be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[2] 3068be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 3069be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[2] 3070be168c0dSopenharmony_ci+ b Compute2x16Return 3071be168c0dSopenharmony_ci+ Compute2x16EndTail2: 3072be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 3073be168c0dSopenharmony_ci+ ld2 {v1.h, v2.h}[0], [x19] 3074be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3075be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3076be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3077be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3078be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3079be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3080be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 3081be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 3082be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v2.h[0] 3083be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 3084be168c0dSopenharmony_ci+ fmla v11.8h, v6.8h, v2.h[0] 3085be168c0dSopenharmony_ci+ b Compute2x16Return 3086be168c0dSopenharmony_ci+ Compute2x16EndTail1: 3087be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 3088be168c0dSopenharmony_ci+ ld1 {v1.h}[0], [x19] 3089be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3090be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3091be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3092be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3093be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3094be168c0dSopenharmony_ci+ fmla v11.8h, v4.8h, v1.h[0] 3095be168c0dSopenharmony_ci+ Compute2x16Return: 3096be168c0dSopenharmony_ci+ ret 3097be168c0dSopenharmony_ci+ 3098be168c0dSopenharmony_ci+Compute2x8Unit: 3099be168c0dSopenharmony_ci+ add x19, x10, x16 3100be168c0dSopenharmony_ci+ subs x14, x14, #8 3101be168c0dSopenharmony_ci+ blt Compute2x8End4 3102be168c0dSopenharmony_ci+ Compute2x8: 3103be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 3104be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x19], #16 3105be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3106be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3107be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3108be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3109be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3110be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 3111be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 3112be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 3113be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[2] 3114be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3115be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3116be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[3] 3117be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[3] 3118be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[4] 3119be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[4] 3120be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3121be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[5] 3122be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[5] 3123be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[6] 3124be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[6] 3125be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[7] 3126be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[7] 3127be168c0dSopenharmony_ci+ 3128be168c0dSopenharmony_ci+ subs x14, x14, #8 3129be168c0dSopenharmony_ci+ bge Compute2x8 3130be168c0dSopenharmony_ci+ Compute2x8End4: 3131be168c0dSopenharmony_ci+ adds x14, x14, #8 3132be168c0dSopenharmony_ci+ cbz x14, Compute2x8Return 3133be168c0dSopenharmony_ci+ subs x14, x14, #4 3134be168c0dSopenharmony_ci+ blt Compute2x8EndTail 3135be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 3136be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19], #8 3137be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3138be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3139be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3140be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3141be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3142be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 3143be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[1] 3144be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 3145be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v1.h[2] 3146be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[3] 3147be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v1.h[3] 3148be168c0dSopenharmony_ci+ subs x14, x14, #4 3149be168c0dSopenharmony_ci+ Compute2x8EndTail: 3150be168c0dSopenharmony_ci+ adds x14, x14, #4 3151be168c0dSopenharmony_ci+ cbz x14, Compute2x8Return 3152be168c0dSopenharmony_ci+ cmp x14, #1 3153be168c0dSopenharmony_ci+ beq Compute2x8EndTail1 3154be168c0dSopenharmony_ci+ cmp x14, #2 3155be168c0dSopenharmony_ci+ beq Compute2x8EndTail2 3156be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 3157be168c0dSopenharmony_ci+ ld3 {v1.h, v2.h, v3.h}[0], [x19] 3158be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3159be168c0dSopenharmony_ci+ ld1 {v4.8h, v5.8h}, [x11], #32 3160be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[0] 3161be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v1.h[0] 3162be168c0dSopenharmony_ci+ ld1 {v6.8h}, [x11], #16 3163be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 3164be168c0dSopenharmony_ci+ fmla v10.8h, v5.8h, v2.h[0] 3165be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[2] 3166be168c0dSopenharmony_ci+ fmla v10.8h, v6.8h, v3.h[0] 3167be168c0dSopenharmony_ci+ b Compute2x8Return 3168be168c0dSopenharmony_ci+ Compute2x8EndTail2: 3169be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 3170be168c0dSopenharmony_ci+ ld2 {v1.h, v2.h}[0], [x19] 3171be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3172be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3173be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3174be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3175be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 3176be168c0dSopenharmony_ci+ fmla v10.8h, v4.8h, v2.h[0] 3177be168c0dSopenharmony_ci+ b Compute2x8Return 3178be168c0dSopenharmony_ci+ Compute2x8EndTail1: 3179be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 3180be168c0dSopenharmony_ci+ ld1 {v1.h}[0], [x19] 3181be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3182be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 3183be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3184be168c0dSopenharmony_ci+ fmla v10.8h, v3.8h, v1.h[0] 3185be168c0dSopenharmony_ci+ Compute2x8Return: 3186be168c0dSopenharmony_ci+ ret 3187be168c0dSopenharmony_ci+ 3188be168c0dSopenharmony_ci+Compute2x4Unit: 3189be168c0dSopenharmony_ci+ add x19, x10, x16 3190be168c0dSopenharmony_ci+ subs x14, x14, #8 3191be168c0dSopenharmony_ci+ blt Compute2x4End4 3192be168c0dSopenharmony_ci+ Compute2x4: 3193be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 3194be168c0dSopenharmony_ci+ ld1 {v1.8h}, [x19], #16 3195be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3196be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3197be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3198be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 3199be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 3200be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 3201be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 3202be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 3203be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[2] 3204be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3205be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3206be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[3] 3207be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[3] 3208be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[4] 3209be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[4] 3210be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 3211be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[5] 3212be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[5] 3213be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[6] 3214be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[6] 3215be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[7] 3216be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[7] 3217be168c0dSopenharmony_ci+ 3218be168c0dSopenharmony_ci+ subs x14, x14, #8 3219be168c0dSopenharmony_ci+ bge Compute2x4 3220be168c0dSopenharmony_ci+ Compute2x4End4: 3221be168c0dSopenharmony_ci+ adds x14, x14, #8 3222be168c0dSopenharmony_ci+ cbz x14, Compute2x4Return 3223be168c0dSopenharmony_ci+ subs x14, x14, #4 3224be168c0dSopenharmony_ci+ blt Compute2x4EndTail 3225be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 3226be168c0dSopenharmony_ci+ ld1 {v1.4h}, [x19], #8 3227be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3228be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3229be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3230be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 3231be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 3232be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 3233be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[1] 3234be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 3235be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v1.h[2] 3236be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[3] 3237be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v1.h[3] 3238be168c0dSopenharmony_ci+ subs x14, x14, #4 3239be168c0dSopenharmony_ci+ Compute2x4EndTail: 3240be168c0dSopenharmony_ci+ adds x14, x14, #4 3241be168c0dSopenharmony_ci+ cbz x14, Compute2x4Return 3242be168c0dSopenharmony_ci+ cmp x14, #1 3243be168c0dSopenharmony_ci+ beq Compute2x4EndTail1 3244be168c0dSopenharmony_ci+ cmp x14, #2 3245be168c0dSopenharmony_ci+ beq Compute2x4EndTail2 3246be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 3247be168c0dSopenharmony_ci+ ld3 {v1.h, v2.h, v3.h}[0], [x19] 3248be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3249be168c0dSopenharmony_ci+ ld1 {v4.4h, v5.4h}, [x11], #16 3250be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[0] 3251be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v1.h[0] 3252be168c0dSopenharmony_ci+ ld1 {v6.4h}, [x11], #8 3253be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[1] 3254be168c0dSopenharmony_ci+ fmla v10.4h, v5.4h, v2.h[0] 3255be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[2] 3256be168c0dSopenharmony_ci+ fmla v10.4h, v6.4h, v3.h[0] 3257be168c0dSopenharmony_ci+ b Compute2x4Return 3258be168c0dSopenharmony_ci+ Compute2x4EndTail2: 3259be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10] 3260be168c0dSopenharmony_ci+ ld2 {v1.h, v2.h}[0], [x19] 3261be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3262be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3263be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3264be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 3265be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 3266be168c0dSopenharmony_ci+ fmla v10.4h, v4.4h, v2.h[0] 3267be168c0dSopenharmony_ci+ b Compute2x4Return 3268be168c0dSopenharmony_ci+ Compute2x4EndTail1: 3269be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 3270be168c0dSopenharmony_ci+ ld1 {v1.h}[0], [x19] 3271be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3272be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 3273be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3274be168c0dSopenharmony_ci+ fmla v10.4h, v3.4h, v1.h[0] 3275be168c0dSopenharmony_ci+ Compute2x4Return: 3276be168c0dSopenharmony_ci+ ret 3277be168c0dSopenharmony_ci+ 3278be168c0dSopenharmony_ci+Compute1x16Unit: 3279be168c0dSopenharmony_ci+ subs x14, x14, #8 3280be168c0dSopenharmony_ci+ blt Compute1x16End4 3281be168c0dSopenharmony_ci+ Compute1x16: 3282be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 3283be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3284be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3285be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3286be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3287be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3288be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 3289be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3290be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 3291be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 3292be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3293be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 3294be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[3] 3295be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3296be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3297be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[3] 3298be168c0dSopenharmony_ci+ 3299be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[4] 3300be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3301be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[4] 3302be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[5] 3303be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3304be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[5] 3305be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[6] 3306be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3307be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[6] 3308be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[7] 3309be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[7] 3310be168c0dSopenharmony_ci+ 3311be168c0dSopenharmony_ci+ subs x14, x14, #8 3312be168c0dSopenharmony_ci+ bge Compute1x16 3313be168c0dSopenharmony_ci+ Compute1x16End4: 3314be168c0dSopenharmony_ci+ adds x14, x14, #8 3315be168c0dSopenharmony_ci+ cbz x14, Compute1x16Return 3316be168c0dSopenharmony_ci+ subs x14, x14, #4 3317be168c0dSopenharmony_ci+ blt Compute1x16EndTail 3318be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 3319be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3320be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3321be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3322be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3323be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3324be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[1] 3325be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3326be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[1] 3327be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[2] 3328be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3329be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[2] 3330be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[3] 3331be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v0.h[3] 3332be168c0dSopenharmony_ci+ subs x14, x14, #4 3333be168c0dSopenharmony_ci+ Compute1x16EndTail: 3334be168c0dSopenharmony_ci+ adds x14, x14, #4 3335be168c0dSopenharmony_ci+ cbz x14, Compute1x16Return 3336be168c0dSopenharmony_ci+ cmp x14, #1 3337be168c0dSopenharmony_ci+ beq Compute1x16EndTail1 3338be168c0dSopenharmony_ci+ cmp x14, #2 3339be168c0dSopenharmony_ci+ beq Compute1x16EndTail2 3340be168c0dSopenharmony_ci+ ld3 {v0.h, v1.h, v2.h}[0], [x10] 3341be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3342be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3343be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3344be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3345be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3346be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v1.h[0] 3347be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3348be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v1.h[0] 3349be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v2.h[0] 3350be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v2.h[0] 3351be168c0dSopenharmony_ci+ b Compute1x16Return 3352be168c0dSopenharmony_ci+ Compute1x16EndTail2: 3353be168c0dSopenharmony_ci+ ld2 {v0.h, v1.h}[0], [x10] 3354be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3355be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3356be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3357be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3358be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3359be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v1.h[0] 3360be168c0dSopenharmony_ci+ fmla v9.8h, v6.8h, v1.h[0] 3361be168c0dSopenharmony_ci+ b Compute1x16Return 3362be168c0dSopenharmony_ci+ Compute1x16EndTail1: 3363be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 3364be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3365be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3366be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3367be168c0dSopenharmony_ci+ fmla v9.8h, v4.8h, v0.h[0] 3368be168c0dSopenharmony_ci+ Compute1x16Return: 3369be168c0dSopenharmony_ci+ ret 3370be168c0dSopenharmony_ci+ 3371be168c0dSopenharmony_ci+Compute1x8Unit: 3372be168c0dSopenharmony_ci+ subs x14, x14, #8 3373be168c0dSopenharmony_ci+ blt Compute1x8End4 3374be168c0dSopenharmony_ci+ Compute1x8: 3375be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 3376be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3377be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3378be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3379be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3380be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 3381be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 3382be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3383be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3384be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[3] 3385be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[4] 3386be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3387be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[5] 3388be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[6] 3389be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[7] 3390be168c0dSopenharmony_ci+ 3391be168c0dSopenharmony_ci+ subs x14, x14, #8 3392be168c0dSopenharmony_ci+ bge Compute1x8 3393be168c0dSopenharmony_ci+ Compute1x8End4: 3394be168c0dSopenharmony_ci+ adds x14, x14, #8 3395be168c0dSopenharmony_ci+ cbz x14, Compute1x8Return 3396be168c0dSopenharmony_ci+ subs x14, x14, #4 3397be168c0dSopenharmony_ci+ blt Compute1x8EndTail 3398be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 3399be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3400be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3401be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3402be168c0dSopenharmony_ci+ ld1 {v5.8h, v6.8h}, [x11], #32 3403be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v0.h[1] 3404be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v0.h[2] 3405be168c0dSopenharmony_ci+ fmla v8.8h, v6.8h, v0.h[3] 3406be168c0dSopenharmony_ci+ subs x14, x14, #4 3407be168c0dSopenharmony_ci+ Compute1x8EndTail: 3408be168c0dSopenharmony_ci+ adds x14, x14, #4 3409be168c0dSopenharmony_ci+ cbz x14, Compute1x8Return 3410be168c0dSopenharmony_ci+ cmp x14, #1 3411be168c0dSopenharmony_ci+ beq Compute1x8EndTail1 3412be168c0dSopenharmony_ci+ cmp x14, #2 3413be168c0dSopenharmony_ci+ beq Compute1x8EndTail2 3414be168c0dSopenharmony_ci+ ld3 {v0.h, v1.h, v2.h}[0], [x10] 3415be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3416be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3417be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3418be168c0dSopenharmony_ci+ ld1 {v5.8h}, [x11], #16 3419be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v1.h[0] 3420be168c0dSopenharmony_ci+ fmla v8.8h, v5.8h, v2.h[0] 3421be168c0dSopenharmony_ci+ b Compute1x8Return 3422be168c0dSopenharmony_ci+ Compute1x8EndTail2: 3423be168c0dSopenharmony_ci+ ld2 {v0.h, v1.h}[0], [x10] 3424be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3425be168c0dSopenharmony_ci+ ld1 {v3.8h, v4.8h}, [x11], #32 3426be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3427be168c0dSopenharmony_ci+ fmla v8.8h, v4.8h, v1.h[0] 3428be168c0dSopenharmony_ci+ b Compute1x8Return 3429be168c0dSopenharmony_ci+ Compute1x8EndTail1: 3430be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 3431be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3432be168c0dSopenharmony_ci+ ld1 {v3.8h}, [x11], #16 3433be168c0dSopenharmony_ci+ fmla v8.8h, v3.8h, v0.h[0] 3434be168c0dSopenharmony_ci+ Compute1x8Return: 3435be168c0dSopenharmony_ci+ ret 3436be168c0dSopenharmony_ci+ 3437be168c0dSopenharmony_ci+Compute1x4Unit: 3438be168c0dSopenharmony_ci+ subs x14, x14, #8 3439be168c0dSopenharmony_ci+ blt Compute1x4End4 3440be168c0dSopenharmony_ci+ Compute1x4: 3441be168c0dSopenharmony_ci+ ld1 {v0.8h}, [x10], #16 3442be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3443be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3444be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3445be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 3446be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 3447be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 3448be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3449be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3450be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[3] 3451be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[4] 3452be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 3453be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[5] 3454be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[6] 3455be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[7] 3456be168c0dSopenharmony_ci+ 3457be168c0dSopenharmony_ci+ subs x14, x14, #8 3458be168c0dSopenharmony_ci+ bge Compute1x4 3459be168c0dSopenharmony_ci+ Compute1x4End4: 3460be168c0dSopenharmony_ci+ adds x14, x14, #8 3461be168c0dSopenharmony_ci+ cbz x14, Compute1x4Return 3462be168c0dSopenharmony_ci+ subs x14, x14, #4 3463be168c0dSopenharmony_ci+ blt Compute1x4EndTail 3464be168c0dSopenharmony_ci+ ld1 {v0.4h}, [x10], #8 3465be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3466be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3467be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3468be168c0dSopenharmony_ci+ ld1 {v5.4h, v6.4h}, [x11], #16 3469be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v0.h[1] 3470be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v0.h[2] 3471be168c0dSopenharmony_ci+ fmla v8.4h, v6.4h, v0.h[3] 3472be168c0dSopenharmony_ci+ subs x14, x14, #4 3473be168c0dSopenharmony_ci+ Compute1x4EndTail: 3474be168c0dSopenharmony_ci+ adds x14, x14, #4 3475be168c0dSopenharmony_ci+ cbz x14, Compute1x4Return 3476be168c0dSopenharmony_ci+ cmp x14, #1 3477be168c0dSopenharmony_ci+ beq Compute1x4EndTail1 3478be168c0dSopenharmony_ci+ cmp x14, #2 3479be168c0dSopenharmony_ci+ beq Compute1x4EndTail2 3480be168c0dSopenharmony_ci+ ld3 {v0.h, v1.h, v2.h}[0], [x10] 3481be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3482be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3483be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3484be168c0dSopenharmony_ci+ ld1 {v5.4h}, [x11], #8 3485be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v1.h[0] 3486be168c0dSopenharmony_ci+ fmla v8.4h, v5.4h, v2.h[0] 3487be168c0dSopenharmony_ci+ b Compute1x4Return 3488be168c0dSopenharmony_ci+ Compute1x4EndTail2: 3489be168c0dSopenharmony_ci+ ld2 {v0.h, v1.h}[0], [x10] 3490be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3491be168c0dSopenharmony_ci+ ld1 {v3.4h, v4.4h}, [x11], #16 3492be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3493be168c0dSopenharmony_ci+ fmla v8.4h, v4.4h, v1.h[0] 3494be168c0dSopenharmony_ci+ b Compute1x4Return 3495be168c0dSopenharmony_ci+ Compute1x4EndTail1: 3496be168c0dSopenharmony_ci+ ld1 {v0.h}[0], [x10] 3497be168c0dSopenharmony_ci+ prfm pldl1strm, [x11, #632] 3498be168c0dSopenharmony_ci+ ld1 {v3.4h}, [x11], #8 3499be168c0dSopenharmony_ci+ fmla v8.4h, v3.4h, v0.h[0] 3500be168c0dSopenharmony_ci+ Compute1x4Return: 3501be168c0dSopenharmony_ci+ ret 3502be168c0dSopenharmony_ci+ 3503be168c0dSopenharmony_ci+End: 3504be168c0dSopenharmony_ci+ sub sp, sp, #192 3505be168c0dSopenharmony_ci+ ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 3506be168c0dSopenharmony_ci+ ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 3507be168c0dSopenharmony_ci+ ldp x19, x20, [sp], #16 3508be168c0dSopenharmony_ci+ ldp x21, x22, [sp], #16 3509be168c0dSopenharmony_ci+ ldp x23, x24, [sp], #16 3510be168c0dSopenharmony_ci+ ldp x29, x30, [sp], #16 3511be168c0dSopenharmony_ci+ ret 3512be168c0dSopenharmony_ci+#endif 3513be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gather_d_grad_v2_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gather_d_grad_v2_parameter.h 3514be168c0dSopenharmony_cinew file mode 100644 3515be168c0dSopenharmony_ciindex 00000000..541c7ff1 3516be168c0dSopenharmony_ci--- /dev/null 3517be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_gather_d_grad_v2_parameter.h 3518be168c0dSopenharmony_ci@@ -0,0 +1,28 @@ 3519be168c0dSopenharmony_ci+/** 3520be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 3521be168c0dSopenharmony_ci+ * 3522be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 3523be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 3524be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 3525be168c0dSopenharmony_ci+ * 3526be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 3527be168c0dSopenharmony_ci+ * 3528be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 3529be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 3530be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3531be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 3532be168c0dSopenharmony_ci+ * limitations under the License. 3533be168c0dSopenharmony_ci+ */ 3534be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_PARAMETER_H_ 3535be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_PARAMETER_H_ 3536be168c0dSopenharmony_ci+ 3537be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 3538be168c0dSopenharmony_ci+ 3539be168c0dSopenharmony_ci+typedef struct CustomGatherGradV2Parameter { 3540be168c0dSopenharmony_ci+ // Primitive parameter 3541be168c0dSopenharmony_ci+ OpParameter op_parameter_; 3542be168c0dSopenharmony_ci+ // shape correlative 3543be168c0dSopenharmony_ci+ int dim; 3544be168c0dSopenharmony_ci+} CustomGatherGradV2Parameter; 3545be168c0dSopenharmony_ci+ 3546be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_PARAMETER_H_ 3547be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 3548be168c0dSopenharmony_ciindex 6e754569..72391811 100644 3549be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 3550be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/custom_gru_fp16.c 3551be168c0dSopenharmony_ci@@ -35,13 +35,13 @@ void CustomGruFp16(float16_t *output, const float16_t *input, const float16_t *w 3552be168c0dSopenharmony_ci float16_t *hidden_gate = buffer[C3NUM]; 3553be168c0dSopenharmony_ci for (int i = 0; i < num_step; ++i) { 3554be168c0dSopenharmony_ci if (batch_size != 1) { 3555be168c0dSopenharmony_ci- RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size); 3556be168c0dSopenharmony_ci+ RowMajor2ColNMajorFp16(input + i * batch_size * input_size, buffer[0], batch_size, input_size, false); 3557be168c0dSopenharmony_ci for (int j = 0; j < C3NUM; ++j) { 3558be168c0dSopenharmony_ci MatmulBaseFp16Neon(buffer[0], weight_input + j * weight_in_offset, input_gate + j * output_size, 3559be168c0dSopenharmony_ci bias_input + j * col_align, ActType_No, input_size, batch_size, hidden_size, hidden_size, 3560be168c0dSopenharmony_ci OutType_Nhwc); 3561be168c0dSopenharmony_ci } 3562be168c0dSopenharmony_ci- RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size); 3563be168c0dSopenharmony_ci+ RowMajor2ColNMajorFp16(init_h, buffer[C2NUM], batch_size, hidden_size, false); 3564be168c0dSopenharmony_ci for (int j = 0; j < C3NUM; ++j) { 3565be168c0dSopenharmony_ci MatmulBaseFp16Neon(buffer[C2NUM], weight_hidden + j * weight_hidden_offset, hidden_gate + j * output_size, 3566be168c0dSopenharmony_ci bias_hidden + j * col_align, ActType_No, hidden_size, batch_size, hidden_size, hidden_size, 3567be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c 3568be168c0dSopenharmony_ciindex d1555953..93f005c8 100644 3569be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c 3570be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/exp_fp16.c 3571be168c0dSopenharmony_ci@@ -20,8 +20,10 @@ 3572be168c0dSopenharmony_ci 3573be168c0dSopenharmony_ci #if defined(ENABLE_NEON) 3574be168c0dSopenharmony_ci static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) { 3575be168c0dSopenharmony_ci- static float16x8_t maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f}; 3576be168c0dSopenharmony_ci- static float16x8_t minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f}; 3577be168c0dSopenharmony_ci+ static float16x8_t maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 3578be168c0dSopenharmony_ci+ 88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; 3579be168c0dSopenharmony_ci+ static float16x8_t minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, 3580be168c0dSopenharmony_ci+ -87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; 3581be168c0dSopenharmony_ci input = vmaxq_f16(minv, vminq_f16(input, maxv)); 3582be168c0dSopenharmony_ci vst1q_f16(dst, VexpFp16(input)); 3583be168c0dSopenharmony_ci } 3584be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c 3585be168c0dSopenharmony_ciindex 813237fa..614842a1 100644 3586be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c 3587be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.c 3588be168c0dSopenharmony_ci@@ -23,28 +23,38 @@ 3589be168c0dSopenharmony_ci #include "nnacl/fp16/cast_fp16.h" 3590be168c0dSopenharmony_ci #include "nnacl/intrinsics/ms_simd_instructions_fp16.h" 3591be168c0dSopenharmony_ci 3592be168c0dSopenharmony_ci-void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { 3593be168c0dSopenharmony_ci+void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, 3594be168c0dSopenharmony_ci+ const int32_t *order) { 3595be168c0dSopenharmony_ci for (int i = 0; i < batch; i++) { 3596be168c0dSopenharmony_ci const float *src_batch = src + i * col * deep; 3597be168c0dSopenharmony_ci- float16_t *dst_batch = dst + i * col_align * deep; 3598be168c0dSopenharmony_ci+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; 3599be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3600be168c0dSopenharmony_ci+ RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, true); 3601be168c0dSopenharmony_ci+#else 3602be168c0dSopenharmony_ci RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, true); 3603be168c0dSopenharmony_ci+#endif 3604be168c0dSopenharmony_ci } 3605be168c0dSopenharmony_ci } 3606be168c0dSopenharmony_ci 3607be168c0dSopenharmony_ci-void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align) { 3608be168c0dSopenharmony_ci+void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, 3609be168c0dSopenharmony_ci+ const int32_t *order) { 3610be168c0dSopenharmony_ci for (int i = 0; i < batch; i++) { 3611be168c0dSopenharmony_ci const float16_t *src_batch = src + i * col * deep; 3612be168c0dSopenharmony_ci- float16_t *dst_batch = dst + i * col_align * deep; 3613be168c0dSopenharmony_ci+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align * deep; 3614be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3615be168c0dSopenharmony_ci+ RowMajor2ColNMajorFp16(src_batch, dst_batch, col, deep, false); 3616be168c0dSopenharmony_ci+#else 3617be168c0dSopenharmony_ci RowMajor2Col8MajorFp16(src_batch, dst_batch, col, deep, false); 3618be168c0dSopenharmony_ci+#endif 3619be168c0dSopenharmony_ci } 3620be168c0dSopenharmony_ci } 3621be168c0dSopenharmony_ci 3622be168c0dSopenharmony_ci-void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, 3623be168c0dSopenharmony_ci- bool is_bidirectional) { 3624be168c0dSopenharmony_ci+void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 3625be168c0dSopenharmony_ci+ const int32_t *order) { 3626be168c0dSopenharmony_ci int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 3627be168c0dSopenharmony_ci for (int i = 0; i < unidirectional_batch; i++) { 3628be168c0dSopenharmony_ci const float *src_batch = src + i * col; 3629be168c0dSopenharmony_ci- float16_t *dst_batch = dst + i * col_align; 3630be168c0dSopenharmony_ci+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; 3631be168c0dSopenharmony_ci Float32ToFloat16(src_batch, dst_batch, col); 3632be168c0dSopenharmony_ci } 3633be168c0dSopenharmony_ci if (is_bidirectional) { 3634be168c0dSopenharmony_ci@@ -52,17 +62,18 @@ void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col 3635be168c0dSopenharmony_ci float16_t *backward_dst = dst + unidirectional_batch * col_align; 3636be168c0dSopenharmony_ci for (int i = 0; i < unidirectional_batch; i++) { 3637be168c0dSopenharmony_ci const float *backward_src_batch = backward_src + i * col; 3638be168c0dSopenharmony_ci- float16_t *backward_dst_batch = backward_dst + i * col_align; 3639be168c0dSopenharmony_ci+ float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; 3640be168c0dSopenharmony_ci Float32ToFloat16(backward_src_batch, backward_dst_batch, col); 3641be168c0dSopenharmony_ci } 3642be168c0dSopenharmony_ci } 3643be168c0dSopenharmony_ci } 3644be168c0dSopenharmony_ci 3645be168c0dSopenharmony_ci-void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional) { 3646be168c0dSopenharmony_ci+void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, 3647be168c0dSopenharmony_ci+ const int32_t *order) { 3648be168c0dSopenharmony_ci int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 3649be168c0dSopenharmony_ci for (int i = 0; i < unidirectional_batch; i++) { 3650be168c0dSopenharmony_ci const float16_t *src_batch = src + i * col; 3651be168c0dSopenharmony_ci- float16_t *dst_batch = dst + i * col_align; 3652be168c0dSopenharmony_ci+ float16_t *dst_batch = dst + (order == NULL ? i : order[i]) * col_align; 3653be168c0dSopenharmony_ci (void)memcpy(dst_batch, src_batch, col * sizeof(float16_t)); 3654be168c0dSopenharmony_ci } 3655be168c0dSopenharmony_ci if (is_bidirectional) { 3656be168c0dSopenharmony_ci@@ -70,7 +81,7 @@ void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, 3657be168c0dSopenharmony_ci float16_t *backward_dst = dst + unidirectional_batch * col_align; 3658be168c0dSopenharmony_ci for (int i = 0; i < unidirectional_batch; i++) { 3659be168c0dSopenharmony_ci const float16_t *backward_src_batch = backward_src + i * col; 3660be168c0dSopenharmony_ci- float16_t *backward_dst_batch = backward_dst + i * col_align; 3661be168c0dSopenharmony_ci+ float16_t *backward_dst_batch = backward_dst + (order == NULL ? i : order[i]) * col_align; 3662be168c0dSopenharmony_ci (void)memcpy(backward_dst_batch, backward_src_batch, col * sizeof(float16_t)); 3663be168c0dSopenharmony_ci } 3664be168c0dSopenharmony_ci } 3665be168c0dSopenharmony_ci@@ -152,13 +163,13 @@ void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_ 3666be168c0dSopenharmony_ci const LstmParameter *lstm_param) { 3667be168c0dSopenharmony_ci int batch = lstm_param->batch_; 3668be168c0dSopenharmony_ci int hidden_size = lstm_param->hidden_size_; 3669be168c0dSopenharmony_ci- int project_size = lstm_param->project_size_; 3670be168c0dSopenharmony_ci+ int output_size = lstm_param->output_size_; 3671be168c0dSopenharmony_ci float16_t *state_buffer = buffer[C5NUM]; 3672be168c0dSopenharmony_ci float16_t *hidden_buffer = weight_project ? buffer[C3NUM] : hidden_state; 3673be168c0dSopenharmony_ci float16_t zoneout = lstm_param->zoneout_hidden_; 3674be168c0dSopenharmony_ci if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 3675be168c0dSopenharmony_ci- (void)memcpy(state_buffer, hidden_state, batch * project_size * sizeof(float16_t)); 3676be168c0dSopenharmony_ci- ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * project_size, false); 3677be168c0dSopenharmony_ci+ (void)memcpy(state_buffer, hidden_state, batch * output_size * sizeof(float16_t)); 3678be168c0dSopenharmony_ci+ ElementOptMulFp16(state_buffer, &zoneout, state_buffer, batch * output_size, false); 3679be168c0dSopenharmony_ci } 3680be168c0dSopenharmony_ci 3681be168c0dSopenharmony_ci TanhFp16(cell_state, hidden_buffer, batch * hidden_size); 3682be168c0dSopenharmony_ci@@ -166,19 +177,32 @@ void UpdateOutputFp16(float16_t *hidden_state, float16_t *output, const float16_ 3683be168c0dSopenharmony_ci 3684be168c0dSopenharmony_ci if (weight_project) { 3685be168c0dSopenharmony_ci float16_t *left_matrix = hidden_buffer; 3686be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3687be168c0dSopenharmony_ci+ if (batch >= C4NUM) { 3688be168c0dSopenharmony_ci+ left_matrix = buffer[C6NUM]; 3689be168c0dSopenharmony_ci+ RowMajor2ColLadder12MajorFp16(hidden_buffer, left_matrix, batch, hidden_size); 3690be168c0dSopenharmony_ci+ } 3691be168c0dSopenharmony_ci+#else 3692be168c0dSopenharmony_ci if (batch != 1) { 3693be168c0dSopenharmony_ci left_matrix = buffer[C6NUM]; 3694be168c0dSopenharmony_ci RowMajor2Col16MajorFp16(hidden_buffer, left_matrix, batch, hidden_size, false); 3695be168c0dSopenharmony_ci } 3696be168c0dSopenharmony_ci- LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, project_size, 3697be168c0dSopenharmony_ci+#endif 3698be168c0dSopenharmony_ci+ LstmMatMulFp16(hidden_state, left_matrix, weight_project, project_bias, batch, hidden_size, output_size, 3699be168c0dSopenharmony_ci batch == 1); 3700be168c0dSopenharmony_ci } 3701be168c0dSopenharmony_ci if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 3702be168c0dSopenharmony_ci- ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * project_size); 3703be168c0dSopenharmony_ci+ ElementOptMulAccFp16(hidden_state, 1 - zoneout, state_buffer, batch * output_size); 3704be168c0dSopenharmony_ci } 3705be168c0dSopenharmony_ci- (void)memcpy(output, hidden_state, batch * project_size * sizeof(float16_t)); 3706be168c0dSopenharmony_ci+ (void)memcpy(output, hidden_state, batch * output_size * sizeof(float16_t)); 3707be168c0dSopenharmony_ci } 3708be168c0dSopenharmony_ci 3709be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3710be168c0dSopenharmony_ci+void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, 3711be168c0dSopenharmony_ci+ int col, bool is_vec) { 3712be168c0dSopenharmony_ci+ MatmulFp16OptV2(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); 3713be168c0dSopenharmony_ci+} 3714be168c0dSopenharmony_ci+#else 3715be168c0dSopenharmony_ci void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, 3716be168c0dSopenharmony_ci int col, bool is_vec) { 3717be168c0dSopenharmony_ci if (is_vec) { 3718be168c0dSopenharmony_ci@@ -188,11 +212,12 @@ void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const 3719be168c0dSopenharmony_ci MatMulFp16(a, b, c, bias, ActType_No, deep, row, col, col, OutType_Nhwc); 3720be168c0dSopenharmony_ci } 3721be168c0dSopenharmony_ci } 3722be168c0dSopenharmony_ci+#endif 3723be168c0dSopenharmony_ci 3724be168c0dSopenharmony_ci void UpdateLstmGateFp16(float16_t *gate_buffer, const float16_t *input, const float16_t *weight, const float16_t *bias, 3725be168c0dSopenharmony_ci int row, int deep, int col, int col_align, bool is_vec) { 3726be168c0dSopenharmony_ci for (int i = 0; i < 4; i++) { 3727be168c0dSopenharmony_ci- const float16_t *weight_i = weight + deep * col * i; 3728be168c0dSopenharmony_ci+ const float16_t *weight_i = weight + deep * col_align * i; 3729be168c0dSopenharmony_ci const float16_t *bias_i = bias + col_align * i; 3730be168c0dSopenharmony_ci float16_t *gate = gate_buffer + row * col * i; 3731be168c0dSopenharmony_ci LstmMatMulFp16(gate, input, weight_i, bias_i, row, deep, col, is_vec); 3732be168c0dSopenharmony_ci@@ -207,16 +232,26 @@ void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forge 3733be168c0dSopenharmony_ci float16_t *state_gate = buffer[C3NUM]; 3734be168c0dSopenharmony_ci float16_t *cell_buffer = buffer[C4NUM]; 3735be168c0dSopenharmony_ci float16_t *hidden_buffer = buffer[C5NUM]; 3736be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3737be168c0dSopenharmony_ci+ if (lstm_param->batch_ <= C3NUM) { 3738be168c0dSopenharmony_ci+ UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3739be168c0dSopenharmony_ci+ lstm_param->hidden_size_, lstm_param->state_col_align_, false); 3740be168c0dSopenharmony_ci+ } else { 3741be168c0dSopenharmony_ci+ RowMajor2ColLadder12MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); 3742be168c0dSopenharmony_ci+ UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3743be168c0dSopenharmony_ci+ lstm_param->hidden_size_, lstm_param->state_col_align_, false); 3744be168c0dSopenharmony_ci+ } 3745be168c0dSopenharmony_ci+#else 3746be168c0dSopenharmony_ci bool is_vec = lstm_param->batch_ == 1; 3747be168c0dSopenharmony_ci if (is_vec) { 3748be168c0dSopenharmony_ci- UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, 3749be168c0dSopenharmony_ci- lstm_param->project_size_, lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3750be168c0dSopenharmony_ci+ UpdateLstmGateFp16(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3751be168c0dSopenharmony_ci+ lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3752be168c0dSopenharmony_ci } else { 3753be168c0dSopenharmony_ci- // pack state for matmul 3754be168c0dSopenharmony_ci- RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->project_size_, false); 3755be168c0dSopenharmony_ci- UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, 3756be168c0dSopenharmony_ci- lstm_param->project_size_, lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3757be168c0dSopenharmony_ci+ RowMajor2Col16MajorFp16(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_, false); 3758be168c0dSopenharmony_ci+ UpdateLstmGateFp16(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 3759be168c0dSopenharmony_ci+ lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec); 3760be168c0dSopenharmony_ci } 3761be168c0dSopenharmony_ci+#endif 3762be168c0dSopenharmony_ci ElementAddFp16(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); 3763be168c0dSopenharmony_ci ElementAddFp16(forget_gate, state_gate + lstm_param->batch_ * lstm_param->hidden_size_ * 2, forget_gate, 3764be168c0dSopenharmony_ci lstm_param->batch_ * lstm_param->hidden_size_); 3765be168c0dSopenharmony_ci@@ -247,24 +282,43 @@ void LstmStepUnitFp16(float16_t *output, float16_t *input_gate, float16_t *forge 3766be168c0dSopenharmony_ci } 3767be168c0dSopenharmony_ci 3768be168c0dSopenharmony_ci if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { 3769be168c0dSopenharmony_ci- (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->project_size_ * sizeof(float16_t)); 3770be168c0dSopenharmony_ci+ (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float16_t)); 3771be168c0dSopenharmony_ci } 3772be168c0dSopenharmony_ci } 3773be168c0dSopenharmony_ci 3774be168c0dSopenharmony_ci-void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, 3775be168c0dSopenharmony_ci- const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, 3776be168c0dSopenharmony_ci- const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, 3777be168c0dSopenharmony_ci- float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, 3778be168c0dSopenharmony_ci- bool is_backward) { 3779be168c0dSopenharmony_ci- float16_t *gate = buffer[1]; 3780be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3781be168c0dSopenharmony_ci+void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, 3782be168c0dSopenharmony_ci+ const LstmParameter *lstm_param) { 3783be168c0dSopenharmony_ci+ int row_input = lstm_param->seq_len_ * lstm_param->batch_; 3784be168c0dSopenharmony_ci+ for (int i = 0; i < C4NUM; i++) { 3785be168c0dSopenharmony_ci+ const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; 3786be168c0dSopenharmony_ci+ const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; 3787be168c0dSopenharmony_ci+ float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; 3788be168c0dSopenharmony_ci+ MatmulFp16OptV2(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, row_input, 3789be168c0dSopenharmony_ci+ lstm_param->hidden_size_, lstm_param->hidden_size_, OutType_Nhwc); 3790be168c0dSopenharmony_ci+ } 3791be168c0dSopenharmony_ci+} 3792be168c0dSopenharmony_ci+#else 3793be168c0dSopenharmony_ci+void LstmGateCompute(float16_t *gate, const float16_t *input, const float16_t *weight_i, const float16_t *input_bias, 3794be168c0dSopenharmony_ci+ const LstmParameter *lstm_param) { 3795be168c0dSopenharmony_ci for (int i = 0; i < C4NUM; i++) { 3796be168c0dSopenharmony_ci const float16_t *weight_loop = weight_i + lstm_param->input_size_ * lstm_param->input_col_align_ * i; 3797be168c0dSopenharmony_ci const float16_t *bias_loop = input_bias + lstm_param->input_col_align_ * i; 3798be168c0dSopenharmony_ci float16_t *gate_loop = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * i; 3799be168c0dSopenharmony_ci- MatMulFp16(packed_input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, 3800be168c0dSopenharmony_ci+ MatMulFp16(input, weight_loop, gate_loop, bias_loop, ActType_No, lstm_param->input_size_, 3801be168c0dSopenharmony_ci lstm_param->seq_len_ * lstm_param->batch_, lstm_param->hidden_size_, lstm_param->hidden_size_, 3802be168c0dSopenharmony_ci OutType_Nhwc); 3803be168c0dSopenharmony_ci } 3804be168c0dSopenharmony_ci+} 3805be168c0dSopenharmony_ci+#endif 3806be168c0dSopenharmony_ci+ 3807be168c0dSopenharmony_ci+void LstmUnidirectionalFp16(float16_t *output, const float16_t *packed_input, const float16_t *weight_i, 3808be168c0dSopenharmony_ci+ const float16_t *weight_h, const float16_t *input_bias, const float16_t *state_bias, 3809be168c0dSopenharmony_ci+ const float16_t *weight_project, const float16_t *project_bias, float16_t *hidden_state, 3810be168c0dSopenharmony_ci+ float16_t *cell_state, float16_t *buffer[C7NUM], const LstmParameter *lstm_param, 3811be168c0dSopenharmony_ci+ bool is_backward) { 3812be168c0dSopenharmony_ci+ float16_t *gate = buffer[1]; 3813be168c0dSopenharmony_ci+ LstmGateCompute(gate, packed_input, weight_i, input_bias, lstm_param); 3814be168c0dSopenharmony_ci 3815be168c0dSopenharmony_ci float16_t *input_gate = gate; 3816be168c0dSopenharmony_ci float16_t *forget_gate = gate + lstm_param->seq_len_ * lstm_param->batch_ * lstm_param->hidden_size_ * 2; 3817be168c0dSopenharmony_ci@@ -287,26 +341,33 @@ void LstmFp16(float16_t *output, const float16_t *input, const float16_t *weight 3818be168c0dSopenharmony_ci const float16_t *project_bias, float16_t *hidden_state, float16_t *cell_state, float16_t *buffer[C7NUM], 3819be168c0dSopenharmony_ci const LstmParameter *lstm_param) { 3820be168c0dSopenharmony_ci // forward 3821be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 3822be168c0dSopenharmony_ci+ const float16_t *packed_input = input; 3823be168c0dSopenharmony_ci+ if (lstm_param->batch_ * lstm_param->seq_len_ >= C4NUM) { 3824be168c0dSopenharmony_ci+ float16_t *temp_input = buffer[0]; 3825be168c0dSopenharmony_ci+ RowMajor2ColLadder12MajorFp16(input, temp_input, lstm_param->seq_len_ * lstm_param->batch_, 3826be168c0dSopenharmony_ci+ lstm_param->input_size_); 3827be168c0dSopenharmony_ci+ packed_input = temp_input; 3828be168c0dSopenharmony_ci+ } 3829be168c0dSopenharmony_ci+#else 3830be168c0dSopenharmony_ci float16_t *packed_input = buffer[0]; 3831be168c0dSopenharmony_ci RowMajor2Col16MajorFp16(input, packed_input, lstm_param->seq_len_ * lstm_param->batch_, lstm_param->input_size_, 3832be168c0dSopenharmony_ci false); 3833be168c0dSopenharmony_ci+#endif 3834be168c0dSopenharmony_ci LstmUnidirectionalFp16(output, packed_input, weight_i, weight_h, input_bias, state_bias, weight_project, project_bias, 3835be168c0dSopenharmony_ci hidden_state, cell_state, buffer, lstm_param, false); 3836be168c0dSopenharmony_ci 3837be168c0dSopenharmony_ci // backward 3838be168c0dSopenharmony_ci if (lstm_param->bidirectional_) { 3839be168c0dSopenharmony_ci const float16_t *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; 3840be168c0dSopenharmony_ci- const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_; 3841be168c0dSopenharmony_ci+ const float16_t *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; 3842be168c0dSopenharmony_ci const float16_t *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; 3843be168c0dSopenharmony_ci const float16_t *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; 3844be168c0dSopenharmony_ci const float16_t *backward_weight_project = 3845be168c0dSopenharmony_ci- weight_project ? weight_project + lstm_param->hidden_size_ * (lstm_param->batch_ == 1 3846be168c0dSopenharmony_ci- ? lstm_param->project_size_ 3847be168c0dSopenharmony_ci- : UP_ROUND(lstm_param->project_size_, C8NUM)) 3848be168c0dSopenharmony_ci- : NULL; 3849be168c0dSopenharmony_ci- float16_t *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; 3850be168c0dSopenharmony_ci+ weight_project ? weight_project + lstm_param->hidden_size_ * lstm_param->proj_col_align_ : NULL; 3851be168c0dSopenharmony_ci+ float16_t *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; 3852be168c0dSopenharmony_ci float16_t *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; 3853be168c0dSopenharmony_ci- float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; 3854be168c0dSopenharmony_ci+ float16_t *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; 3855be168c0dSopenharmony_ci 3856be168c0dSopenharmony_ci LstmUnidirectionalFp16(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, 3857be168c0dSopenharmony_ci backward_state_bias, backward_weight_project, project_bias, backward_hidden_state, 3858be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h 3859be168c0dSopenharmony_ciindex f6f853b4..d6af9c78 100644 3860be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h 3861be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/lstm_fp16.h 3862be168c0dSopenharmony_ci@@ -21,13 +21,17 @@ 3863be168c0dSopenharmony_ci #ifdef __cplusplus 3864be168c0dSopenharmony_ci extern "C" { 3865be168c0dSopenharmony_ci #endif 3866be168c0dSopenharmony_ci-void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align); 3867be168c0dSopenharmony_ci+void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align, 3868be168c0dSopenharmony_ci+ const int32_t *order); 3869be168c0dSopenharmony_ci 3870be168c0dSopenharmony_ci-void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align); 3871be168c0dSopenharmony_ci+void PackLstmWeightFp16(float16_t *dst, const float16_t *src, int batch, int deep, int col, int col_align, 3872be168c0dSopenharmony_ci+ const int32_t *order); 3873be168c0dSopenharmony_ci 3874be168c0dSopenharmony_ci-void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional); 3875be168c0dSopenharmony_ci+void PackLstmBiasFp32ToFp16(float16_t *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 3876be168c0dSopenharmony_ci+ const int32_t *order); 3877be168c0dSopenharmony_ci 3878be168c0dSopenharmony_ci-void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional); 3879be168c0dSopenharmony_ci+void PackLstmBiasFp16(float16_t *dst, const float16_t *src, int batch, int col, int col_align, bool is_bidirectional, 3880be168c0dSopenharmony_ci+ const int32_t *order); 3881be168c0dSopenharmony_ci 3882be168c0dSopenharmony_ci void LstmMatMulFp16(float16_t *c, const float16_t *a, const float16_t *b, const float16_t *bias, int row, int deep, 3883be168c0dSopenharmony_ci int col, bool is_vec); 3884be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c 3885be168c0dSopenharmony_ciindex 1aefbaf5..39dcb9ee 100644 3886be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c 3887be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.c 3888be168c0dSopenharmony_ci@@ -16,7 +16,7 @@ 3889be168c0dSopenharmony_ci 3890be168c0dSopenharmony_ci #include "nnacl/fp16/matmul_fp16.h" 3891be168c0dSopenharmony_ci 3892be168c0dSopenharmony_ci-static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { 3893be168c0dSopenharmony_ci+static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col) { 3894be168c0dSopenharmony_ci int row_c8 = row / C8NUM * C8NUM; 3895be168c0dSopenharmony_ci int col_c8 = col / C8NUM * C8NUM; 3896be168c0dSopenharmony_ci const float16_t *src = (const float16_t *)src_ptr; 3897be168c0dSopenharmony_ci@@ -108,7 +108,7 @@ static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t 3898be168c0dSopenharmony_ci } 3899be168c0dSopenharmony_ci } 3900be168c0dSopenharmony_ci 3901be168c0dSopenharmony_ci-static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { 3902be168c0dSopenharmony_ci+static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, int row, int col) { 3903be168c0dSopenharmony_ci int row_c8 = row / C8NUM * C8NUM; 3904be168c0dSopenharmony_ci int col_c8 = col / C8NUM * C8NUM; 3905be168c0dSopenharmony_ci int ci = 0; 3906be168c0dSopenharmony_ci@@ -410,17 +410,14 @@ void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f 3907be168c0dSopenharmony_ci int di = 0; 3908be168c0dSopenharmony_ci for (; di < depth - C8NUM + 1; di += C8NUM) { 3909be168c0dSopenharmony_ci float16x8_t av = vld1q_f16(a + di); 3910be168c0dSopenharmony_ci- float16x8_t bv_0; 3911be168c0dSopenharmony_ci- float16x8_t bv_1; 3912be168c0dSopenharmony_ci- for (int i = 0; i < C8NUM; i += C2NUM) { 3913be168c0dSopenharmony_ci- bv_0 = vld1q_f16(bv_base); // bv_i为一行,8列数据 3914be168c0dSopenharmony_ci- acc_0 = vfmaq_n_f16(acc_0, bv_0, av[i]); // av[i]为向量中的一个值 3915be168c0dSopenharmony_ci- bv_base += C8NUM; 3916be168c0dSopenharmony_ci- 3917be168c0dSopenharmony_ci- bv_1 = vld1q_f16(bv_base); // bv_i为一行,8列数据 3918be168c0dSopenharmony_ci- acc_0 = vfmaq_n_f16(acc_0, bv_1, av[i + 1]); // av[i]为向量中的一个值 3919be168c0dSopenharmony_ci+ float16x8_t bv_0[C8NUM]; 3920be168c0dSopenharmony_ci+ for (int i = 0; i < C8NUM; ++i) { 3921be168c0dSopenharmony_ci+ bv_0[i] = vld1q_f16(bv_base); 3922be168c0dSopenharmony_ci bv_base += C8NUM; 3923be168c0dSopenharmony_ci } 3924be168c0dSopenharmony_ci+ for (int i = 0; i < C8NUM; ++i) { 3925be168c0dSopenharmony_ci+ acc_0 = vfmaq_n_f16(acc_0, bv_0[i], av[i]); 3926be168c0dSopenharmony_ci+ } 3927be168c0dSopenharmony_ci } 3928be168c0dSopenharmony_ci if (di < depth) { 3929be168c0dSopenharmony_ci for (; di < depth; ++di) { 3930be168c0dSopenharmony_ci@@ -636,8 +633,94 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si 3931be168c0dSopenharmony_ci } 3932be168c0dSopenharmony_ci 3933be168c0dSopenharmony_ci #ifdef ENABLE_ARM64 3934be168c0dSopenharmony_ci-void RowMajor2ColNMajorFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { 3935be168c0dSopenharmony_ci- // Col16Major ==> Col8Major ==> Col4Major 3936be168c0dSopenharmony_ci+void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col) { 3937be168c0dSopenharmony_ci+ // Col12Major ==> Col8Major ==> Col4Major 3938be168c0dSopenharmony_ci+ const float16_t *src_r = src; 3939be168c0dSopenharmony_ci+ float16_t *dst_r = dst_ptr; 3940be168c0dSopenharmony_ci+ int ri = 0; 3941be168c0dSopenharmony_ci+ size_t col8 = col / C8NUM * C8NUM; 3942be168c0dSopenharmony_ci+ // find 16 block unit 3943be168c0dSopenharmony_ci+ for (; ri <= row - C12NUM; ri += C12NUM) { 3944be168c0dSopenharmony_ci+ size_t ci = 0; 3945be168c0dSopenharmony_ci+ for (; ci < col8; ci += C8NUM) { 3946be168c0dSopenharmony_ci+ const float16_t *src_c = src_r + ci; 3947be168c0dSopenharmony_ci+ float16_t *dst_c = dst_r + ci * C12NUM; 3948be168c0dSopenharmony_ci+ Transpose12x8ARM64Fp16(src_c, dst_c, col * C2NUM, C24NUM); 3949be168c0dSopenharmony_ci+ } 3950be168c0dSopenharmony_ci+ for (; ci < col; ci++) { 3951be168c0dSopenharmony_ci+ const float16_t *src_c = src_r + ci; 3952be168c0dSopenharmony_ci+ float16_t *dst_c = dst_r + ci * C12NUM; 3953be168c0dSopenharmony_ci+ for (size_t i = 0; i < C12NUM; i++) { 3954be168c0dSopenharmony_ci+ dst_c[i] = src_c[i * col]; 3955be168c0dSopenharmony_ci+ } 3956be168c0dSopenharmony_ci+ } 3957be168c0dSopenharmony_ci+ src_r += C12NUM * col; 3958be168c0dSopenharmony_ci+ dst_r += C12NUM * col; 3959be168c0dSopenharmony_ci+ } 3960be168c0dSopenharmony_ci+ for (; ri <= row - C8NUM; ri += C8NUM) { 3961be168c0dSopenharmony_ci+ size_t ci = 0; 3962be168c0dSopenharmony_ci+ for (; ci < col8; ci += C8NUM) { 3963be168c0dSopenharmony_ci+ const float16_t *src_c = src_r + ci; 3964be168c0dSopenharmony_ci+ float16_t *dst_c = dst_r + ci * C8NUM; 3965be168c0dSopenharmony_ci+ Transpose8x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C8NUM * sizeof(float16_t)); 3966be168c0dSopenharmony_ci+ } 3967be168c0dSopenharmony_ci+ for (; ci < col; ci++) { 3968be168c0dSopenharmony_ci+ const float16_t *src_c = src_r + ci; 3969be168c0dSopenharmony_ci+ float16_t *dst_c = dst_r + ci * C8NUM; 3970be168c0dSopenharmony_ci+ for (size_t i = 0; i < C8NUM; i++) { 3971be168c0dSopenharmony_ci+ dst_c[i] = src_c[i * col]; 3972be168c0dSopenharmony_ci+ } 3973be168c0dSopenharmony_ci+ } 3974be168c0dSopenharmony_ci+ src_r += C8NUM * col; 3975be168c0dSopenharmony_ci+ dst_r += C8NUM * col; 3976be168c0dSopenharmony_ci+ } 3977be168c0dSopenharmony_ci+ for (; ri <= row - C4NUM; ri += C4NUM) { 3978be168c0dSopenharmony_ci+ size_t ci = 0; 3979be168c0dSopenharmony_ci+ for (; ci < col8; ci += C8NUM) { 3980be168c0dSopenharmony_ci+ const float16_t *src_c = src_r + ci; 3981be168c0dSopenharmony_ci+ float16_t *dst_c = dst_r + ci * C4NUM; 3982be168c0dSopenharmony_ci+ Transpose4x8ARM64Fp16(src_c, dst_c, col * sizeof(float16_t), C4NUM * sizeof(float16_t)); 3983be168c0dSopenharmony_ci+ } 3984be168c0dSopenharmony_ci+ for (; ci < col; ci++) { 3985be168c0dSopenharmony_ci+ const float16_t *src_c = src_r + ci; 3986be168c0dSopenharmony_ci+ float16_t *dst_c = dst_r + ci * C4NUM; 3987be168c0dSopenharmony_ci+ for (size_t i = 0; i < C4NUM; i++) { 3988be168c0dSopenharmony_ci+ dst_c[i] = src_c[i * col]; 3989be168c0dSopenharmony_ci+ } 3990be168c0dSopenharmony_ci+ } 3991be168c0dSopenharmony_ci+ src_r += C4NUM * col; 3992be168c0dSopenharmony_ci+ dst_r += C4NUM * col; 3993be168c0dSopenharmony_ci+ } 3994be168c0dSopenharmony_ci+ if (ri < row) { 3995be168c0dSopenharmony_ci+ memcpy(dst_r, src_r, (row - ri) * col * C2NUM); 3996be168c0dSopenharmony_ci+ } 3997be168c0dSopenharmony_ci+} 3998be168c0dSopenharmony_ci+ 3999be168c0dSopenharmony_ci+void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col) { 4000be168c0dSopenharmony_ci+ // Row12 ==> Row8 ==> Row4 4001be168c0dSopenharmony_ci+ for (int r = 0; r < row; r++) { 4002be168c0dSopenharmony_ci+ int c = 0; 4003be168c0dSopenharmony_ci+ for (; c <= col - C12NUM; c += C12NUM) { 4004be168c0dSopenharmony_ci+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4005be168c0dSopenharmony_ci+ MS_FLOAT16X4 src_data1 = MS_LD_F16(src + r * col + c + C8NUM); 4006be168c0dSopenharmony_ci+ MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM, src_data); 4007be168c0dSopenharmony_ci+ MS_ST_F16(dst + c / C12NUM * C12NUM * row + r * C12NUM + C8NUM, src_data1); 4008be168c0dSopenharmony_ci+ } 4009be168c0dSopenharmony_ci+ for (; c <= col - C8NUM; c += C8NUM) { 4010be168c0dSopenharmony_ci+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4011be168c0dSopenharmony_ci+ MS_STQ_F16(dst + c / C12NUM * C12NUM * row + r * C8NUM, src_data); 4012be168c0dSopenharmony_ci+ } 4013be168c0dSopenharmony_ci+ for (; c <= col - C4NUM; c += C4NUM) { 4014be168c0dSopenharmony_ci+ MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); 4015be168c0dSopenharmony_ci+ MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4016be168c0dSopenharmony_ci+ } 4017be168c0dSopenharmony_ci+ for (; c < col; ++c) { 4018be168c0dSopenharmony_ci+ dst[c / C4NUM * C4NUM * row + r + c % C4NUM * row] = src[r * col + c]; 4019be168c0dSopenharmony_ci+ } 4020be168c0dSopenharmony_ci+ } 4021be168c0dSopenharmony_ci+} 4022be168c0dSopenharmony_ci+ 4023be168c0dSopenharmony_ci+void RowMajor2ColNMajorFp16srcFp16(const float16_t *src_ptr, float16_t *dst_ptr, int row, int col) { 4024be168c0dSopenharmony_ci const float16_t *src_r = src_ptr; 4025be168c0dSopenharmony_ci float16_t *dst_r = dst_ptr; 4026be168c0dSopenharmony_ci int ri = 0; 4027be168c0dSopenharmony_ci@@ -702,6 +785,112 @@ void RowMajor2ColNMajorFp16(const float16_t *src_ptr, float16_t *dst_ptr, int ro 4028be168c0dSopenharmony_ci dst_r += 1; 4029be168c0dSopenharmony_ci } 4030be168c0dSopenharmony_ci } 4031be168c0dSopenharmony_ci+ 4032be168c0dSopenharmony_ci+void RowMajor2ColNMajorFp16(const void *src_ptr, float16_t *dst_ptr, int row, int col, bool is_fp32_src) { 4033be168c0dSopenharmony_ci+ // Col16Major ==> Col8Major ==> Col4Major 4034be168c0dSopenharmony_ci+ if (!is_fp32_src) { 4035be168c0dSopenharmony_ci+ RowMajor2ColNMajorFp16srcFp16((const float16_t *)src_ptr, dst_ptr, row, col); 4036be168c0dSopenharmony_ci+ return; 4037be168c0dSopenharmony_ci+ } 4038be168c0dSopenharmony_ci+ const float *src_r = src_ptr; 4039be168c0dSopenharmony_ci+ float16_t *dst_r = dst_ptr; 4040be168c0dSopenharmony_ci+ int ri = 0; 4041be168c0dSopenharmony_ci+ // find 16 block unit 4042be168c0dSopenharmony_ci+ for (; ri <= row - C16NUM; ri += C16NUM) { 4043be168c0dSopenharmony_ci+ for (int r = 0; r < C16NUM; ++r) { 4044be168c0dSopenharmony_ci+ for (int c = 0; c < col; ++c) { 4045be168c0dSopenharmony_ci+ dst_r[c * C16NUM + r % C16NUM] = src_r[r * col + c]; 4046be168c0dSopenharmony_ci+ } 4047be168c0dSopenharmony_ci+ } 4048be168c0dSopenharmony_ci+ src_r += C16NUM * col; 4049be168c0dSopenharmony_ci+ dst_r += C16NUM * col; 4050be168c0dSopenharmony_ci+ } 4051be168c0dSopenharmony_ci+ for (; ri <= row - C8NUM; ri += C8NUM) { 4052be168c0dSopenharmony_ci+ for (int r = 0; r < C8NUM; ++r) { 4053be168c0dSopenharmony_ci+ for (int c = 0; c < col; ++c) { 4054be168c0dSopenharmony_ci+ dst_r[c * C8NUM + r % C8NUM] = src_r[r * col + c]; 4055be168c0dSopenharmony_ci+ } 4056be168c0dSopenharmony_ci+ } 4057be168c0dSopenharmony_ci+ src_r += C8NUM * col; 4058be168c0dSopenharmony_ci+ dst_r += C8NUM * col; 4059be168c0dSopenharmony_ci+ } 4060be168c0dSopenharmony_ci+ for (; ri <= row - C4NUM; ri += C4NUM) { 4061be168c0dSopenharmony_ci+ for (int r = 0; r < C4NUM; ++r) { 4062be168c0dSopenharmony_ci+ for (int c = 0; c < col; ++c) { 4063be168c0dSopenharmony_ci+ dst_r[c * C4NUM + r % C4NUM] = src_r[r * col + c]; 4064be168c0dSopenharmony_ci+ } 4065be168c0dSopenharmony_ci+ } 4066be168c0dSopenharmony_ci+ src_r += C4NUM * col; 4067be168c0dSopenharmony_ci+ dst_r += C4NUM * col; 4068be168c0dSopenharmony_ci+ } 4069be168c0dSopenharmony_ci+ for (; ri < row; ++ri) { 4070be168c0dSopenharmony_ci+ for (size_t i = 0; i < col; ++i) { 4071be168c0dSopenharmony_ci+ dst_r[i * C4NUM] = src_r[i]; 4072be168c0dSopenharmony_ci+ } 4073be168c0dSopenharmony_ci+ src_r += col; 4074be168c0dSopenharmony_ci+ dst_r += 1; 4075be168c0dSopenharmony_ci+ } 4076be168c0dSopenharmony_ci+} 4077be168c0dSopenharmony_ci+ 4078be168c0dSopenharmony_ci+void RowMajor2RowNMajorFp16(const void *src_ptr, float16_t *dst, int row, int col, bool is_fp32_src) { 4079be168c0dSopenharmony_ci+ // Row16 ==> Row8 ==> Row4 4080be168c0dSopenharmony_ci+ if (is_fp32_src) { 4081be168c0dSopenharmony_ci+ const float *src = (const float *)src_ptr; 4082be168c0dSopenharmony_ci+ for (int r = 0; r < row; r++) { 4083be168c0dSopenharmony_ci+ int c = 0; 4084be168c0dSopenharmony_ci+ for (; c <= col - C16NUM; c += C16NUM) { 4085be168c0dSopenharmony_ci+ const float *cur_src = src + r * col + c; 4086be168c0dSopenharmony_ci+ MS_FLOAT32X4X4 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM), MS_LDQ_F32(cur_src + C8NUM), 4087be168c0dSopenharmony_ci+ MS_LDQ_F32(cur_src + C12NUM)}; 4088be168c0dSopenharmony_ci+ MS_FLOAT16X4X4 res = { 4089be168c0dSopenharmony_ci+ MS_CVT_F16_F32(src_f32_data.val[0]), 4090be168c0dSopenharmony_ci+ MS_CVT_F16_F32(src_f32_data.val[1]), 4091be168c0dSopenharmony_ci+ MS_CVT_F16_F32(src_f32_data.val[2]), 4092be168c0dSopenharmony_ci+ MS_CVT_F16_F32(src_f32_data.val[3]), 4093be168c0dSopenharmony_ci+ }; 4094be168c0dSopenharmony_ci+ MS_ST4_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, res); 4095be168c0dSopenharmony_ci+ } 4096be168c0dSopenharmony_ci+ for (; c <= col - C8NUM; c += C8NUM) { 4097be168c0dSopenharmony_ci+ const float *cur_src = src + r * col + c; 4098be168c0dSopenharmony_ci+ MS_FLOAT32X4X2 src_f32_data = {MS_LDQ_F32(cur_src), MS_LDQ_F32(cur_src + C4NUM)}; 4099be168c0dSopenharmony_ci+ MS_FLOAT16X4X2 res = { 4100be168c0dSopenharmony_ci+ MS_CVT_F16_F32(src_f32_data.val[0]), 4101be168c0dSopenharmony_ci+ MS_CVT_F16_F32(src_f32_data.val[1]), 4102be168c0dSopenharmony_ci+ }; 4103be168c0dSopenharmony_ci+ MS_ST2_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, res); 4104be168c0dSopenharmony_ci+ } 4105be168c0dSopenharmony_ci+ for (; c <= col - C4NUM; c += C4NUM) { 4106be168c0dSopenharmony_ci+ MS_FLOAT16X4 src_data = MS_CVT_F16_F32(MS_LDQ_F32(src + r * col + c)); 4107be168c0dSopenharmony_ci+ MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4108be168c0dSopenharmony_ci+ } 4109be168c0dSopenharmony_ci+ for (; c < col; ++c) { 4110be168c0dSopenharmony_ci+ dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; 4111be168c0dSopenharmony_ci+ } 4112be168c0dSopenharmony_ci+ } 4113be168c0dSopenharmony_ci+ return; 4114be168c0dSopenharmony_ci+ } 4115be168c0dSopenharmony_ci+ const float16_t *src = (const float16_t *)src_ptr; 4116be168c0dSopenharmony_ci+ for (int r = 0; r < row; r++) { 4117be168c0dSopenharmony_ci+ int c = 0; 4118be168c0dSopenharmony_ci+ for (; c <= col - C16NUM; c += C16NUM) { 4119be168c0dSopenharmony_ci+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4120be168c0dSopenharmony_ci+ MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); 4121be168c0dSopenharmony_ci+ MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); 4122be168c0dSopenharmony_ci+ MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); 4123be168c0dSopenharmony_ci+ } 4124be168c0dSopenharmony_ci+ for (; c <= col - C8NUM; c += C8NUM) { 4125be168c0dSopenharmony_ci+ MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4126be168c0dSopenharmony_ci+ MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); 4127be168c0dSopenharmony_ci+ } 4128be168c0dSopenharmony_ci+ for (; c <= col - C4NUM; c += C4NUM) { 4129be168c0dSopenharmony_ci+ MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); 4130be168c0dSopenharmony_ci+ MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4131be168c0dSopenharmony_ci+ } 4132be168c0dSopenharmony_ci+ for (; c < col; ++c) { 4133be168c0dSopenharmony_ci+ dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; 4134be168c0dSopenharmony_ci+ } 4135be168c0dSopenharmony_ci+ } 4136be168c0dSopenharmony_ci+} 4137be168c0dSopenharmony_ci #endif 4138be168c0dSopenharmony_ci 4139be168c0dSopenharmony_ci void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { 4140be168c0dSopenharmony_ci@@ -802,32 +991,6 @@ void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, 4141be168c0dSopenharmony_ci } 4142be168c0dSopenharmony_ci } 4143be168c0dSopenharmony_ci 4144be168c0dSopenharmony_ci-#ifdef ENABLE_ARM64 4145be168c0dSopenharmony_ci-void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col) { 4146be168c0dSopenharmony_ci- // Row16 ==> Row8 ==> Row4 4147be168c0dSopenharmony_ci- for (int r = 0; r < row; r++) { 4148be168c0dSopenharmony_ci- int c = 0; 4149be168c0dSopenharmony_ci- for (; c <= col - C16NUM; c += C16NUM) { 4150be168c0dSopenharmony_ci- MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4151be168c0dSopenharmony_ci- MS_FLOAT16X8 src_data1 = MS_LDQ_F16(src + r * col + c + C8NUM); 4152be168c0dSopenharmony_ci- MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM, src_data); 4153be168c0dSopenharmony_ci- MS_STQ_F16(dst + c / C16NUM * C16NUM * row + r * C16NUM + C8NUM, src_data1); 4154be168c0dSopenharmony_ci- } 4155be168c0dSopenharmony_ci- for (; c <= col - C8NUM; c += C8NUM) { 4156be168c0dSopenharmony_ci- MS_FLOAT16X8 src_data = MS_LDQ_F16(src + r * col + c); 4157be168c0dSopenharmony_ci- MS_STQ_F16(dst + c / C8NUM * C8NUM * row + r * C8NUM, src_data); 4158be168c0dSopenharmony_ci- } 4159be168c0dSopenharmony_ci- for (; c <= col - C4NUM; c += C4NUM) { 4160be168c0dSopenharmony_ci- MS_FLOAT16X4 src_data = MS_LD_F16(src + r * col + c); 4161be168c0dSopenharmony_ci- MS_ST_F16(dst + c / C4NUM * C4NUM * row + r * C4NUM, src_data); 4162be168c0dSopenharmony_ci- } 4163be168c0dSopenharmony_ci- for (; c < col; ++c) { 4164be168c0dSopenharmony_ci- dst[c / C4NUM * C4NUM * row + r * C4NUM + c % C4NUM] = src[r * col + c]; 4165be168c0dSopenharmony_ci- } 4166be168c0dSopenharmony_ci- } 4167be168c0dSopenharmony_ci-} 4168be168c0dSopenharmony_ci-#endif 4169be168c0dSopenharmony_ci- 4170be168c0dSopenharmony_ci void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col) { 4171be168c0dSopenharmony_ci int col_align = UP_ROUND(col, C16NUM); 4172be168c0dSopenharmony_ci for (int r = 0; r < row; r++) { 4173be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h 4174be168c0dSopenharmony_ciindex be7f8443..7acef622 100644 4175be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h 4176be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/matmul_fp16.h 4177be168c0dSopenharmony_ci@@ -14,8 +14,8 @@ 4178be168c0dSopenharmony_ci * limitations under the License. 4179be168c0dSopenharmony_ci */ 4180be168c0dSopenharmony_ci 4181be168c0dSopenharmony_ci-#ifndef NNACL_FP16_MATMUL_FP16_H_ 4182be168c0dSopenharmony_ci-#define NNACL_FP16_MATMUL_FP16_H_ 4183be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_FP16_MATMUL_H_ 4184be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_FP16_MATMUL_H_ 4185be168c0dSopenharmony_ci 4186be168c0dSopenharmony_ci #include <float.h> 4187be168c0dSopenharmony_ci #include <string.h> 4188be168c0dSopenharmony_ci@@ -45,9 +45,13 @@ void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, cons 4189be168c0dSopenharmony_ci int deep, int row, int col, int stride, int write_mode); 4190be168c0dSopenharmony_ci 4191be168c0dSopenharmony_ci #ifdef ENABLE_ARM64 4192be168c0dSopenharmony_ci-void RowMajor2ColNMajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); 4193be168c0dSopenharmony_ci+void RowMajor2ColLadder12MajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col); 4194be168c0dSopenharmony_ci 4195be168c0dSopenharmony_ci-void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col); 4196be168c0dSopenharmony_ci+void RowMajor2RowLadder12MajorFp16(const float16_t *src, float16_t *dst, int row, int col); 4197be168c0dSopenharmony_ci+ 4198be168c0dSopenharmony_ci+void RowMajor2ColNMajorFp16(const void *src, float16_t *dst_ptr, int row, int col, bool is_fp32_src); 4199be168c0dSopenharmony_ci+ 4200be168c0dSopenharmony_ci+void RowMajor2RowNMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); 4201be168c0dSopenharmony_ci 4202be168c0dSopenharmony_ci void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, 4203be168c0dSopenharmony_ci int deep, int row, int col, size_t stride, size_t out_type); 4204be168c0dSopenharmony_ci@@ -60,6 +64,9 @@ void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, c 4205be168c0dSopenharmony_ci void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 4206be168c0dSopenharmony_ci size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); 4207be168c0dSopenharmony_ci 4208be168c0dSopenharmony_ci+void MatmulFp16OptV2(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 4209be168c0dSopenharmony_ci+ size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); 4210be168c0dSopenharmony_ci+ 4211be168c0dSopenharmony_ci #ifdef ENABLE_DEBUG 4212be168c0dSopenharmony_ci void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, 4213be168c0dSopenharmony_ci size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); 4214be168c0dSopenharmony_ci@@ -118,4 +125,4 @@ void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bo 4215be168c0dSopenharmony_ci } 4216be168c0dSopenharmony_ci #endif 4217be168c0dSopenharmony_ci 4218be168c0dSopenharmony_ci-#endif // NNACL_FP16_MATMUL_FP16_H_ 4219be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_FP16_MATMUL_H_ 4220be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c 4221be168c0dSopenharmony_ciindex 74e75115..da9f6bef 100644 4222be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c 4223be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.c 4224be168c0dSopenharmony_ci@@ -33,7 +33,7 @@ static void PackLstmMatrix(const float *src_batch, float *dst_batch, int col, in 4225be168c0dSopenharmony_ci } 4226be168c0dSopenharmony_ci 4227be168c0dSopenharmony_ci static void PackLstmWeightBatch(float *dst, const float *src, int batch, int deep, int col, int col_align, 4228be168c0dSopenharmony_ci- const int32_t *order) { 4229be168c0dSopenharmony_ci+ const int *order) { 4230be168c0dSopenharmony_ci for (int i = 0; i < batch; i++) { 4231be168c0dSopenharmony_ci const float *src_batch = src + i * col * deep; 4232be168c0dSopenharmony_ci float *dst_batch = dst + ((order == NULL) ? i : order[i]) * col_align * deep; 4233be168c0dSopenharmony_ci@@ -41,12 +41,12 @@ static void PackLstmWeightBatch(float *dst, const float *src, int batch, int dee 4234be168c0dSopenharmony_ci } 4235be168c0dSopenharmony_ci } 4236be168c0dSopenharmony_ci 4237be168c0dSopenharmony_ci-void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order) { 4238be168c0dSopenharmony_ci+void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order) { 4239be168c0dSopenharmony_ci PackLstmWeightBatch(dst, src, batch, deep, col, col_align, order); 4240be168c0dSopenharmony_ci } 4241be168c0dSopenharmony_ci 4242be168c0dSopenharmony_ci void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, 4243be168c0dSopenharmony_ci- bool is_bidirectional, int stride, const int32_t *order) { 4244be168c0dSopenharmony_ci+ bool is_bidirectional, int stride, const int *order) { 4245be168c0dSopenharmony_ci int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 4246be168c0dSopenharmony_ci PackLstmWeightBatch(dst, src, unidirectional_batch, deep, col, col_align, order); 4247be168c0dSopenharmony_ci src += stride; 4248be168c0dSopenharmony_ci@@ -57,7 +57,7 @@ void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, 4249be168c0dSopenharmony_ci } 4250be168c0dSopenharmony_ci 4251be168c0dSopenharmony_ci void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4252be168c0dSopenharmony_ci- const int32_t *order) { 4253be168c0dSopenharmony_ci+ const int *order) { 4254be168c0dSopenharmony_ci int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 4255be168c0dSopenharmony_ci for (int i = 0; i < unidirectional_batch; i++) { 4256be168c0dSopenharmony_ci const float *src_batch = src + i * col; 4257be168c0dSopenharmony_ci@@ -76,7 +76,7 @@ void PackLstmBias(float *dst, const float *src, int batch, int col, int col_alig 4258be168c0dSopenharmony_ci } 4259be168c0dSopenharmony_ci 4260be168c0dSopenharmony_ci void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4261be168c0dSopenharmony_ci- int b_stride, const int32_t *order) { 4262be168c0dSopenharmony_ci+ int b_stride, const int *order) { 4263be168c0dSopenharmony_ci int unidirectional_batch = is_bidirectional ? batch / 2 : batch; 4264be168c0dSopenharmony_ci for (int i = 0; i < unidirectional_batch; i++) { 4265be168c0dSopenharmony_ci const float *src_batch = src + i * col; 4266be168c0dSopenharmony_ci@@ -175,13 +175,13 @@ void UpdateOutput(float *hidden_state, float *output, const float *cell_state, c 4267be168c0dSopenharmony_ci const float *weight_project, float *buffer[C8NUM], const LstmParameter *lstm_param) { 4268be168c0dSopenharmony_ci int batch = lstm_param->batch_; 4269be168c0dSopenharmony_ci int hidden_size = lstm_param->hidden_size_; 4270be168c0dSopenharmony_ci- int project_size = lstm_param->project_size_; 4271be168c0dSopenharmony_ci+ int output_size = lstm_param->output_size_; 4272be168c0dSopenharmony_ci float *state_buffer = buffer[C4NUM]; 4273be168c0dSopenharmony_ci float *hidden_buffer = weight_project ? buffer[C2NUM] : hidden_state; 4274be168c0dSopenharmony_ci float zoneout = lstm_param->zoneout_hidden_; 4275be168c0dSopenharmony_ci if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 4276be168c0dSopenharmony_ci- (void)memcpy(state_buffer, hidden_state, batch * project_size * sizeof(float)); 4277be168c0dSopenharmony_ci- ElementOptMul(state_buffer, &zoneout, state_buffer, batch * project_size, false); 4278be168c0dSopenharmony_ci+ (void)memcpy(state_buffer, hidden_state, batch * hidden_size * sizeof(float)); 4279be168c0dSopenharmony_ci+ ElementOptMul(state_buffer, &zoneout, state_buffer, batch * hidden_size, false); 4280be168c0dSopenharmony_ci } 4281be168c0dSopenharmony_ci 4282be168c0dSopenharmony_ci Tanh(cell_state, batch * hidden_size, hidden_buffer); 4283be168c0dSopenharmony_ci@@ -193,20 +193,13 @@ void UpdateOutput(float *hidden_state, float *output, const float *cell_state, c 4284be168c0dSopenharmony_ci left_matrix = buffer[C6NUM]; 4285be168c0dSopenharmony_ci PackLstmInput(hidden_buffer, left_matrix, batch, hidden_size); 4286be168c0dSopenharmony_ci } 4287be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 4288be168c0dSopenharmony_ci- int col_tile = batch == 1 ? C8NUM : C16NUM; 4289be168c0dSopenharmony_ci-#elif defined(ENABLE_ARM32) 4290be168c0dSopenharmony_ci- int col_tile = C4NUM; 4291be168c0dSopenharmony_ci-#else 4292be168c0dSopenharmony_ci- int col_tile = C8NUM; 4293be168c0dSopenharmony_ci-#endif 4294be168c0dSopenharmony_ci- LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, project_size, 4295be168c0dSopenharmony_ci- UP_ROUND(project_size, col_tile), batch == 1, buffer[C7NUM]); 4296be168c0dSopenharmony_ci+ LstmMatMul(hidden_state, left_matrix, weight_project, NULL, batch, hidden_size, output_size, 4297be168c0dSopenharmony_ci+ lstm_param->proj_col_align_, batch == 1, buffer[C7NUM]); 4298be168c0dSopenharmony_ci } 4299be168c0dSopenharmony_ci if (!(zoneout >= -FLT_EPSILON && zoneout <= FLT_EPSILON)) { 4300be168c0dSopenharmony_ci- ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * project_size); 4301be168c0dSopenharmony_ci+ ElementOptMulAcc(hidden_state, 1 - zoneout, state_buffer, batch * output_size); 4302be168c0dSopenharmony_ci } 4303be168c0dSopenharmony_ci- (void)memcpy(output, hidden_state, batch * project_size * sizeof(float)); 4304be168c0dSopenharmony_ci+ (void)memcpy(output, hidden_state, batch * output_size * sizeof(float)); 4305be168c0dSopenharmony_ci } 4306be168c0dSopenharmony_ci 4307be168c0dSopenharmony_ci void UpdateLstmGate(float *gate_buffer, const float *input, const float *weight, const float *bias, int row, int deep, 4308be168c0dSopenharmony_ci@@ -238,12 +231,12 @@ void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *c 4309be168c0dSopenharmony_ci bool is_vec = lstm_param->batch_ == 1; 4310be168c0dSopenharmony_ci // state * weight 4311be168c0dSopenharmony_ci if (is_vec) { 4312be168c0dSopenharmony_ci- UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->project_size_, 4313be168c0dSopenharmony_ci+ UpdateLstmGate(state_gate, hidden_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 4314be168c0dSopenharmony_ci lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); 4315be168c0dSopenharmony_ci } else { 4316be168c0dSopenharmony_ci // pack state for matmul 4317be168c0dSopenharmony_ci- PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->project_size_); 4318be168c0dSopenharmony_ci- UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->project_size_, 4319be168c0dSopenharmony_ci+ PackLstmInput(hidden_state, packed_state, lstm_param->batch_, lstm_param->output_size_); 4320be168c0dSopenharmony_ci+ UpdateLstmGate(state_gate, packed_state, state_weight, state_bias, lstm_param->batch_, lstm_param->output_size_, 4321be168c0dSopenharmony_ci lstm_param->hidden_size_, lstm_param->state_col_align_, is_vec, packed_output); 4322be168c0dSopenharmony_ci } 4323be168c0dSopenharmony_ci ElementAdd(input_gate, state_gate, input_gate, lstm_param->batch_ * lstm_param->hidden_size_); 4324be168c0dSopenharmony_ci@@ -276,7 +269,7 @@ void LstmStepUnit(float *output, float *input_gate, float *forget_gate, float *c 4325be168c0dSopenharmony_ci } 4326be168c0dSopenharmony_ci 4327be168c0dSopenharmony_ci if (!(lstm_param->zoneout_hidden_ >= -FLT_EPSILON && lstm_param->zoneout_hidden_ <= FLT_EPSILON)) { 4328be168c0dSopenharmony_ci- (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->project_size_ * sizeof(float)); 4329be168c0dSopenharmony_ci+ (void)memcpy(hidden_state, hidden_buffer, lstm_param->batch_ * lstm_param->output_size_ * sizeof(float)); 4330be168c0dSopenharmony_ci } 4331be168c0dSopenharmony_ci } 4332be168c0dSopenharmony_ci 4333be168c0dSopenharmony_ci@@ -322,12 +315,12 @@ void Lstm(float *output, const float *input, const float *weight_i, const float 4334be168c0dSopenharmony_ci // backward 4335be168c0dSopenharmony_ci if (lstm_param->bidirectional_) { 4336be168c0dSopenharmony_ci const float *backward_weight_i = weight_i + 4 * lstm_param->input_col_align_ * lstm_param->input_size_; 4337be168c0dSopenharmony_ci- const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->hidden_size_; 4338be168c0dSopenharmony_ci+ const float *backward_weight_h = weight_h + 4 * lstm_param->state_col_align_ * lstm_param->output_size_; 4339be168c0dSopenharmony_ci const float *backward_input_bias = input_bias + 4 * lstm_param->input_col_align_; 4340be168c0dSopenharmony_ci const float *backward_state_bias = state_bias + 4 * lstm_param->state_col_align_; 4341be168c0dSopenharmony_ci- float *backward_output = output + lstm_param->batch_ * lstm_param->hidden_size_; 4342be168c0dSopenharmony_ci+ float *backward_output = output + lstm_param->batch_ * lstm_param->output_size_; 4343be168c0dSopenharmony_ci float *backward_cell_state = cell_state + lstm_param->batch_ * lstm_param->hidden_size_; 4344be168c0dSopenharmony_ci- float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->hidden_size_; 4345be168c0dSopenharmony_ci+ float *backward_hidden_state = hidden_state + lstm_param->batch_ * lstm_param->output_size_; 4346be168c0dSopenharmony_ci 4347be168c0dSopenharmony_ci LstmUnidirectional(backward_output, packed_input, backward_weight_i, backward_weight_h, backward_input_bias, 4348be168c0dSopenharmony_ci backward_state_bias, backward_hidden_state, backward_cell_state, buffer, lstm_param, true); 4349be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 4350be168c0dSopenharmony_ciindex 88dd9d16..f94f0bb7 100644 4351be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 4352be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/lstm_fp32.h 4353be168c0dSopenharmony_ci@@ -21,16 +21,16 @@ 4354be168c0dSopenharmony_ci #ifdef __cplusplus 4355be168c0dSopenharmony_ci extern "C" { 4356be168c0dSopenharmony_ci #endif 4357be168c0dSopenharmony_ci-void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int32_t *order); 4358be168c0dSopenharmony_ci+void PackLstmWeight(float *dst, const float *src, int batch, int deep, int col, int col_align, const int *order); 4359be168c0dSopenharmony_ci 4360be168c0dSopenharmony_ci void PackLstmWeightWithStride(float *dst, const float *src, int batch, int deep, int col, int col_align, 4361be168c0dSopenharmony_ci- bool is_bidirectional, int stride, const int32_t *order); 4362be168c0dSopenharmony_ci+ bool is_bidirectional, int stride, const int *order); 4363be168c0dSopenharmony_ci 4364be168c0dSopenharmony_ci void PackLstmBias(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4365be168c0dSopenharmony_ci- const int32_t *order); 4366be168c0dSopenharmony_ci+ const int *order); 4367be168c0dSopenharmony_ci 4368be168c0dSopenharmony_ci void PackLstmBiasWithStride(float *dst, const float *src, int batch, int col, int col_align, bool is_bidirectional, 4369be168c0dSopenharmony_ci- int b_stride, const int32_t *order); 4370be168c0dSopenharmony_ci+ int b_stride, const int *order); 4371be168c0dSopenharmony_ci 4372be168c0dSopenharmony_ci void PackLstmInput(const float *src, float *dst, int row, int deep); 4373be168c0dSopenharmony_ci 4374be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c 4375be168c0dSopenharmony_ciindex 308419fb..1898ffd4 100644 4376be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c 4377be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/matmul_fp32.c 4378be168c0dSopenharmony_ci@@ -440,8 +440,8 @@ void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float * 4379be168c0dSopenharmony_ci } 4380be168c0dSopenharmony_ci c[oc_index] = dst; 4381be168c0dSopenharmony_ci } 4382be168c0dSopenharmony_ci- a += k; 4383be168c0dSopenharmony_ci- b += k * col; 4384be168c0dSopenharmony_ci+ a += C1500NUM; 4385be168c0dSopenharmony_ci+ b += C1500NUM * col; 4386be168c0dSopenharmony_ci } 4387be168c0dSopenharmony_ci if (k == depth) { 4388be168c0dSopenharmony_ci return; 4389be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.c 4390be168c0dSopenharmony_cinew file mode 100644 4391be168c0dSopenharmony_ciindex 00000000..ad1cac2e 4392be168c0dSopenharmony_ci--- /dev/null 4393be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.c 4394be168c0dSopenharmony_ci@@ -0,0 +1,36 @@ 4395be168c0dSopenharmony_ci+/** 4396be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 4397be168c0dSopenharmony_ci+ * 4398be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 4399be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 4400be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 4401be168c0dSopenharmony_ci+ * 4402be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 4403be168c0dSopenharmony_ci+ * 4404be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 4405be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 4406be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4407be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 4408be168c0dSopenharmony_ci+ * limitations under the License. 4409be168c0dSopenharmony_ci+ */ 4410be168c0dSopenharmony_ci+ 4411be168c0dSopenharmony_ci+#include "nnacl/infer/custom_gather_d_grad_v2_infer.h" 4412be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h" 4413be168c0dSopenharmony_ci+ 4414be168c0dSopenharmony_ci+int CustomGatherDGradV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 4415be168c0dSopenharmony_ci+ size_t outputs_size, OpParameter *parameter) { 4416be168c0dSopenharmony_ci+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); 4417be168c0dSopenharmony_ci+ if (check_ret != NNACL_OK) { 4418be168c0dSopenharmony_ci+ return check_ret; 4419be168c0dSopenharmony_ci+ } 4420be168c0dSopenharmony_ci+ const TensorC *input = inputs[0]; 4421be168c0dSopenharmony_ci+ TensorC *output = outputs[0]; 4422be168c0dSopenharmony_ci+ SetDataTypeFormat(output, input); 4423be168c0dSopenharmony_ci+ if (!InferFlag(inputs, inputs_size)) { 4424be168c0dSopenharmony_ci+ return NNACL_INFER_INVALID; 4425be168c0dSopenharmony_ci+ } 4426be168c0dSopenharmony_ci+ SetShapeTensor(output, input); 4427be168c0dSopenharmony_ci+ return NNACL_OK; 4428be168c0dSopenharmony_ci+} 4429be168c0dSopenharmony_ci+ 4430be168c0dSopenharmony_ci+REG_INFER(CustomGatherDGradV2, PrimType_Inner_CustomGatherDGradV2, CustomGatherDGradV2InferShape) 4431be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.h 4432be168c0dSopenharmony_cinew file mode 100644 4433be168c0dSopenharmony_ciindex 00000000..68d85d20 4434be168c0dSopenharmony_ci--- /dev/null 4435be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_gather_d_grad_v2_infer.h 4436be168c0dSopenharmony_ci@@ -0,0 +1,30 @@ 4437be168c0dSopenharmony_ci+/** 4438be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 4439be168c0dSopenharmony_ci+ * 4440be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 4441be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 4442be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 4443be168c0dSopenharmony_ci+ * 4444be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 4445be168c0dSopenharmony_ci+ * 4446be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 4447be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 4448be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4449be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 4450be168c0dSopenharmony_ci+ * limitations under the License. 4451be168c0dSopenharmony_ci+ */ 4452be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_INFER_H 4453be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_INFER_H 4454be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h" 4455be168c0dSopenharmony_ci+ 4456be168c0dSopenharmony_ci+#ifdef __cplusplus 4457be168c0dSopenharmony_ci+extern "C" { 4458be168c0dSopenharmony_ci+#endif 4459be168c0dSopenharmony_ci+ 4460be168c0dSopenharmony_ci+int CustomGatherDGradV2InferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 4461be168c0dSopenharmony_ci+ size_t outputs_size, OpParameter *parameter); 4462be168c0dSopenharmony_ci+ 4463be168c0dSopenharmony_ci+#ifdef __cplusplus 4464be168c0dSopenharmony_ci+} 4465be168c0dSopenharmony_ci+#endif 4466be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_GATHER_D_GRAD_V2_INFER_H 4467be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c 4468be168c0dSopenharmony_ciindex 9892ef0b..391e2522 100644 4469be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c 4470be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/lstm_infer.c 4471be168c0dSopenharmony_ci@@ -17,41 +17,81 @@ 4472be168c0dSopenharmony_ci #include "nnacl/infer/lstm_infer.h" 4473be168c0dSopenharmony_ci #include "nnacl/infer/infer_register.h" 4474be168c0dSopenharmony_ci 4475be168c0dSopenharmony_ci-static const int num_of_gates = 4; 4476be168c0dSopenharmony_ci-static const int no_of_recorde_values = 6; 4477be168c0dSopenharmony_ci+static const int no_of_recorde_values = 5; 4478be168c0dSopenharmony_ci 4479be168c0dSopenharmony_ci int CheckInputShapeValid(const TensorC *const *inputs, size_t inputs_size, const LstmParameter *parameter) { 4480be168c0dSopenharmony_ci+ if (inputs_size < C6NUM) { 4481be168c0dSopenharmony_ci+ return NNACL_INPUT_TENSOR_ERROR; 4482be168c0dSopenharmony_ci+ } 4483be168c0dSopenharmony_ci const TensorC *input = inputs[FIRST_INPUT]; 4484be168c0dSopenharmony_ci const TensorC *weight_i = inputs[SECOND_INPUT]; 4485be168c0dSopenharmony_ci const TensorC *weight_g = inputs[THIRD_INPUT]; 4486be168c0dSopenharmony_ci const TensorC *bias = inputs[FOURTH_INPUT]; 4487be168c0dSopenharmony_ci- const TensorC *cell = inputs[FIFTH_INPUT]; 4488be168c0dSopenharmony_ci+ const TensorC *hidden_init = inputs[FIFTH_INPUT]; 4489be168c0dSopenharmony_ci+ const TensorC *cell_init = inputs[SIXTH_INPUT]; 4490be168c0dSopenharmony_ci+ 4491be168c0dSopenharmony_ci+ NNACL_CHECK_TRUE_RET(input->shape_size_ == DIMENSION_3D && weight_i->shape_size_ == DIMENSION_3D && 4492be168c0dSopenharmony_ci+ weight_g->shape_size_ == DIMENSION_3D && bias->shape_size_ == DIMENSION_2D, 4493be168c0dSopenharmony_ci+ NNACL_ERR); 4494be168c0dSopenharmony_ci int batch = input->shape_[kNHWC_H]; 4495be168c0dSopenharmony_ci int input_size = input->shape_[kNHWC_W]; 4496be168c0dSopenharmony_ci int hidden_size = weight_i->shape_[kNHWC_H] / C4NUM; 4497be168c0dSopenharmony_ci- int project_size = inputs_size == C7NUM ? inputs[C6NUM]->shape_[kNHWC_H] : hidden_size; 4498be168c0dSopenharmony_ci- bool bidirectional = parameter->bidirectional_; 4499be168c0dSopenharmony_ci- if (input->shape_size_ != DIMENSION_3D || weight_i->shape_size_ != DIMENSION_3D) { 4500be168c0dSopenharmony_ci- return NNACL_ERR; 4501be168c0dSopenharmony_ci+ int out_size = hidden_size; 4502be168c0dSopenharmony_ci+ if (inputs_size == C7NUM) { 4503be168c0dSopenharmony_ci+ NNACL_CHECK_TRUE_RET(inputs[SEVENTH_INPUT]->shape_size_ == DIMENSION_3D, NNACL_INPUT_TENSOR_ERROR); 4504be168c0dSopenharmony_ci+ out_size = inputs[SEVENTH_INPUT]->shape_[kNHWC_H]; 4505be168c0dSopenharmony_ci } 4506be168c0dSopenharmony_ci+ bool bidirectional = parameter->bidirectional_; 4507be168c0dSopenharmony_ci int bidirection = bidirectional ? C2NUM : C1NUM; 4508be168c0dSopenharmony_ci NNACL_CHECK_TRUE_RET(weight_i->shape_[kNHWC_N] == bidirection && weight_i->shape_[kNHWC_H] == hidden_size * C4NUM && 4509be168c0dSopenharmony_ci weight_i->shape_[kNHWC_W] == input_size, 4510be168c0dSopenharmony_ci NNACL_ERR); 4511be168c0dSopenharmony_ci NNACL_CHECK_TRUE_RET(weight_g->shape_[kNHWC_N] == bidirection && weight_g->shape_[kNHWC_H] == hidden_size * C4NUM && 4512be168c0dSopenharmony_ci- weight_g->shape_[kNHWC_W] == project_size, 4513be168c0dSopenharmony_ci+ weight_g->shape_[kNHWC_W] == out_size, 4514be168c0dSopenharmony_ci NNACL_ERR); 4515be168c0dSopenharmony_ci NNACL_CHECK_TRUE_RET(bias->shape_[kNHWC_N] == bidirection && bias->shape_[kNHWC_H] == hidden_size * C8NUM, NNACL_ERR); 4516be168c0dSopenharmony_ci- if (!bidirectional && cell->shape_size_ == DIMENSION_2D) { 4517be168c0dSopenharmony_ci- NNACL_CHECK_TRUE_RET(cell->shape_[kNHWC_N] == batch && cell->shape_[kNHWC_H] == hidden_size, NNACL_ERR); 4518be168c0dSopenharmony_ci+ if (!bidirectional && hidden_init->shape_size_ == DIMENSION_2D) { 4519be168c0dSopenharmony_ci+ NNACL_CHECK_TRUE_RET(hidden_init->shape_[kNHWC_N] == batch && hidden_init->shape_[kNHWC_H] == out_size, NNACL_ERR); 4520be168c0dSopenharmony_ci } else { 4521be168c0dSopenharmony_ci- NNACL_CHECK_TRUE_RET( 4522be168c0dSopenharmony_ci- cell->shape_[kNHWC_N] == bidirection && cell->shape_[kNHWC_H] == batch && cell->shape_[kNHWC_W] == project_size, 4523be168c0dSopenharmony_ci- NNACL_ERR); 4524be168c0dSopenharmony_ci+ NNACL_CHECK_TRUE_RET(hidden_init->shape_size_ == DIMENSION_3D && hidden_init->shape_[kNHWC_N] == bidirection && 4525be168c0dSopenharmony_ci+ hidden_init->shape_[kNHWC_H] == batch && hidden_init->shape_[kNHWC_W] == out_size, 4526be168c0dSopenharmony_ci+ NNACL_ERR); 4527be168c0dSopenharmony_ci+ } 4528be168c0dSopenharmony_ci+ if (!bidirectional && cell_init->shape_size_ == DIMENSION_2D) { 4529be168c0dSopenharmony_ci+ NNACL_CHECK_TRUE_RET(cell_init->shape_[kNHWC_N] == batch && cell_init->shape_[kNHWC_H] == hidden_size, NNACL_ERR); 4530be168c0dSopenharmony_ci+ } else { 4531be168c0dSopenharmony_ci+ NNACL_CHECK_TRUE_RET(cell_init->shape_size_ == DIMENSION_3D && cell_init->shape_[kNHWC_N] == bidirection && 4532be168c0dSopenharmony_ci+ cell_init->shape_[kNHWC_H] == batch && cell_init->shape_[kNHWC_W] == hidden_size, 4533be168c0dSopenharmony_ci+ NNACL_ERR); 4534be168c0dSopenharmony_ci } 4535be168c0dSopenharmony_ci return NNACL_OK; 4536be168c0dSopenharmony_ci } 4537be168c0dSopenharmony_ci 4538be168c0dSopenharmony_ci+int InferFirstOutputMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { 4539be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs_size; ++i) { 4540be168c0dSopenharmony_ci+ if (inputs[i]->shape_size_ != C3NUM) { 4541be168c0dSopenharmony_ci+ return NNACL_INPUT_TENSOR_ERROR; 4542be168c0dSopenharmony_ci+ } 4543be168c0dSopenharmony_ci+ } 4544be168c0dSopenharmony_ci+ ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); 4545be168c0dSopenharmony_ci+ int out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; 4546be168c0dSopenharmony_ci+ output->shape_[THIRD_INPUT] = (param->bidirectional_ ? C2NUM : 1) * out_size; 4547be168c0dSopenharmony_ci+ return NNACL_OK; 4548be168c0dSopenharmony_ci+} 4549be168c0dSopenharmony_ci+ 4550be168c0dSopenharmony_ci+int InferFirstOutputNonMindir(const TensorC *const *inputs, size_t inputs_size, TensorC *output, LstmParameter *param) { 4551be168c0dSopenharmony_ci+ if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { 4552be168c0dSopenharmony_ci+ return NNACL_ERR; 4553be168c0dSopenharmony_ci+ } 4554be168c0dSopenharmony_ci+ ShapeSet(output->shape_, &output->shape_size_, inputs[0]->shape_, inputs[0]->shape_size_); 4555be168c0dSopenharmony_ci+ const TensorC *hidden_init = inputs[FIFTH_INPUT]; 4556be168c0dSopenharmony_ci+ int out_size = hidden_init->shape_[hidden_init->shape_size_ - 1]; 4557be168c0dSopenharmony_ci+ output->shape_[THIRD_INPUT] = out_size; 4558be168c0dSopenharmony_ci+ int direction = param->bidirectional_ ? C2NUM : C1NUM; 4559be168c0dSopenharmony_ci+ int ret = ShapeInsert(output->shape_, &output->shape_size_, 1, direction); 4560be168c0dSopenharmony_ci+ return ret; 4561be168c0dSopenharmony_ci+} 4562be168c0dSopenharmony_ci+ 4563be168c0dSopenharmony_ci int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 4564be168c0dSopenharmony_ci OpParameter *parameter) { 4565be168c0dSopenharmony_ci int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 4, 3); 4566be168c0dSopenharmony_ci@@ -60,9 +100,8 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o 4567be168c0dSopenharmony_ci } 4568be168c0dSopenharmony_ci 4569be168c0dSopenharmony_ci const TensorC *input = inputs[0]; 4570be168c0dSopenharmony_ci- const TensorC *weight_i = inputs[1]; 4571be168c0dSopenharmony_ci TensorC *output = outputs[0]; 4572be168c0dSopenharmony_ci- for (int i = 0; i < 3; i++) { 4573be168c0dSopenharmony_ci+ for (int i = 0; i < outputs_size; i++) { 4574be168c0dSopenharmony_ci SetDataTypeFormat(outputs[i], input); 4575be168c0dSopenharmony_ci } 4576be168c0dSopenharmony_ci 4577be168c0dSopenharmony_ci@@ -71,42 +110,31 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o 4578be168c0dSopenharmony_ci if (!InferFlag(inputs, inputs_size)) { 4579be168c0dSopenharmony_ci return NNACL_INFER_INVALID; 4580be168c0dSopenharmony_ci } 4581be168c0dSopenharmony_ci- int dir_multiplier = param->bidirectional_ ? 2 : 1; 4582be168c0dSopenharmony_ci- int out_shape[MAX_SHAPE_SIZE]; 4583be168c0dSopenharmony_ci- size_t out_shape_size = 0; 4584be168c0dSopenharmony_ci- int hidden_size = 1; 4585be168c0dSopenharmony_ci- int project_size = 1; 4586be168c0dSopenharmony_ci- ShapeSet(out_shape, &out_shape_size, input->shape_, input->shape_size_); 4587be168c0dSopenharmony_ci- if (inputs_size == DIMENSION_4D) { // if input from MINDIR 4588be168c0dSopenharmony_ci- hidden_size = weight_i->shape_[THIRD_INPUT]; 4589be168c0dSopenharmony_ci- project_size = hidden_size; 4590be168c0dSopenharmony_ci- out_shape[THIRD_INPUT] = hidden_size * dir_multiplier; 4591be168c0dSopenharmony_ci- } else { 4592be168c0dSopenharmony_ci- if (CheckInputShapeValid(inputs, inputs_size, param) != NNACL_OK) { 4593be168c0dSopenharmony_ci- return NNACL_ERR; 4594be168c0dSopenharmony_ci+ int hidden_size = 0; 4595be168c0dSopenharmony_ci+ int out_size = 0; 4596be168c0dSopenharmony_ci+ if (inputs_size == C4NUM) { 4597be168c0dSopenharmony_ci+ int ret = InferFirstOutputMindir(inputs, inputs_size, output, param); 4598be168c0dSopenharmony_ci+ if (ret != NNACL_OK) { 4599be168c0dSopenharmony_ci+ return ret; 4600be168c0dSopenharmony_ci } 4601be168c0dSopenharmony_ci- hidden_size = weight_i->shape_[1] / num_of_gates; 4602be168c0dSopenharmony_ci- project_size = inputs_size == C7NUM ? inputs[C6NUM]->shape_[kNHWC_H] : hidden_size; 4603be168c0dSopenharmony_ci- out_shape[THIRD_INPUT] = project_size; 4604be168c0dSopenharmony_ci- if (param->bidirectional_) { 4605be168c0dSopenharmony_ci- int ret = ShapeInsert(out_shape, &out_shape_size, 1, 2); 4606be168c0dSopenharmony_ci- if (ret != NNACL_OK) { 4607be168c0dSopenharmony_ci- return NNACL_ERR; 4608be168c0dSopenharmony_ci- } 4609be168c0dSopenharmony_ci- } else { 4610be168c0dSopenharmony_ci- int ret = ShapeInsert(out_shape, &out_shape_size, 1, 1); 4611be168c0dSopenharmony_ci- if (ret != NNACL_OK) { 4612be168c0dSopenharmony_ci- return NNACL_ERR; 4613be168c0dSopenharmony_ci- } 4614be168c0dSopenharmony_ci+ hidden_size = inputs[THIRD_INPUT]->shape_[THIRD_INPUT]; 4615be168c0dSopenharmony_ci+ out_size = inputs[SECOND_INPUT]->shape_[THIRD_INPUT]; 4616be168c0dSopenharmony_ci+ } else { 4617be168c0dSopenharmony_ci+ int ret = InferFirstOutputNonMindir(inputs, inputs_size, output, param); 4618be168c0dSopenharmony_ci+ if (ret != NNACL_OK) { 4619be168c0dSopenharmony_ci+ return ret; 4620be168c0dSopenharmony_ci } 4621be168c0dSopenharmony_ci+ hidden_size = inputs[SIXTH_INPUT]->shape_[inputs[SIXTH_INPUT]->shape_size_ - 1]; 4622be168c0dSopenharmony_ci+ out_size = inputs[FIFTH_INPUT]->shape_[inputs[FIFTH_INPUT]->shape_size_ - 1]; 4623be168c0dSopenharmony_ci } 4624be168c0dSopenharmony_ci- SetShapeArray(output, out_shape, out_shape_size); 4625be168c0dSopenharmony_ci+ 4626be168c0dSopenharmony_ci+ int dir_multiplier = param->bidirectional_ ? C2NUM : C1NUM; 4627be168c0dSopenharmony_ci int state_shape[MAX_SHAPE_SIZE]; 4628be168c0dSopenharmony_ci size_t state_shape_size = 0; 4629be168c0dSopenharmony_ci 4630be168c0dSopenharmony_ci ShapeSet(state_shape, &state_shape_size, input->shape_, input->shape_size_); 4631be168c0dSopenharmony_ci state_shape[FIRST_INPUT] = dir_multiplier; 4632be168c0dSopenharmony_ci- state_shape[THIRD_INPUT] = project_size; 4633be168c0dSopenharmony_ci+ state_shape[THIRD_INPUT] = out_size; 4634be168c0dSopenharmony_ci SetShapeArray(outputs[SECOND_INPUT], state_shape, state_shape_size); 4635be168c0dSopenharmony_ci state_shape[THIRD_INPUT] = hidden_size; 4636be168c0dSopenharmony_ci SetShapeArray(outputs[THIRD_INPUT], state_shape, state_shape_size); 4637be168c0dSopenharmony_ci@@ -116,11 +144,9 @@ int LstmInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o 4638be168c0dSopenharmony_ci const size_t intermediate_states_shape_size = 1; 4639be168c0dSopenharmony_ci int batch_size = input->shape_[SECOND_INPUT]; 4640be168c0dSopenharmony_ci int seq_len = input->shape_[FIRST_INPUT]; 4641be168c0dSopenharmony_ci- intermediate_states_shape[FIRST_INPUT] = no_of_recorde_values * batch_size * hidden_size * seq_len * dir_multiplier; 4642be168c0dSopenharmony_ci- SetDataTypeFormat(outputs[FOURTH_INPUT], inputs[FIRST_INPUT]); 4643be168c0dSopenharmony_ci+ intermediate_states_shape[FIRST_INPUT] = 4644be168c0dSopenharmony_ci+ batch_size * seq_len * dir_multiplier * (out_size + no_of_recorde_values * hidden_size); 4645be168c0dSopenharmony_ci SetShapeArray(outputs[FOURTH_INPUT], intermediate_states_shape, intermediate_states_shape_size); 4646be168c0dSopenharmony_ci- 4647be168c0dSopenharmony_ci- SetDataTypeFormat(outputs[FIFTH_INPUT], inputs[FIRST_INPUT]); 4648be168c0dSopenharmony_ci SetShapeArray(outputs[FIFTH_INPUT], state_shape, state_shape_size); 4649be168c0dSopenharmony_ci } 4650be168c0dSopenharmony_ci 4651be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 4652be168c0dSopenharmony_ciindex 287e9de3..3c192df7 100644 4653be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 4654be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 4655be168c0dSopenharmony_ci@@ -33,12 +33,14 @@ int CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size 4656be168c0dSopenharmony_ci } 4657be168c0dSopenharmony_ci ShapePush(out_shape, out_shape_size, data[i]); 4658be168c0dSopenharmony_ci } 4659be168c0dSopenharmony_ci- 4660be168c0dSopenharmony_ci+ if (size == 0) { 4661be168c0dSopenharmony_ci+ return NNACL_ERR; 4662be168c0dSopenharmony_ci+ } 4663be168c0dSopenharmony_ci if ((int)(data[index]) == -1) { 4664be168c0dSopenharmony_ci if (index >= MAX_SHAPE_SIZE) { 4665be168c0dSopenharmony_ci return NNACL_ERR; 4666be168c0dSopenharmony_ci } 4667be168c0dSopenharmony_ci- out_shape[index] = size == 0 ? 0 : input_count / size; 4668be168c0dSopenharmony_ci+ out_shape[index] = input_count / size; 4669be168c0dSopenharmony_ci } 4670be168c0dSopenharmony_ci return NNACL_OK; 4671be168c0dSopenharmony_ci } 4672be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h 4673be168c0dSopenharmony_ciindex 377993cd..6a933785 100644 4674be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h 4675be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions.h 4676be168c0dSopenharmony_ci@@ -308,7 +308,7 @@ static inline float simd_exp32_f32(float data) { 4677be168c0dSopenharmony_ci #else 4678be168c0dSopenharmony_ci data = MS_MAX32_F32(-88.0f, MS_MIN32_F32(88.0f, data)); // clamp(-88, 88) 4679be168c0dSopenharmony_ci #endif 4680be168c0dSopenharmony_ci- int integer = floor(data * 1.44269504088896341f + 0.5f); 4681be168c0dSopenharmony_ci+ int integer = data / param[0]; 4682be168c0dSopenharmony_ci float decimal = data - integer * param[0]; 4683be168c0dSopenharmony_ci fi int_exp; 4684be168c0dSopenharmony_ci int_exp.i = (integer + 127) << 23; // Approximate calculation : (integer + 127) << 23 4685be168c0dSopenharmony_ci@@ -324,14 +324,19 @@ static inline void simd_exp32(float src, float *dst) { 4686be168c0dSopenharmony_ci int i; 4687be168c0dSopenharmony_ci } fi; 4688be168c0dSopenharmony_ci static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f}; // log(2.0f) 4689be168c0dSopenharmony_ci- src = MS_MAX32_F32(-88.0f, MS_MIN32_F32(88.0f, src)); // clamp(-88.0f, 88.0f) 4690be168c0dSopenharmony_ci+ src = MS_MAX32_F32(-87.3365478515625f, MS_MIN32_F32(88.72283935546875f, src)); // clamp(logf(FLT_MIN), logf(FLT_MAX)) 4691be168c0dSopenharmony_ci int integer = floor(src * 1.44269504088896341f + 0.5f); 4692be168c0dSopenharmony_ci float decimal = src - integer * param[0]; 4693be168c0dSopenharmony_ci fi int_exp; 4694be168c0dSopenharmony_ci- int_exp.i = (integer + 127) << 23; // integer num approximate calculation : (x + 127) << 23 4695be168c0dSopenharmony_ci+ const int shift = 23; 4696be168c0dSopenharmony_ci+ const int bias = 126; 4697be168c0dSopenharmony_ci+ const float factor = 2; 4698be168c0dSopenharmony_ci+ // 2^n * exp(r) should be counted 2 * 2^(n - 1) * exp(r), 4699be168c0dSopenharmony_ci+ // because n may be 128, and it is not representable by fp32. 4700be168c0dSopenharmony_ci+ int_exp.i = (integer + bias) << shift; // integer num 2^(n - 1) approximate calculation : ((x - 1) + 127) << 23 4701be168c0dSopenharmony_ci const float decimal_exp = 4702be168c0dSopenharmony_ci 1.0f + decimal * (1.0f + decimal * (0.5f + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); 4703be168c0dSopenharmony_ci- *dst = int_exp.f * decimal_exp; 4704be168c0dSopenharmony_ci+ *dst = factor * int_exp.f * decimal_exp; 4705be168c0dSopenharmony_ci } 4706be168c0dSopenharmony_ci 4707be168c0dSopenharmony_ci // define (float/int) data 4708be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h 4709be168c0dSopenharmony_ciindex a29c4dbb..94ed4b89 100644 4710be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h 4711be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_instructions_fp16.h 4712be168c0dSopenharmony_ci@@ -94,9 +94,13 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) { 4713be168c0dSopenharmony_ci 4714be168c0dSopenharmony_ci #define MS_FLOAT16X8 float16x8_t 4715be168c0dSopenharmony_ci #define MS_FLOAT16X4 float16x4_t 4716be168c0dSopenharmony_ci+#define MS_FLOAT16X4X4 float16x4x4_t 4717be168c0dSopenharmony_ci+#define MS_FLOAT16X4X2 float16x4x2_t 4718be168c0dSopenharmony_ci #define MS_MOVQ_F16 vmovq_n_f16 4719be168c0dSopenharmony_ci #define MS_STQ_F16(ptr, val) vst1q_f16(ptr, val) 4720be168c0dSopenharmony_ci #define MS_ST_F16 vst1_f16 4721be168c0dSopenharmony_ci+#define MS_ST2_F16 vst2_f16 4722be168c0dSopenharmony_ci+#define MS_ST4_F16 vst4_f16 4723be168c0dSopenharmony_ci #define MS_MINQ_F16 vminq_f16 4724be168c0dSopenharmony_ci #define MS_MAXQ_F16 vmaxq_f16 4725be168c0dSopenharmony_ci #define MS_LDQ_F16(ptr) vld1q_f16(ptr) 4726be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h 4727be168c0dSopenharmony_ciindex c4bc34d9..fb38b452 100644 4728be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h 4729be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/intrinsics/ms_simd_neon_instructions.h 4730be168c0dSopenharmony_ci@@ -25,6 +25,8 @@ 4731be168c0dSopenharmony_ci #define MS128_F32_GETI(src, i) src[i] 4732be168c0dSopenharmony_ci #define MS_FLOAT32X4 float32x4_t 4733be168c0dSopenharmony_ci #define MS_FLOAT128_F32 float32x4_t 4734be168c0dSopenharmony_ci+#define MS_FLOAT32X4X2 float32x4x2_t 4735be168c0dSopenharmony_ci+#define MS_FLOAT32X4X4 float32x4x4_t 4736be168c0dSopenharmony_ci #define MS_INT32X4 int32x4_t 4737be168c0dSopenharmony_ci #define MS_INT128_EPI32 int32x4_t 4738be168c0dSopenharmony_ci #define MS_UINT32X4 uint32x4_t 4739be168c0dSopenharmony_ci@@ -222,29 +224,30 @@ static inline MS_FLOAT32X4 VexpFp32(MS_FLOAT32X4 input) { 4740be168c0dSopenharmony_ci {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, 4741be168c0dSopenharmony_ci {0.5f, 0.5f, 0.5f, 0.5f}, 4742be168c0dSopenharmony_ci {1.0f, 1.0f, 1.0f, 1.0f}, 4743be168c0dSopenharmony_ci- {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}}; 4744be168c0dSopenharmony_ci+ {1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f, 1.44269504088896341f}, 4745be168c0dSopenharmony_ci+ {2.0f, 2.0f, 2.0f, 2.0f}}; 4746be168c0dSopenharmony_ci static MS_FLOAT32X4 negative_flag = {-0.0f, -0.0f, -0.0f, -0.0f}; 4747be168c0dSopenharmony_ci 4748be168c0dSopenharmony_ci MS_INT32X4 integer = 4749be168c0dSopenharmony_ci MS_CVTQPS_EPI32(MS_FMADD128_F32(input, param[6], MS_OR128_F32(MS_AND128_F32(input, negative_flag), param[4]))); 4750be168c0dSopenharmony_ci MS_FLOAT32X4 decimal = MS_SUBQ_F32(input, MS_MULQ_F32(MS_CVTQEPI32_PS(integer), param[0])); 4751be168c0dSopenharmony_ci- MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23); 4752be168c0dSopenharmony_ci+ MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(126)), 23); 4753be168c0dSopenharmony_ci MS_FLOAT32X4 tmp = MS_MULQ_F32(decimal, (MS_ADDQ_F32(param[2], MS_MULQ_F32(decimal, param[1])))); 4754be168c0dSopenharmony_ci tmp = MS_MULQ_F32(decimal, MS_ADDQ_F32(param[4], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[3], tmp)))); 4755be168c0dSopenharmony_ci MS_FLOAT32X4 decimal_exp = MS_ADDQ_F32(param[5], MS_MULQ_F32(decimal, MS_ADDQ_F32(param[5], tmp))); 4756be168c0dSopenharmony_ci- return MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp)); 4757be168c0dSopenharmony_ci+ return MS_MULQ_F32(param[7], MS_MULQ_F32(decimal_exp, MS_CAST128_F32_S32(int_exp))); 4758be168c0dSopenharmony_ci } 4759be168c0dSopenharmony_ci 4760be168c0dSopenharmony_ci static inline void simd_exp128(MS_FLOAT32X4 input, float *dst) { 4761be168c0dSopenharmony_ci- static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f}; 4762be168c0dSopenharmony_ci- static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f}; 4763be168c0dSopenharmony_ci+ static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; 4764be168c0dSopenharmony_ci+ static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; 4765be168c0dSopenharmony_ci input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); 4766be168c0dSopenharmony_ci MS_STQ_F32(dst, VexpFp32(input)); 4767be168c0dSopenharmony_ci } 4768be168c0dSopenharmony_ci 4769be168c0dSopenharmony_ci static inline MS_FLOAT32X4 simd_exp128_f32(MS_FLOAT32X4 input) { 4770be168c0dSopenharmony_ci- static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f}; 4771be168c0dSopenharmony_ci- static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f}; 4772be168c0dSopenharmony_ci+ static MS_FLOAT32X4 maxv = {88.72283935546875f, 88.72283935546875f, 88.72283935546875f, 88.72283935546875f}; 4773be168c0dSopenharmony_ci+ static MS_FLOAT32X4 minv = {-87.3365478515625f, -87.3365478515625f, -87.3365478515625f, -87.3365478515625f}; 4774be168c0dSopenharmony_ci input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); 4775be168c0dSopenharmony_ci return VexpFp32(input); 4776be168c0dSopenharmony_ci } 4777be168c0dSopenharmony_ci@@ -286,18 +289,6 @@ static inline MS_FLOAT32X4 MS_TANHX4_F32(MS_FLOAT32X4 src) { 4778be168c0dSopenharmony_ci return res; 4779be168c0dSopenharmony_ci } 4780be168c0dSopenharmony_ci 4781be168c0dSopenharmony_ci-static inline MS_FLOAT128_F32 SIMD_SIGN128_F32(MS_FLOAT128_F32 src) { 4782be168c0dSopenharmony_ci- MS_FLOAT128_F32 abs_src = MS_ABS128_F32(src); 4783be168c0dSopenharmony_ci- MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); 4784be168c0dSopenharmony_ci- MS_FLOAT128_F32 sign = MS_DIV128_F32(abs_src, src_tmp); 4785be168c0dSopenharmony_ci- return sign; 4786be168c0dSopenharmony_ci-} 4787be168c0dSopenharmony_ci- 4788be168c0dSopenharmony_ci-static inline MS_FLOAT128_F32 SIMD_SIGNABS128_F32(MS_FLOAT128_F32 src, MS_FLOAT128_F32 abs_src) { 4789be168c0dSopenharmony_ci- MS_FLOAT128_F32 src_tmp = MS_OR128_F32(src, MS_MOV128_F32(1.0f)); 4790be168c0dSopenharmony_ci- return MS_DIV128_F32(abs_src, src_tmp); 4791be168c0dSopenharmony_ci-} 4792be168c0dSopenharmony_ci- 4793be168c0dSopenharmony_ci #define MS_TANH128_F32 MS_TANHX4_F32 4794be168c0dSopenharmony_ci 4795be168c0dSopenharmony_ci static inline MS_FLOAT32X4 MS128_ERF_F32(MS_FLOAT32X4 src) { 4796be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h 4797be168c0dSopenharmony_ciindex 9ecd8409..5baf10fa 100644 4798be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h 4799be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/lstm_parameter.h 4800be168c0dSopenharmony_ci@@ -25,6 +25,7 @@ typedef struct LstmParameter { 4801be168c0dSopenharmony_ci int input_size_; 4802be168c0dSopenharmony_ci int hidden_size_; 4803be168c0dSopenharmony_ci int project_size_; 4804be168c0dSopenharmony_ci+ int output_size_; 4805be168c0dSopenharmony_ci int seq_len_; 4806be168c0dSopenharmony_ci int batch_; 4807be168c0dSopenharmony_ci // other parameter 4808be168c0dSopenharmony_ci@@ -36,6 +37,8 @@ typedef struct LstmParameter { 4809be168c0dSopenharmony_ci int input_col_align_; 4810be168c0dSopenharmony_ci int state_row_align_; 4811be168c0dSopenharmony_ci int state_col_align_; 4812be168c0dSopenharmony_ci+ int proj_col_align_; 4813be168c0dSopenharmony_ci+ bool has_bias_; 4814be168c0dSopenharmony_ci } LstmParameter; 4815be168c0dSopenharmony_ci 4816be168c0dSopenharmony_ci #endif // NNACL_LSTM_PARAMETER_H_ 4817be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 4818be168c0dSopenharmony_ciindex 895f7e3d..bd0d152c 100644 4819be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 4820be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 4821be168c0dSopenharmony_ci@@ -562,6 +562,7 @@ enum PrimType { 4822be168c0dSopenharmony_ci PrimType_Inner_CustomMaskedFill = 10014, 4823be168c0dSopenharmony_ci PrimType_Inner_CustomTensorScatterMax = 10015, 4824be168c0dSopenharmony_ci PrimType_Inner_CustomIsInf = 10016, 4825be168c0dSopenharmony_ci+ PrimType_Inner_CustomGatherDGradV2 = 10017, 4826be168c0dSopenharmony_ci PrimType_InnerOpMax, 4827be168c0dSopenharmony_ci PrimType_InnerOpMin = PrimType_Inner_ToFormat 4828be168c0dSopenharmony_ci }; 4829be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc 4830be168c0dSopenharmony_ciindex 2301be8c..342ffb7f 100644 4831be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/threadpool.cc 4832be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/threadpool.cc 4833be168c0dSopenharmony_ci@@ -53,7 +53,7 @@ Worker::~Worker() { 4834be168c0dSopenharmony_ci void Worker::CreateThread() { thread_ = std::make_unique<std::thread>(&Worker::Run, this); } 4835be168c0dSopenharmony_ci 4836be168c0dSopenharmony_ci void Worker::ReinitAfterFork() { 4837be168c0dSopenharmony_ci- THREAD_INFO("worker %ld recreate thread after fork in child process", worker_id_); 4838be168c0dSopenharmony_ci+ THREAD_INFO("worker %zu recreate thread after fork in child process", worker_id_); 4839be168c0dSopenharmony_ci if (cond_var_ != nullptr) { 4840be168c0dSopenharmony_ci (void)cond_var_.release(); 4841be168c0dSopenharmony_ci cond_var_ = std::make_unique<std::condition_variable>(); 4842be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/base_operator.h b/mindspore/core/ops/base_operator.h 4843be168c0dSopenharmony_ciindex 811a6000..23652e8e 100644 4844be168c0dSopenharmony_ci--- a/mindspore/core/ops/base_operator.h 4845be168c0dSopenharmony_ci+++ b/mindspore/core/ops/base_operator.h 4846be168c0dSopenharmony_ci@@ -75,7 +75,7 @@ class MIND_API OperatorRegisterHelper { 4847be168c0dSopenharmony_ci public: 4848be168c0dSopenharmony_ci OperatorRegisterHelper(const std::string &kname, const OperatorDefineFunc &fn) { 4849be168c0dSopenharmony_ci OperatorRegister::GetInstance().SetOperatorMap(kname, fn); 4850be168c0dSopenharmony_ci- (void)id_; // make compiler happy on macos 4851be168c0dSopenharmony_ci+ // (void)id_; // make compiler happy on macos 4852be168c0dSopenharmony_ci } 4853be168c0dSopenharmony_ci 4854be168c0dSopenharmony_ci ~OperatorRegisterHelper() = default; 4855be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/gather_d_grad_v2.cc b/mindspore/core/ops/grad/gather_d_grad_v2.cc 4856be168c0dSopenharmony_ciindex 3ce5f887..c999ca88 100644 4857be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/gather_d_grad_v2.cc 4858be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/gather_d_grad_v2.cc 4859be168c0dSopenharmony_ci@@ -75,6 +75,11 @@ TypePtr GatherDGradV2InferType(const PrimitivePtr &prim, const std::vector<Abstr 4860be168c0dSopenharmony_ci } 4861be168c0dSopenharmony_ci } // namespace 4862be168c0dSopenharmony_ci 4863be168c0dSopenharmony_ci+int64_t GatherDGradV2::get_dim() const { 4864be168c0dSopenharmony_ci+ auto value_ptr = this->GetAttr(kDim); 4865be168c0dSopenharmony_ci+ return GetValue<int64_t>(value_ptr); 4866be168c0dSopenharmony_ci+} 4867be168c0dSopenharmony_ci+ 4868be168c0dSopenharmony_ci AbstractBasePtr GatherDGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, 4869be168c0dSopenharmony_ci const std::vector<AbstractBasePtr> &input_args) { 4870be168c0dSopenharmony_ci auto infer_type = GatherDGradV2InferType(primitive, input_args); 4871be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/gather_d_grad_v2.h b/mindspore/core/ops/grad/gather_d_grad_v2.h 4872be168c0dSopenharmony_ciindex 94274e3b..40a6e412 100644 4873be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/gather_d_grad_v2.h 4874be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/gather_d_grad_v2.h 4875be168c0dSopenharmony_ci@@ -25,6 +25,7 @@ class MIND_API GatherDGradV2 : public BaseOperator { 4876be168c0dSopenharmony_ci public: 4877be168c0dSopenharmony_ci MIND_API_BASE_MEMBER(GatherDGradV2); 4878be168c0dSopenharmony_ci GatherDGradV2() : BaseOperator(kNameGatherDGradV2) { InitIOName({"x", "dim", "index", "grad"}, {"output"}); } 4879be168c0dSopenharmony_ci+ int64_t get_dim() const; 4880be168c0dSopenharmony_ci }; 4881be168c0dSopenharmony_ci MIND_API abstract::AbstractBasePtr GatherDGradV2Infer(const abstract::AnalysisEnginePtr &, 4882be168c0dSopenharmony_ci const PrimitivePtr &primitive, 4883be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/lstm_grad.cc b/mindspore/core/ops/grad/lstm_grad.cc 4884be168c0dSopenharmony_ciindex d51c4882..c25e0379 100644 4885be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/lstm_grad.cc 4886be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/lstm_grad.cc 4887be168c0dSopenharmony_ci@@ -98,15 +98,22 @@ void LSTMGrad::set_zoneout_hidden(float zoneout_hidden) { 4888be168c0dSopenharmony_ci 4889be168c0dSopenharmony_ci float LSTMGrad::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } 4890be168c0dSopenharmony_ci 4891be168c0dSopenharmony_ci+void LSTMGrad::set_proj_size(const int64_t proj_size) { 4892be168c0dSopenharmony_ci+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 4893be168c0dSopenharmony_ci+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 4894be168c0dSopenharmony_ci+} 4895be168c0dSopenharmony_ci+int64_t LSTMGrad::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 4896be168c0dSopenharmony_ci+ 4897be168c0dSopenharmony_ci void LSTMGrad::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 4898be168c0dSopenharmony_ci- const float dropout, const bool bidirectional, const float zoneout_cell, 4899be168c0dSopenharmony_ci- const float zoneout_hidden) { 4900be168c0dSopenharmony_ci+ const float dropout, const bool bidirectional, const float zoneout_cell, const float zoneout_hidden, 4901be168c0dSopenharmony_ci+ const int64_t proj_size) { 4902be168c0dSopenharmony_ci this->set_input_size(input_size); 4903be168c0dSopenharmony_ci this->set_hidden_size(hidden_size); 4904be168c0dSopenharmony_ci this->set_num_layers(num_layers); 4905be168c0dSopenharmony_ci this->set_has_bias(has_bias); 4906be168c0dSopenharmony_ci this->set_dropout(dropout); 4907be168c0dSopenharmony_ci this->set_bidirectional(bidirectional); 4908be168c0dSopenharmony_ci+ this->set_proj_size(proj_size); 4909be168c0dSopenharmony_ci if (bidirectional) { 4910be168c0dSopenharmony_ci constexpr int k2Directions = 2; 4911be168c0dSopenharmony_ci this->set_num_directions(k2Directions); 4912be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/lstm_grad.h b/mindspore/core/ops/grad/lstm_grad.h 4913be168c0dSopenharmony_ciindex 73272d55..f6eba32c 100644 4914be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/lstm_grad.h 4915be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/lstm_grad.h 4916be168c0dSopenharmony_ci@@ -31,7 +31,7 @@ class MIND_API LSTMGrad : public BaseOperator { 4917be168c0dSopenharmony_ci LSTMGrad() : BaseOperator(kNameLSTMGrad) {} 4918be168c0dSopenharmony_ci void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 4919be168c0dSopenharmony_ci const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, 4920be168c0dSopenharmony_ci- const float zoneout_hidden = 0.0f); 4921be168c0dSopenharmony_ci+ const float zoneout_hidden = 0.0f, const int64_t proj_size = 0); 4922be168c0dSopenharmony_ci void set_input_size(const int64_t input_size); 4923be168c0dSopenharmony_ci int64_t get_input_size() const; 4924be168c0dSopenharmony_ci void set_hidden_size(const int64_t hidden_size); 4925be168c0dSopenharmony_ci@@ -51,6 +51,8 @@ class MIND_API LSTMGrad : public BaseOperator { 4926be168c0dSopenharmony_ci void set_zoneout_hidden(float zoneout_hidden); 4927be168c0dSopenharmony_ci float get_zoneout_hidden() const; 4928be168c0dSopenharmony_ci int64_t get_good_ld(const int64_t dim, const int64_t type_size); 4929be168c0dSopenharmony_ci+ void set_proj_size(const int64_t proj_size); 4930be168c0dSopenharmony_ci+ int64_t get_proj_size() const; 4931be168c0dSopenharmony_ci }; 4932be168c0dSopenharmony_ci } // namespace ops 4933be168c0dSopenharmony_ci } // namespace mindspore 4934be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/lstm_grad_data.cc b/mindspore/core/ops/grad/lstm_grad_data.cc 4935be168c0dSopenharmony_ciindex 573d26f4..2b25282c 100644 4936be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/lstm_grad_data.cc 4937be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/lstm_grad_data.cc 4938be168c0dSopenharmony_ci@@ -91,15 +91,23 @@ void LSTMGradData::set_zoneout_hidden(float zoneout_hidden) { 4939be168c0dSopenharmony_ci 4940be168c0dSopenharmony_ci float LSTMGradData::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } 4941be168c0dSopenharmony_ci 4942be168c0dSopenharmony_ci+void LSTMGradData::set_proj_size(const int64_t proj_size) { 4943be168c0dSopenharmony_ci+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 4944be168c0dSopenharmony_ci+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 4945be168c0dSopenharmony_ci+} 4946be168c0dSopenharmony_ci+ 4947be168c0dSopenharmony_ci+int64_t LSTMGradData::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 4948be168c0dSopenharmony_ci+ 4949be168c0dSopenharmony_ci void LSTMGradData::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, 4950be168c0dSopenharmony_ci const bool has_bias, const float dropout, const bool bidirectional, const float zoneout_cell, 4951be168c0dSopenharmony_ci- const float zoneout_hidden) { 4952be168c0dSopenharmony_ci+ const float zoneout_hidden, const int64_t proj_size) { 4953be168c0dSopenharmony_ci this->set_input_size(input_size); 4954be168c0dSopenharmony_ci this->set_hidden_size(hidden_size); 4955be168c0dSopenharmony_ci this->set_num_layers(num_layers); 4956be168c0dSopenharmony_ci this->set_has_bias(has_bias); 4957be168c0dSopenharmony_ci this->set_dropout(dropout); 4958be168c0dSopenharmony_ci this->set_bidirectional(bidirectional); 4959be168c0dSopenharmony_ci+ this->set_proj_size(proj_size); 4960be168c0dSopenharmony_ci if (bidirectional) { 4961be168c0dSopenharmony_ci constexpr int k2Directions = 2; 4962be168c0dSopenharmony_ci this->set_num_directions(k2Directions); 4963be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/lstm_grad_data.h b/mindspore/core/ops/grad/lstm_grad_data.h 4964be168c0dSopenharmony_ciindex adcf2ee7..f93e3260 100644 4965be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/lstm_grad_data.h 4966be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/lstm_grad_data.h 4967be168c0dSopenharmony_ci@@ -32,7 +32,7 @@ class MIND_API LSTMGradData : public BaseOperator { 4968be168c0dSopenharmony_ci LSTMGradData() : BaseOperator(kNameLSTMGradData) {} 4969be168c0dSopenharmony_ci void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 4970be168c0dSopenharmony_ci const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, 4971be168c0dSopenharmony_ci- const float zoneout_hidden = 0.0f); 4972be168c0dSopenharmony_ci+ const float zoneout_hidden = 0.0f, const int64_t proj_size = 0); 4973be168c0dSopenharmony_ci void set_input_size(const int64_t input_size); 4974be168c0dSopenharmony_ci int64_t get_input_size() const; 4975be168c0dSopenharmony_ci void set_hidden_size(const int64_t hidden_size); 4976be168c0dSopenharmony_ci@@ -52,6 +52,8 @@ class MIND_API LSTMGradData : public BaseOperator { 4977be168c0dSopenharmony_ci void set_zoneout_hidden(float zoneout_hidden); 4978be168c0dSopenharmony_ci float get_zoneout_hidden() const; 4979be168c0dSopenharmony_ci int64_t get_good_ld(const int64_t dim, const int64_t type_size); 4980be168c0dSopenharmony_ci+ void set_proj_size(const int64_t proj_size); 4981be168c0dSopenharmony_ci+ int64_t get_proj_size() const; 4982be168c0dSopenharmony_ci }; 4983be168c0dSopenharmony_ci MIND_API abstract::AbstractBasePtr LstmGradDataInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, 4984be168c0dSopenharmony_ci const std::vector<abstract::AbstractBasePtr> &input_args); 4985be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/lstm_grad_weight.cc b/mindspore/core/ops/grad/lstm_grad_weight.cc 4986be168c0dSopenharmony_ciindex 22b519c3..ce0aca94 100644 4987be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/lstm_grad_weight.cc 4988be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/lstm_grad_weight.cc 4989be168c0dSopenharmony_ci@@ -88,15 +88,23 @@ void LSTMGradWeight::set_zoneout_hidden(float zoneout_hidden) { 4990be168c0dSopenharmony_ci 4991be168c0dSopenharmony_ci float LSTMGradWeight::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } 4992be168c0dSopenharmony_ci 4993be168c0dSopenharmony_ci+void LSTMGradWeight::set_proj_size(const int64_t proj_size) { 4994be168c0dSopenharmony_ci+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 4995be168c0dSopenharmony_ci+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 4996be168c0dSopenharmony_ci+} 4997be168c0dSopenharmony_ci+ 4998be168c0dSopenharmony_ci+int64_t LSTMGradWeight::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 4999be168c0dSopenharmony_ci+ 5000be168c0dSopenharmony_ci void LSTMGradWeight::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, 5001be168c0dSopenharmony_ci const bool has_bias, const float dropout, const bool bidirectional, const float zoneout_cell, 5002be168c0dSopenharmony_ci- const float zoneout_hidden) { 5003be168c0dSopenharmony_ci+ const float zoneout_hidden, const int64_t proj_size) { 5004be168c0dSopenharmony_ci this->set_input_size(input_size); 5005be168c0dSopenharmony_ci this->set_hidden_size(hidden_size); 5006be168c0dSopenharmony_ci this->set_num_layers(num_layers); 5007be168c0dSopenharmony_ci this->set_has_bias(has_bias); 5008be168c0dSopenharmony_ci this->set_dropout(dropout); 5009be168c0dSopenharmony_ci this->set_bidirectional(bidirectional); 5010be168c0dSopenharmony_ci+ this->set_proj_size(proj_size); 5011be168c0dSopenharmony_ci if (bidirectional) { 5012be168c0dSopenharmony_ci constexpr int k2Directions = 2; 5013be168c0dSopenharmony_ci this->set_num_directions(k2Directions); 5014be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/grad/lstm_grad_weight.h b/mindspore/core/ops/grad/lstm_grad_weight.h 5015be168c0dSopenharmony_ciindex c2ca6b5e..add816d3 100644 5016be168c0dSopenharmony_ci--- a/mindspore/core/ops/grad/lstm_grad_weight.h 5017be168c0dSopenharmony_ci+++ b/mindspore/core/ops/grad/lstm_grad_weight.h 5018be168c0dSopenharmony_ci@@ -32,7 +32,7 @@ class MIND_API LSTMGradWeight : public BaseOperator { 5019be168c0dSopenharmony_ci LSTMGradWeight() : BaseOperator(kNameLSTMGradWeight) {} 5020be168c0dSopenharmony_ci void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, 5021be168c0dSopenharmony_ci const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, 5022be168c0dSopenharmony_ci- const float zoneout_hidden = 0.0f); 5023be168c0dSopenharmony_ci+ const float zoneout_hidden = 0.0f, const int64_t proj_size = 0); 5024be168c0dSopenharmony_ci void set_input_size(const int64_t input_size); 5025be168c0dSopenharmony_ci int64_t get_input_size() const; 5026be168c0dSopenharmony_ci void set_hidden_size(const int64_t hidden_size); 5027be168c0dSopenharmony_ci@@ -52,6 +52,8 @@ class MIND_API LSTMGradWeight : public BaseOperator { 5028be168c0dSopenharmony_ci void set_zoneout_hidden(float zoneout_hidden); 5029be168c0dSopenharmony_ci float get_zoneout_hidden() const; 5030be168c0dSopenharmony_ci int64_t get_good_ld(const int64_t dim, const int64_t type_size); 5031be168c0dSopenharmony_ci+ void set_proj_size(const int64_t proj_size); 5032be168c0dSopenharmony_ci+ int64_t get_proj_size() const; 5033be168c0dSopenharmony_ci }; 5034be168c0dSopenharmony_ci MIND_API abstract::AbstractBasePtr LstmGradWeightInfer(const abstract::AnalysisEnginePtr &, 5035be168c0dSopenharmony_ci const PrimitivePtr &primitive, 5036be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/lstm.cc b/mindspore/core/ops/lstm.cc 5037be168c0dSopenharmony_ciindex 43b9241c..937207df 100644 5038be168c0dSopenharmony_ci--- a/mindspore/core/ops/lstm.cc 5039be168c0dSopenharmony_ci+++ b/mindspore/core/ops/lstm.cc 5040be168c0dSopenharmony_ci@@ -68,6 +68,7 @@ abstract::TupleShapePtr LSTMInferShape(const PrimitivePtr &primitive, const std: 5041be168c0dSopenharmony_ci int64_t input_x_size = GetValue<int64_t>(primitive->GetAttr(kInput_size)); 5042be168c0dSopenharmony_ci int64_t num_layers = GetValue<int64_t>(primitive->GetAttr(kNumLayers)); 5043be168c0dSopenharmony_ci bool bidirectional = GetValue<bool>(primitive->GetAttr(kBidirectional)); 5044be168c0dSopenharmony_ci+ int64_t proj_size = GetValue<int64_t>(primitive->GetAttr(kProjection_size)); 5045be168c0dSopenharmony_ci int64_t num_directions = 1; 5046be168c0dSopenharmony_ci if (bidirectional) { 5047be168c0dSopenharmony_ci num_directions = 2; 5048be168c0dSopenharmony_ci@@ -90,7 +91,8 @@ abstract::TupleShapePtr LSTMInferShape(const PrimitivePtr &primitive, const std: 5049be168c0dSopenharmony_ci (void)CheckAndConvertUtils::CheckInteger("h_shape[1]", h_input_shape[1], kEqual, x_input_shape[1], prim_name); 5050be168c0dSopenharmony_ci } 5051be168c0dSopenharmony_ci 5052be168c0dSopenharmony_ci- std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], hidden_size * num_directions}; 5053be168c0dSopenharmony_ci+ auto real_hidden_size = proj_size > 0 ? proj_size : hidden_size; 5054be168c0dSopenharmony_ci+ std::vector<int64_t> y_shape = {x_input_shape[0], x_input_shape[1], real_hidden_size * num_directions}; 5055be168c0dSopenharmony_ci std::vector<int64_t> h_shape = {h_input_shape}; 5056be168c0dSopenharmony_ci std::vector<int64_t> c_shape = {c_input_shape}; 5057be168c0dSopenharmony_ci std::vector<int64_t> reverse_shape = {1, 1}; 5058be168c0dSopenharmony_ci@@ -135,6 +137,11 @@ void LSTM::set_hidden_size(const int64_t hidden_size) { 5059be168c0dSopenharmony_ci (void)AddAttr(kHidden_size, api::MakeValue(hidden_size)); 5060be168c0dSopenharmony_ci } 5061be168c0dSopenharmony_ci int64_t LSTM::get_hidden_size() const { return GetValue<int64_t>(GetAttr(kHidden_size)); } 5062be168c0dSopenharmony_ci+void LSTM::set_proj_size(const int64_t proj_size) { 5063be168c0dSopenharmony_ci+ (void)CheckAndConvertUtils::CheckInteger(kProjection_size, proj_size, kGreaterThan, 0, this->name()); 5064be168c0dSopenharmony_ci+ (void)AddAttr(kProjection_size, api::MakeValue(proj_size)); 5065be168c0dSopenharmony_ci+} 5066be168c0dSopenharmony_ci+int64_t LSTM::get_proj_size() const { return GetValue<int64_t>(GetAttr(kProjection_size)); } 5067be168c0dSopenharmony_ci void LSTM::set_num_layers(const int64_t num_layers) { 5068be168c0dSopenharmony_ci (void)CheckAndConvertUtils::CheckInteger(kNumLayers, num_layers, kGreaterThan, 0, this->name()); 5069be168c0dSopenharmony_ci (void)AddAttr(kNumLayers, api::MakeValue(num_layers)); 5070be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/lstm.h b/mindspore/core/ops/lstm.h 5071be168c0dSopenharmony_ciindex 4d3c8756..e32c5781 100644 5072be168c0dSopenharmony_ci--- a/mindspore/core/ops/lstm.h 5073be168c0dSopenharmony_ci+++ b/mindspore/core/ops/lstm.h 5074be168c0dSopenharmony_ci@@ -51,6 +51,12 @@ class MIND_API LSTM : public BaseOperator { 5075be168c0dSopenharmony_ci /// 5076be168c0dSopenharmony_ci /// \return hidden_size. 5077be168c0dSopenharmony_ci int64_t get_hidden_size() const; 5078be168c0dSopenharmony_ci+ /// \brief Set proj_size. 5079be168c0dSopenharmony_ci+ void set_proj_size(const int64_t proj_size); 5080be168c0dSopenharmony_ci+ /// \brief Get proj_size. 5081be168c0dSopenharmony_ci+ /// 5082be168c0dSopenharmony_ci+ /// \return proj_size. 5083be168c0dSopenharmony_ci+ int64_t get_proj_size() const; 5084be168c0dSopenharmony_ci /// \brief Set num_layers. 5085be168c0dSopenharmony_ci void set_num_layers(const int64_t num_layers); 5086be168c0dSopenharmony_ci /// \brief Get num_layers. 5087be168c0dSopenharmony_cidiff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h 5088be168c0dSopenharmony_ciindex ce68079f..ad9066e7 100644 5089be168c0dSopenharmony_ci--- a/mindspore/core/ops/op_name.h 5090be168c0dSopenharmony_ci+++ b/mindspore/core/ops/op_name.h 5091be168c0dSopenharmony_ci@@ -268,6 +268,7 @@ constexpr auto kWindowSize = "window_size"; 5092be168c0dSopenharmony_ci constexpr auto kPaddings = "paddings"; 5093be168c0dSopenharmony_ci constexpr auto kInput_size = "input_size"; 5094be168c0dSopenharmony_ci constexpr auto kHidden_size = "hidden_size"; 5095be168c0dSopenharmony_ci+constexpr auto kProjection_size = "proj_size"; 5096be168c0dSopenharmony_ci constexpr auto kChannelShared = "channel_shared"; 5097be168c0dSopenharmony_ci constexpr auto kSlope = "slope"; 5098be168c0dSopenharmony_ci constexpr auto kBase = "base"; 5099be168c0dSopenharmony_cidiff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 5100be168c0dSopenharmony_ciindex f7e465e2..9318d54e 100644 5101be168c0dSopenharmony_ci--- a/mindspore/lite/BUILD.gn 5102be168c0dSopenharmony_ci+++ b/mindspore/lite/BUILD.gn 5103be168c0dSopenharmony_ci@@ -602,6 +602,8 @@ all_train_sources = [ 5104be168c0dSopenharmony_ci "src/train/optimizer/fusion/matmul_activation_fusion_pass.cc", 5105be168c0dSopenharmony_ci "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc", 5106be168c0dSopenharmony_ci "src/train/optimizer/fusion/gru_fusion_pass.cc", 5107be168c0dSopenharmony_ci+ "src/train/optimizer/fusion/matmul_add_fusion_pass.cc", 5108be168c0dSopenharmony_ci+ "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc", 5109be168c0dSopenharmony_ci "src/common/storage.cc", 5110be168c0dSopenharmony_ci "tools/converter/optimizer.cc", 5111be168c0dSopenharmony_ci "tools/converter/legacy_optimizer/fusion/fusion_pass.cc", 5112be168c0dSopenharmony_ci@@ -646,6 +648,7 @@ fp32_train_kernel_sources = [ 5113be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp32_grad/convolution.cc", 5114be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp32_grad/convolution_grad_filter.cc", 5115be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp32_grad/convolution_grad_input.cc", 5116be168c0dSopenharmony_ci+ "src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc", 5117be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp32_grad/deconvolution_grad_filter.cc", 5118be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp32_grad/dropout.cc", 5119be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp32_grad/dropout_grad.cc", 5120be168c0dSopenharmony_cidiff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt 5121be168c0dSopenharmony_ciindex 1faf2f38..f2b5809f 100644 5122be168c0dSopenharmony_ci--- a/mindspore/lite/CMakeLists.txt 5123be168c0dSopenharmony_ci+++ b/mindspore/lite/CMakeLists.txt 5124be168c0dSopenharmony_ci@@ -977,7 +977,7 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "lite" OR MSLITE_MINDDATA_IMPLEMENT STREQU 5125be168c0dSopenharmony_ci endif() 5126be168c0dSopenharmony_ci 5127be168c0dSopenharmony_ci add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/src/common/ops) 5128be168c0dSopenharmony_ci-if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX) 5129be168c0dSopenharmony_ci+if(ANDROID_NDK_TOOLCHAIN_INCLUDED OR TARGET_OHOS_LITE OR TARGET_HIMIX OR TARGET_OHOS) 5130be168c0dSopenharmony_ci add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter/micro/coder) 5131be168c0dSopenharmony_ci endif() 5132be168c0dSopenharmony_ci 5133be168c0dSopenharmony_cidiff --git a/mindspore/lite/schema/inner/ops_generated.h b/mindspore/lite/schema/inner/ops_generated.h 5134be168c0dSopenharmony_ciindex c4fd8c15..6c861aa5 100644 5135be168c0dSopenharmony_ci--- a/mindspore/lite/schema/inner/ops_generated.h 5136be168c0dSopenharmony_ci+++ b/mindspore/lite/schema/inner/ops_generated.h 5137be168c0dSopenharmony_ci@@ -11338,6 +11338,7 @@ struct LSTMT : public flatbuffers::NativeTable { 5138be168c0dSopenharmony_ci float dropout = 0.0f; 5139be168c0dSopenharmony_ci float zoneout_cell = 0.0f; 5140be168c0dSopenharmony_ci float zoneout_hidden = 0.0f; 5141be168c0dSopenharmony_ci+ int64_t proj_size = 0; 5142be168c0dSopenharmony_ci }; 5143be168c0dSopenharmony_ci 5144be168c0dSopenharmony_ci struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5145be168c0dSopenharmony_ci@@ -11355,7 +11356,8 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5146be168c0dSopenharmony_ci VT_NUM_DIRECTIONS = 14, 5147be168c0dSopenharmony_ci VT_DROPOUT = 16, 5148be168c0dSopenharmony_ci VT_ZONEOUT_CELL = 18, 5149be168c0dSopenharmony_ci- VT_ZONEOUT_HIDDEN = 20 5150be168c0dSopenharmony_ci+ VT_ZONEOUT_HIDDEN = 20, 5151be168c0dSopenharmony_ci+ VT_PROJ_SIZE = 22 5152be168c0dSopenharmony_ci }; 5153be168c0dSopenharmony_ci bool bidirectional() const { 5154be168c0dSopenharmony_ci return GetField<uint8_t>(VT_BIDIRECTIONAL, 0) != 0; 5155be168c0dSopenharmony_ci@@ -11411,6 +11413,12 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5156be168c0dSopenharmony_ci bool mutate_zoneout_hidden(float _zoneout_hidden) { 5157be168c0dSopenharmony_ci return SetField<float>(VT_ZONEOUT_HIDDEN, _zoneout_hidden, 0.0f); 5158be168c0dSopenharmony_ci } 5159be168c0dSopenharmony_ci+ int64_t proj_size() const { 5160be168c0dSopenharmony_ci+ return GetField<int64_t>(VT_PROJ_SIZE, 0); 5161be168c0dSopenharmony_ci+ } 5162be168c0dSopenharmony_ci+ bool mutate_proj_size(int64_t _proj_size) { 5163be168c0dSopenharmony_ci+ return SetField<int64_t>(VT_PROJ_SIZE, _proj_size, 0); 5164be168c0dSopenharmony_ci+ } 5165be168c0dSopenharmony_ci bool Verify(flatbuffers::Verifier &verifier) const { 5166be168c0dSopenharmony_ci return VerifyTableStart(verifier) && 5167be168c0dSopenharmony_ci VerifyField<uint8_t>(verifier, VT_BIDIRECTIONAL) && 5168be168c0dSopenharmony_ci@@ -11422,6 +11430,7 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5169be168c0dSopenharmony_ci VerifyField<float>(verifier, VT_DROPOUT) && 5170be168c0dSopenharmony_ci VerifyField<float>(verifier, VT_ZONEOUT_CELL) && 5171be168c0dSopenharmony_ci VerifyField<float>(verifier, VT_ZONEOUT_HIDDEN) && 5172be168c0dSopenharmony_ci+ VerifyField<int64_t>(verifier, VT_PROJ_SIZE) && 5173be168c0dSopenharmony_ci verifier.EndTable(); 5174be168c0dSopenharmony_ci } 5175be168c0dSopenharmony_ci LSTMT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; 5176be168c0dSopenharmony_ci@@ -11460,6 +11469,9 @@ struct LSTMBuilder { 5177be168c0dSopenharmony_ci void add_zoneout_hidden(float zoneout_hidden) { 5178be168c0dSopenharmony_ci fbb_.AddElement<float>(LSTM::VT_ZONEOUT_HIDDEN, zoneout_hidden, 0.0f); 5179be168c0dSopenharmony_ci } 5180be168c0dSopenharmony_ci+ void add_proj_size(int64_t proj_size) { 5181be168c0dSopenharmony_ci+ fbb_.AddElement<int64_t>(LSTM::VT_PROJ_SIZE, proj_size, 0); 5182be168c0dSopenharmony_ci+ } 5183be168c0dSopenharmony_ci explicit LSTMBuilder(flatbuffers::FlatBufferBuilder &_fbb) 5184be168c0dSopenharmony_ci : fbb_(_fbb) { 5185be168c0dSopenharmony_ci start_ = fbb_.StartTable(); 5186be168c0dSopenharmony_ci@@ -11481,8 +11493,10 @@ inline flatbuffers::Offset<LSTM> CreateLSTM( 5187be168c0dSopenharmony_ci int64_t num_directions = 0, 5188be168c0dSopenharmony_ci float dropout = 0.0f, 5189be168c0dSopenharmony_ci float zoneout_cell = 0.0f, 5190be168c0dSopenharmony_ci- float zoneout_hidden = 0.0f) { 5191be168c0dSopenharmony_ci+ float zoneout_hidden = 0.0f, 5192be168c0dSopenharmony_ci+ int64_t proj_size = 0) { 5193be168c0dSopenharmony_ci LSTMBuilder builder_(_fbb); 5194be168c0dSopenharmony_ci+ builder_.add_proj_size(proj_size); 5195be168c0dSopenharmony_ci builder_.add_num_directions(num_directions); 5196be168c0dSopenharmony_ci builder_.add_num_layers(num_layers); 5197be168c0dSopenharmony_ci builder_.add_hidden_size(hidden_size); 5198be168c0dSopenharmony_ci@@ -23571,6 +23585,7 @@ inline void LSTM::UnPackTo(LSTMT *_o, const flatbuffers::resolver_function_t *_r 5199be168c0dSopenharmony_ci { auto _e = dropout(); _o->dropout = _e; } 5200be168c0dSopenharmony_ci { auto _e = zoneout_cell(); _o->zoneout_cell = _e; } 5201be168c0dSopenharmony_ci { auto _e = zoneout_hidden(); _o->zoneout_hidden = _e; } 5202be168c0dSopenharmony_ci+ { auto _e = proj_size(); _o->proj_size = _e; } 5203be168c0dSopenharmony_ci } 5204be168c0dSopenharmony_ci 5205be168c0dSopenharmony_ci inline flatbuffers::Offset<LSTM> LSTM::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMT* _o, const flatbuffers::rehasher_function_t *_rehasher) { 5206be168c0dSopenharmony_ci@@ -23590,6 +23605,7 @@ inline flatbuffers::Offset<LSTM> CreateLSTM(flatbuffers::FlatBufferBuilder &_fbb 5207be168c0dSopenharmony_ci auto _dropout = _o->dropout; 5208be168c0dSopenharmony_ci auto _zoneout_cell = _o->zoneout_cell; 5209be168c0dSopenharmony_ci auto _zoneout_hidden = _o->zoneout_hidden; 5210be168c0dSopenharmony_ci+ auto _proj_size = _o->proj_size; 5211be168c0dSopenharmony_ci return mindspore::schema::CreateLSTM( 5212be168c0dSopenharmony_ci _fbb, 5213be168c0dSopenharmony_ci _bidirectional, 5214be168c0dSopenharmony_ci@@ -23600,7 +23616,8 @@ inline flatbuffers::Offset<LSTM> CreateLSTM(flatbuffers::FlatBufferBuilder &_fbb 5215be168c0dSopenharmony_ci _num_directions, 5216be168c0dSopenharmony_ci _dropout, 5217be168c0dSopenharmony_ci _zoneout_cell, 5218be168c0dSopenharmony_ci- _zoneout_hidden); 5219be168c0dSopenharmony_ci+ _zoneout_hidden, 5220be168c0dSopenharmony_ci+ _proj_size); 5221be168c0dSopenharmony_ci } 5222be168c0dSopenharmony_ci 5223be168c0dSopenharmony_ci inline LSTMGradT *LSTMGrad::UnPack(const flatbuffers::resolver_function_t *_resolver) const { 5224be168c0dSopenharmony_cidiff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs 5225be168c0dSopenharmony_ciindex 76caf810..920c0d31 100644 5226be168c0dSopenharmony_ci--- a/mindspore/lite/schema/ops.fbs 5227be168c0dSopenharmony_ci+++ b/mindspore/lite/schema/ops.fbs 5228be168c0dSopenharmony_ci@@ -688,6 +688,7 @@ table LSTM { 5229be168c0dSopenharmony_ci dropout: float; 5230be168c0dSopenharmony_ci zoneout_cell: float = 0; 5231be168c0dSopenharmony_ci zoneout_hidden: float = 0; 5232be168c0dSopenharmony_ci+ proj_size: long = 0; 5233be168c0dSopenharmony_ci } 5234be168c0dSopenharmony_ci 5235be168c0dSopenharmony_ci table LSTMGrad { 5236be168c0dSopenharmony_cidiff --git a/mindspore/lite/schema/ops_generated.h b/mindspore/lite/schema/ops_generated.h 5237be168c0dSopenharmony_ciindex 2f792706..8d387e9d 100644 5238be168c0dSopenharmony_ci--- a/mindspore/lite/schema/ops_generated.h 5239be168c0dSopenharmony_ci+++ b/mindspore/lite/schema/ops_generated.h 5240be168c0dSopenharmony_ci@@ -7046,7 +7046,8 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5241be168c0dSopenharmony_ci VT_NUM_DIRECTIONS = 14, 5242be168c0dSopenharmony_ci VT_DROPOUT = 16, 5243be168c0dSopenharmony_ci VT_ZONEOUT_CELL = 18, 5244be168c0dSopenharmony_ci- VT_ZONEOUT_HIDDEN = 20 5245be168c0dSopenharmony_ci+ VT_ZONEOUT_HIDDEN = 20, 5246be168c0dSopenharmony_ci+ VT_PROJ_SIZE = 22 5247be168c0dSopenharmony_ci }; 5248be168c0dSopenharmony_ci bool bidirectional() const { 5249be168c0dSopenharmony_ci return GetField<uint8_t>(VT_BIDIRECTIONAL, 0) != 0; 5250be168c0dSopenharmony_ci@@ -7075,6 +7076,9 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5251be168c0dSopenharmony_ci float zoneout_hidden() const { 5252be168c0dSopenharmony_ci return GetField<float>(VT_ZONEOUT_HIDDEN, 0.0f); 5253be168c0dSopenharmony_ci } 5254be168c0dSopenharmony_ci+ int64_t proj_size() const { 5255be168c0dSopenharmony_ci+ return GetField<int64_t>(VT_PROJ_SIZE, 0); 5256be168c0dSopenharmony_ci+ } 5257be168c0dSopenharmony_ci bool Verify(flatbuffers::Verifier &verifier) const { 5258be168c0dSopenharmony_ci return VerifyTableStart(verifier) && 5259be168c0dSopenharmony_ci VerifyField<uint8_t>(verifier, VT_BIDIRECTIONAL) && 5260be168c0dSopenharmony_ci@@ -7086,6 +7090,7 @@ struct LSTM FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { 5261be168c0dSopenharmony_ci VerifyField<float>(verifier, VT_DROPOUT) && 5262be168c0dSopenharmony_ci VerifyField<float>(verifier, VT_ZONEOUT_CELL) && 5263be168c0dSopenharmony_ci VerifyField<float>(verifier, VT_ZONEOUT_HIDDEN) && 5264be168c0dSopenharmony_ci+ VerifyField<int64_t>(verifier, VT_PROJ_SIZE) && 5265be168c0dSopenharmony_ci verifier.EndTable(); 5266be168c0dSopenharmony_ci } 5267be168c0dSopenharmony_ci }; 5268be168c0dSopenharmony_ci@@ -7121,6 +7126,9 @@ struct LSTMBuilder { 5269be168c0dSopenharmony_ci void add_zoneout_hidden(float zoneout_hidden) { 5270be168c0dSopenharmony_ci fbb_.AddElement<float>(LSTM::VT_ZONEOUT_HIDDEN, zoneout_hidden, 0.0f); 5271be168c0dSopenharmony_ci } 5272be168c0dSopenharmony_ci+ void add_proj_size(int64_t proj_size) { 5273be168c0dSopenharmony_ci+ fbb_.AddElement<int64_t>(LSTM::VT_PROJ_SIZE, proj_size, 0); 5274be168c0dSopenharmony_ci+ } 5275be168c0dSopenharmony_ci explicit LSTMBuilder(flatbuffers::FlatBufferBuilder &_fbb) 5276be168c0dSopenharmony_ci : fbb_(_fbb) { 5277be168c0dSopenharmony_ci start_ = fbb_.StartTable(); 5278be168c0dSopenharmony_ci@@ -7142,8 +7150,10 @@ inline flatbuffers::Offset<LSTM> CreateLSTM( 5279be168c0dSopenharmony_ci int64_t num_directions = 0, 5280be168c0dSopenharmony_ci float dropout = 0.0f, 5281be168c0dSopenharmony_ci float zoneout_cell = 0.0f, 5282be168c0dSopenharmony_ci- float zoneout_hidden = 0.0f) { 5283be168c0dSopenharmony_ci+ float zoneout_hidden = 0.0f, 5284be168c0dSopenharmony_ci+ int64_t proj_size = 0) { 5285be168c0dSopenharmony_ci LSTMBuilder builder_(_fbb); 5286be168c0dSopenharmony_ci+ builder_.add_proj_size(proj_size); 5287be168c0dSopenharmony_ci builder_.add_num_directions(num_directions); 5288be168c0dSopenharmony_ci builder_.add_num_layers(num_layers); 5289be168c0dSopenharmony_ci builder_.add_hidden_size(hidden_size); 5290be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 5291be168c0dSopenharmony_ciindex de1781cd..469bcb6b 100644 5292be168c0dSopenharmony_ci--- a/mindspore/lite/src/CMakeLists.txt 5293be168c0dSopenharmony_ci+++ b/mindspore/lite/src/CMakeLists.txt 5294be168c0dSopenharmony_ci@@ -337,6 +337,8 @@ set(TRAIN_SRC 5295be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/common/fusion_utils.cc 5296be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/gru_fusion_pass.cc 5297be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_activation_fusion_pass.cc 5298be168c0dSopenharmony_ci+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_add_fusion_pass.cc 5299be168c0dSopenharmony_ci+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 5300be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 5301be168c0dSopenharmony_ci ${TOOLS_DIR}/converter/optimizer.cc 5302be168c0dSopenharmony_ci ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc 5303be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc 5304be168c0dSopenharmony_ciindex e5c7f5ca..baa2497a 100644 5305be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/ops_def.cc 5306be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/ops_def.cc 5307be168c0dSopenharmony_ci@@ -688,6 +688,7 @@ OP_ATTR(num_directions, long) 5308be168c0dSopenharmony_ci OP_ATTR(dropout, float) 5309be168c0dSopenharmony_ci OP_ATTR_WITH_VALUE(zoneout_cell, float, 0) 5310be168c0dSopenharmony_ci OP_ATTR_WITH_VALUE(zoneout_hidden, float, 0) 5311be168c0dSopenharmony_ci+OP_ATTR_WITH_VALUE(proj_size, long, 0) 5312be168c0dSopenharmony_ci OP_SCHEMA_DEF_END(LSTM) 5313be168c0dSopenharmony_ci 5314be168c0dSopenharmony_ci OP_SCHEMA_DEF(LSTMGrad) 5315be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc 5316be168c0dSopenharmony_ciindex 13957ed7..6c490130 100644 5317be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc 5318be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc 5319be168c0dSopenharmony_ci@@ -22,6 +22,7 @@ 5320be168c0dSopenharmony_ci #include "nnacl/custom_masked_fill_parameter.h" 5321be168c0dSopenharmony_ci #include "nnacl/custom_is_inf_parameter.h" 5322be168c0dSopenharmony_ci #include "nnacl/custom_tensor_scatter_max_parameter.h" 5323be168c0dSopenharmony_ci+#include "nnacl/custom_gather_d_grad_v2_parameter.h" 5324be168c0dSopenharmony_ci using mindspore::schema::PrimitiveType_Custom; 5325be168c0dSopenharmony_ci 5326be168c0dSopenharmony_ci namespace mindspore { 5327be168c0dSopenharmony_ci@@ -128,6 +129,33 @@ OpParameter *CreateCustomMaskedFillParameter() { 5328be168c0dSopenharmony_ci return reinterpret_cast<OpParameter *>(param); 5329be168c0dSopenharmony_ci } 5330be168c0dSopenharmony_ci 5331be168c0dSopenharmony_ci+OpParameter *CreateCustomGatherDGradV2Parameter(const schema::Custom *value) { 5332be168c0dSopenharmony_ci+ if (value->attr()->size() < 1) { 5333be168c0dSopenharmony_ci+ return nullptr; 5334be168c0dSopenharmony_ci+ } 5335be168c0dSopenharmony_ci+ auto *param = static_cast<CustomGatherGradV2Parameter *>(malloc(sizeof(CustomGatherGradV2Parameter))); 5336be168c0dSopenharmony_ci+ if (param == nullptr) { 5337be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc CustomGruParameter failed."; 5338be168c0dSopenharmony_ci+ return nullptr; 5339be168c0dSopenharmony_ci+ } 5340be168c0dSopenharmony_ci+ 5341be168c0dSopenharmony_ci+ std::string dim_str; 5342be168c0dSopenharmony_ci+ auto attrs = value->attr(); 5343be168c0dSopenharmony_ci+ for (size_t i = 0; i < attrs->size(); i++) { 5344be168c0dSopenharmony_ci+ auto attr = attrs->Get(i); 5345be168c0dSopenharmony_ci+ if (attr->name()->str() == "dim") { 5346be168c0dSopenharmony_ci+ auto data = attr->data(); 5347be168c0dSopenharmony_ci+ dim_str = std::string(reinterpret_cast<const char *>(data->Data()), data->size()); 5348be168c0dSopenharmony_ci+ break; 5349be168c0dSopenharmony_ci+ } 5350be168c0dSopenharmony_ci+ } 5351be168c0dSopenharmony_ci+ 5352be168c0dSopenharmony_ci+ memset(param, 0, sizeof(CustomGatherGradV2Parameter)); 5353be168c0dSopenharmony_ci+ param->dim = std::stoi(dim_str.c_str()); 5354be168c0dSopenharmony_ci+ param->op_parameter_.type_ = PrimType_Inner_CustomGatherDGradV2; 5355be168c0dSopenharmony_ci+ return reinterpret_cast<OpParameter *>(param); 5356be168c0dSopenharmony_ci+} 5357be168c0dSopenharmony_ci+ 5358be168c0dSopenharmony_ci OpParameter *PopulateCustomParameter(const void *prim) { 5359be168c0dSopenharmony_ci MS_CHECK_TRUE_RET(prim != nullptr, nullptr); 5360be168c0dSopenharmony_ci auto primitive = static_cast<const schema::Primitive *>(prim); 5361be168c0dSopenharmony_ci@@ -167,6 +195,8 @@ OpParameter *PopulateCustomParameter(const void *prim) { 5362be168c0dSopenharmony_ci return CreateCustomGruParameter(); 5363be168c0dSopenharmony_ci } else if (type == "CastGatherReduceFusion") { 5364be168c0dSopenharmony_ci return CreateParam(PrimType_Inner_CastGatherReduceFusion); 5365be168c0dSopenharmony_ci+ } else if (type == "GatherDGradV2") { 5366be168c0dSopenharmony_ci+ return CreateCustomGatherDGradV2Parameter(value); 5367be168c0dSopenharmony_ci } else if (type == "ThirdPartyModel") { 5368be168c0dSopenharmony_ci auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter))); 5369be168c0dSopenharmony_ci if (param == nullptr) { 5370be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/lstm_populate.cc b/mindspore/lite/src/common/ops/populate/lstm_populate.cc 5371be168c0dSopenharmony_ciindex 522da7ad..b3a85b64 100644 5372be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/lstm_populate.cc 5373be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/lstm_populate.cc 5374be168c0dSopenharmony_ci@@ -37,8 +37,12 @@ OpParameter *PopulateLstmParameter(const void *prim) { 5375be168c0dSopenharmony_ci 5376be168c0dSopenharmony_ci param->op_parameter_.type_ = primitive->value_type(); 5377be168c0dSopenharmony_ci param->bidirectional_ = value->bidirectional(); 5378be168c0dSopenharmony_ci+ param->has_bias_ = value->has_bias(); 5379be168c0dSopenharmony_ci+ param->input_size_ = value->input_size(); 5380be168c0dSopenharmony_ci+ param->hidden_size_ = value->hidden_size(); 5381be168c0dSopenharmony_ci param->zoneout_cell_ = value->zoneout_cell(); 5382be168c0dSopenharmony_ci param->zoneout_hidden_ = value->zoneout_hidden(); 5383be168c0dSopenharmony_ci+ param->project_size_ = value->proj_size(); 5384be168c0dSopenharmony_ci return reinterpret_cast<OpParameter *>(param); 5385be168c0dSopenharmony_ci } 5386be168c0dSopenharmony_ci 5387be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/prim_util.cc b/mindspore/lite/src/common/prim_util.cc 5388be168c0dSopenharmony_ciindex 5ded05e9..7263775a 100644 5389be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/prim_util.cc 5390be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/prim_util.cc 5391be168c0dSopenharmony_ci@@ -29,11 +29,25 @@ static std::set<schema::PrimitiveType> kTensorListOps = { 5392be168c0dSopenharmony_ci schema::PrimitiveType_TensorListReserve, schema::PrimitiveType_TensorListSetItem, 5393be168c0dSopenharmony_ci schema::PrimitiveType_TensorListStack}; 5394be168c0dSopenharmony_ci 5395be168c0dSopenharmony_ci-static const char *const kInnerOpNames[C10NUM] = {"Inner_ToFormat", "Inner_GltextureToOpencl", 5396be168c0dSopenharmony_ci- "Inner_Identity", "Inner_ShapeFusion", 5397be168c0dSopenharmony_ci- "Inner_GraphKernel", "Inner_SplitReduceConcatFusion", 5398be168c0dSopenharmony_ci- "Inner_EncoderLayer", "Inner_DecoderLayer", 5399be168c0dSopenharmony_ci- "Inner_UsePastEmbedding", "Inner_CustomGru"}; 5400be168c0dSopenharmony_ci+static const char *const kInnerOpNames[C20NUM] = {"Inner_ToFormat", 5401be168c0dSopenharmony_ci+ "Inner_GltextureToOpencl", 5402be168c0dSopenharmony_ci+ "Inner_Identity", 5403be168c0dSopenharmony_ci+ "Inner_ShapeFusion", 5404be168c0dSopenharmony_ci+ "Inner_GraphKernel", 5405be168c0dSopenharmony_ci+ "Inner_SplitReduceConcatFusion", 5406be168c0dSopenharmony_ci+ "Inner_EncoderLayer", 5407be168c0dSopenharmony_ci+ "PrimType_Inner_FseDecode", 5408be168c0dSopenharmony_ci+ "Inner_DecoderLayer", 5409be168c0dSopenharmony_ci+ "Inner_UsePastEmbedding", 5410be168c0dSopenharmony_ci+ "Inner_CustomGru", 5411be168c0dSopenharmony_ci+ "PrimType_Inner_CastGatherReduceFusion", 5412be168c0dSopenharmony_ci+ "PrimType_Inner_ReduceConcatFusion", 5413be168c0dSopenharmony_ci+ "PrimType_Inner_ThirdPartyModel", 5414be168c0dSopenharmony_ci+ "PrimType_Inner_CustomMaskedFill", 5415be168c0dSopenharmony_ci+ "PrimType_Inner_CustomTensorScatterMax", 5416be168c0dSopenharmony_ci+ "PrimType_Inner_CustomIsInf", 5417be168c0dSopenharmony_ci+ "PrimType_Inner_CustomGatherDGradV2"}; 5418be168c0dSopenharmony_ci+ 5419be168c0dSopenharmony_ci int GetPrimitiveType(const void *primitive, int schema_version) { 5420be168c0dSopenharmony_ci if (primitive == nullptr) { 5421be168c0dSopenharmony_ci return -1; 5422be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5423be168c0dSopenharmony_ciindex 65065b5b..7b813314 100644 5424be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5425be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5426be168c0dSopenharmony_ci@@ -85,6 +85,9 @@ cpu_kernel_sources = [ 5427be168c0dSopenharmony_ci "fp32/invert_permutation_fp32.cc", 5428be168c0dSopenharmony_ci "fp32/l2_norm_fp32.cc", 5429be168c0dSopenharmony_ci "fp32/lstm_fp32.cc", 5430be168c0dSopenharmony_ci+ "fp32/lstm_fp32_base.cc", 5431be168c0dSopenharmony_ci+ "fp32/lstm_mindir_fp32.cc", 5432be168c0dSopenharmony_ci+ "fp32/lstm_non_mindir_fp32.cc", 5433be168c0dSopenharmony_ci "fp32/matmul_fp32_arm32.cc", 5434be168c0dSopenharmony_ci "fp32/matmul_fp32_arm64.cc", 5435be168c0dSopenharmony_ci "fp32/matmul_fp32_avx512.cc", 5436be168c0dSopenharmony_ci@@ -174,6 +177,9 @@ fp16_kernel_sources = [ 5437be168c0dSopenharmony_ci "fp16/instance_norm_fp16.cc", 5438be168c0dSopenharmony_ci "fp16/layout_transform_fp16.cc", 5439be168c0dSopenharmony_ci "fp16/lstm_fp16.cc", 5440be168c0dSopenharmony_ci+ "fp16/lstm_fp16_base.cc", 5441be168c0dSopenharmony_ci+ "fp16/lstm_mindir_fp16.cc", 5442be168c0dSopenharmony_ci+ "fp16/lstm_non_mindir_fp16.cc", 5443be168c0dSopenharmony_ci "fp16/matmul_base_fp16.cc", 5444be168c0dSopenharmony_ci "fp16/matmul_fp16.cc", 5445be168c0dSopenharmony_ci "fp16/power_fp16.cc", 5446be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc 5447be168c0dSopenharmony_ciindex 232bbe44..89945e1c 100644 5448be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc 5449be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/gru_fp16.cc 5450be168c0dSopenharmony_ci@@ -100,10 +100,10 @@ int GruFp16CPUKernel::InitInputWeightBias() { 5451be168c0dSopenharmony_ci } 5452be168c0dSopenharmony_ci if (weight_g->data_type() == kNumberTypeFloat32) { 5453be168c0dSopenharmony_ci PackLstmWeightFp32ToFp16(weight_g_ptr_, reinterpret_cast<float *>(weight_g->data()), weight_batch_, 5454be168c0dSopenharmony_ci- gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_); 5455be168c0dSopenharmony_ci+ gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_, nullptr); 5456be168c0dSopenharmony_ci } else if (weight_g->data_type() == kNumberTypeFloat16) { 5457be168c0dSopenharmony_ci PackLstmWeightFp16(weight_g_ptr_, reinterpret_cast<float16_t *>(weight_g->data()), weight_batch_, 5458be168c0dSopenharmony_ci- gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_); 5459be168c0dSopenharmony_ci+ gru_param_->input_size_, gru_param_->hidden_size_, gru_param_->input_col_align_, nullptr); 5460be168c0dSopenharmony_ci } else { 5461be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported data type of weight_g tensor for gru."; 5462be168c0dSopenharmony_ci return RET_ERROR; 5463be168c0dSopenharmony_ci@@ -120,10 +120,10 @@ int GruFp16CPUKernel::InitInputWeightBias() { 5464be168c0dSopenharmony_ci memset(input_bias_, 0, weight_batch_ * gru_param_->input_col_align_ * sizeof(float16_t)); 5465be168c0dSopenharmony_ci if (bias->data_type() == kNumberTypeFloat32) { 5466be168c0dSopenharmony_ci PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias->data()), weight_batch_, 5467be168c0dSopenharmony_ci- gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_); 5468be168c0dSopenharmony_ci+ gru_param_->hidden_size_, gru_param_->input_col_align_, gru_param_->bidirectional_, nullptr); 5469be168c0dSopenharmony_ci } else if (bias->data_type() == kNumberTypeFloat16) { 5470be168c0dSopenharmony_ci PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias->data()), weight_batch_, gru_param_->hidden_size_, 5471be168c0dSopenharmony_ci- gru_param_->input_col_align_, gru_param_->bidirectional_); 5472be168c0dSopenharmony_ci+ gru_param_->input_col_align_, gru_param_->bidirectional_, nullptr); 5473be168c0dSopenharmony_ci } else { 5474be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru."; 5475be168c0dSopenharmony_ci return RET_ERROR; 5476be168c0dSopenharmony_ci@@ -148,10 +148,10 @@ int GruFp16CPUKernel::InitStateWeightBias() { 5477be168c0dSopenharmony_ci if (!is_vec_) { 5478be168c0dSopenharmony_ci if (weight_r->data_type() == kNumberTypeFloat32) { 5479be168c0dSopenharmony_ci PackLstmWeightFp32ToFp16(weight_r_ptr_, reinterpret_cast<float *>(weight_r->data()), weight_batch_, 5480be168c0dSopenharmony_ci- gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_); 5481be168c0dSopenharmony_ci+ gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_, nullptr); 5482be168c0dSopenharmony_ci } else if (weight_r->data_type() == kNumberTypeFloat16) { 5483be168c0dSopenharmony_ci PackLstmWeightFp16(weight_r_ptr_, reinterpret_cast<float16_t *>(weight_r->data()), weight_batch_, 5484be168c0dSopenharmony_ci- gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_); 5485be168c0dSopenharmony_ci+ gru_param_->hidden_size_, gru_param_->hidden_size_, gru_param_->state_col_align_, nullptr); 5486be168c0dSopenharmony_ci } else { 5487be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported data type of weight_r tensor for gru."; 5488be168c0dSopenharmony_ci return RET_ERROR; 5489be168c0dSopenharmony_ci@@ -179,11 +179,11 @@ int GruFp16CPUKernel::InitStateWeightBias() { 5490be168c0dSopenharmony_ci if (bias->data_type() == kNumberTypeFloat32) { 5491be168c0dSopenharmony_ci auto state_bias_data = reinterpret_cast<float *>(bias->data()) + gate_num * gru_param_->hidden_size_; 5492be168c0dSopenharmony_ci PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_, 5493be168c0dSopenharmony_ci- gru_param_->state_col_align_, gru_param_->bidirectional_); 5494be168c0dSopenharmony_ci+ gru_param_->state_col_align_, gru_param_->bidirectional_, nullptr); 5495be168c0dSopenharmony_ci } else if (bias->data_type() == kNumberTypeFloat16) { 5496be168c0dSopenharmony_ci auto state_bias_data = reinterpret_cast<float16_t *>(bias->data()) + gate_num * gru_param_->hidden_size_; 5497be168c0dSopenharmony_ci PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, gru_param_->hidden_size_, 5498be168c0dSopenharmony_ci- gru_param_->state_col_align_, gru_param_->bidirectional_); 5499be168c0dSopenharmony_ci+ gru_param_->state_col_align_, gru_param_->bidirectional_, nullptr); 5500be168c0dSopenharmony_ci } else { 5501be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported data type of bias tensor for gru."; 5502be168c0dSopenharmony_ci return RET_ERROR; 5503be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc 5504be168c0dSopenharmony_ciindex b583358a..bd99b791 100644 5505be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc 5506be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16.cc 5507be168c0dSopenharmony_ci@@ -1,5 +1,5 @@ 5508be168c0dSopenharmony_ci /** 5509be168c0dSopenharmony_ci- * Copyright 2021 Huawei Technologies Co., Ltd 5510be168c0dSopenharmony_ci+ * Copyright 2021-2023 Huawei Technologies Co., Ltd 5511be168c0dSopenharmony_ci * 5512be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License"); 5513be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 5514be168c0dSopenharmony_ci@@ -16,13 +16,9 @@ 5515be168c0dSopenharmony_ci 5516be168c0dSopenharmony_ci #include "src/litert/kernel/cpu/fp16/lstm_fp16.h" 5517be168c0dSopenharmony_ci #include <vector> 5518be168c0dSopenharmony_ci-#include <cfloat> 5519be168c0dSopenharmony_ci-#include "schema/model_generated.h" 5520be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h" 5521be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h" 5522be168c0dSopenharmony_ci #include "src/litert/kernel_registry.h" 5523be168c0dSopenharmony_ci-#include "include/errorcode.h" 5524be168c0dSopenharmony_ci-#include "nnacl/fp16/lstm_fp16.h" 5525be168c0dSopenharmony_ci-#include "nnacl/fp16/cast_fp16.h" 5526be168c0dSopenharmony_ci-#include "nnacl/errorcode.h" 5527be168c0dSopenharmony_ci 5528be168c0dSopenharmony_ci using mindspore::kernel::KERNEL_ARCH; 5529be168c0dSopenharmony_ci using mindspore::lite::KernelRegistrar; 5530be168c0dSopenharmony_ci@@ -31,389 +27,34 @@ using mindspore::lite::RET_OK; 5531be168c0dSopenharmony_ci using mindspore::schema::PrimitiveType_LSTM; 5532be168c0dSopenharmony_ci 5533be168c0dSopenharmony_ci namespace mindspore::kernel { 5534be168c0dSopenharmony_ci-void LstmFp16CPUKernel::FreeTmpBuffer() { 5535be168c0dSopenharmony_ci- if (weight_i_ptr_ != nullptr) { 5536be168c0dSopenharmony_ci- free(weight_i_ptr_); 5537be168c0dSopenharmony_ci- weight_i_ptr_ = nullptr; 5538be168c0dSopenharmony_ci- } 5539be168c0dSopenharmony_ci- if (input_bias_ != nullptr) { 5540be168c0dSopenharmony_ci- free(input_bias_); 5541be168c0dSopenharmony_ci- input_bias_ = nullptr; 5542be168c0dSopenharmony_ci- } 5543be168c0dSopenharmony_ci- if (weight_h_ptr_ != nullptr) { 5544be168c0dSopenharmony_ci- free(weight_h_ptr_); 5545be168c0dSopenharmony_ci- weight_h_ptr_ = nullptr; 5546be168c0dSopenharmony_ci- } 5547be168c0dSopenharmony_ci- if (state_bias_ != nullptr) { 5548be168c0dSopenharmony_ci- free(state_bias_); 5549be168c0dSopenharmony_ci- state_bias_ = nullptr; 5550be168c0dSopenharmony_ci- } 5551be168c0dSopenharmony_ci- if (weight_project_ptr_ != nullptr) { 5552be168c0dSopenharmony_ci- free(weight_project_ptr_); 5553be168c0dSopenharmony_ci- weight_project_ptr_ = nullptr; 5554be168c0dSopenharmony_ci- } 5555be168c0dSopenharmony_ci- if (project_bias_ != nullptr) { 5556be168c0dSopenharmony_ci- free(project_bias_); 5557be168c0dSopenharmony_ci- project_bias_ = nullptr; 5558be168c0dSopenharmony_ci- } 5559be168c0dSopenharmony_ci-} 5560be168c0dSopenharmony_ci- 5561be168c0dSopenharmony_ci-void LstmFp16CPUKernel::FreeRunBuffer() { 5562be168c0dSopenharmony_ci- ms_context_->allocator->Free(buffer_[packed_input_index]); 5563be168c0dSopenharmony_ci- ms_context_->allocator->Free(buffer_[input_gate_index]); 5564be168c0dSopenharmony_ci- if (!is_vec_) { 5565be168c0dSopenharmony_ci- ms_context_->allocator->Free(buffer_[packed_state_index]); 5566be168c0dSopenharmony_ci- } 5567be168c0dSopenharmony_ci- ms_context_->allocator->Free(buffer_[state_gate_index]); 5568be168c0dSopenharmony_ci- if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 5569be168c0dSopenharmony_ci- ms_context_->allocator->Free(buffer_[cell_state_index]); 5570be168c0dSopenharmony_ci- } 5571be168c0dSopenharmony_ci- if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 5572be168c0dSopenharmony_ci- ms_context_->allocator->Free(buffer_[hidden_state_index]); 5573be168c0dSopenharmony_ci- } 5574be168c0dSopenharmony_ci-} 5575be168c0dSopenharmony_ci- 5576be168c0dSopenharmony_ci-int LstmFp16CPUKernel::InitParam() { 5577be168c0dSopenharmony_ci- auto input = in_tensors_.front(); 5578be168c0dSopenharmony_ci- std::vector<int> in_shape = input->shape(); 5579be168c0dSopenharmony_ci- lstm_param_->seq_len_ = in_shape.at(0); 5580be168c0dSopenharmony_ci- lstm_param_->batch_ = in_shape.at(1); 5581be168c0dSopenharmony_ci- lstm_param_->input_size_ = in_shape.at(kNHWC_W); 5582be168c0dSopenharmony_ci- 5583be168c0dSopenharmony_ci- auto weight_i = in_tensors_.at(1); 5584be168c0dSopenharmony_ci- std::vector<int> w_shape = weight_i->shape(); 5585be168c0dSopenharmony_ci- NNACL_CHECK_ZERO_RETURN_ERR(gate_num); 5586be168c0dSopenharmony_ci- lstm_param_->hidden_size_ = w_shape.at(1) / gate_num; 5587be168c0dSopenharmony_ci- 5588be168c0dSopenharmony_ci- auto weight_h = in_tensors_.at(C2NUM); 5589be168c0dSopenharmony_ci- auto h_shape = weight_h->shape(); 5590be168c0dSopenharmony_ci- lstm_param_->project_size_ = h_shape.back(); 5591be168c0dSopenharmony_ci- 5592be168c0dSopenharmony_ci- const int twice = 2; 5593be168c0dSopenharmony_ci- lstm_param_->output_step_ = lstm_param_->bidirectional_ ? twice * lstm_param_->batch_ * lstm_param_->hidden_size_ 5594be168c0dSopenharmony_ci- : lstm_param_->batch_ * lstm_param_->hidden_size_; 5595be168c0dSopenharmony_ci- weight_batch_ = lstm_param_->bidirectional_ ? twice * gate_num : gate_num; 5596be168c0dSopenharmony_ci- lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM); 5597be168c0dSopenharmony_ci- lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C8NUM); 5598be168c0dSopenharmony_ci- 5599be168c0dSopenharmony_ci- is_vec_ = lstm_param_->batch_ == 1; 5600be168c0dSopenharmony_ci- lstm_param_->state_row_align_ = is_vec_ ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); 5601be168c0dSopenharmony_ci- lstm_param_->state_col_align_ = is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); 5602be168c0dSopenharmony_ci- return RET_OK; 5603be168c0dSopenharmony_ci-} 5604be168c0dSopenharmony_ci- 5605be168c0dSopenharmony_ci-int LstmFp16CPUKernel::InitInputWeightBias() { 5606be168c0dSopenharmony_ci- // malloc and init input * weight right matrix buffer 5607be168c0dSopenharmony_ci- // input -- row: seq_len * batch; col: input_size 5608be168c0dSopenharmony_ci- // weight -- row: hidden_size; col: input_size, need transpose 5609be168c0dSopenharmony_ci- // result -- row: seq_len * batch; col: hidden_size 5610be168c0dSopenharmony_ci- auto weight_i = in_tensors_.at(1); 5611be168c0dSopenharmony_ci- auto weight_i_data = weight_i->data(); 5612be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_i_data); 5613be168c0dSopenharmony_ci- weight_i_ptr_ = reinterpret_cast<float16_t *>( 5614be168c0dSopenharmony_ci- malloc(weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 5615be168c0dSopenharmony_ci- if (weight_i_ptr_ == nullptr) { 5616be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_i_ptr_ error."; 5617be168c0dSopenharmony_ci- return RET_ERROR; 5618be168c0dSopenharmony_ci- } 5619be168c0dSopenharmony_ci- if (weight_i->data_type() == kNumberTypeFloat32) { 5620be168c0dSopenharmony_ci- PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i_data), weight_batch_, 5621be168c0dSopenharmony_ci- lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_); 5622be168c0dSopenharmony_ci- } else if (weight_i->data_type() == kNumberTypeFloat16) { 5623be168c0dSopenharmony_ci- PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i_data), weight_batch_, 5624be168c0dSopenharmony_ci- lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_); 5625be168c0dSopenharmony_ci+namespace { 5626be168c0dSopenharmony_ci+constexpr size_t kMindirInputTensorNum = 4; 5627be168c0dSopenharmony_ci+} // namespace 5628be168c0dSopenharmony_ci+ 5629be168c0dSopenharmony_ci+LiteKernel *LstmFp16KernelCreator(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, 5630be168c0dSopenharmony_ci+ OpParameter *parameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc) { 5631be168c0dSopenharmony_ci+ if (parameter == nullptr) { 5632be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "parameter is nullptr."; 5633be168c0dSopenharmony_ci+ return nullptr; 5634be168c0dSopenharmony_ci+ } 5635be168c0dSopenharmony_ci+ if (desc.data_type == kTypeUnknown) { 5636be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "desc data_type is unknown."; 5637be168c0dSopenharmony_ci+ } 5638be168c0dSopenharmony_ci+ LiteKernel *kernel{nullptr}; 5639be168c0dSopenharmony_ci+ if (inputs.size() == kMindirInputTensorNum) { 5640be168c0dSopenharmony_ci+ kernel = new (std::nothrow) 5641be168c0dSopenharmony_ci+ LstmMindirFp16CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 5642be168c0dSopenharmony_ci } else { 5643be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm."; 5644be168c0dSopenharmony_ci- return RET_ERROR; 5645be168c0dSopenharmony_ci- } 5646be168c0dSopenharmony_ci- 5647be168c0dSopenharmony_ci- // input bias 5648be168c0dSopenharmony_ci- auto bias = in_tensors_.at(FOURTH_INPUT); 5649be168c0dSopenharmony_ci- auto bias_data = bias->data(); 5650be168c0dSopenharmony_ci- CHECK_NULL_RETURN(bias_data); 5651be168c0dSopenharmony_ci- input_bias_ = 5652be168c0dSopenharmony_ci- reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t))); 5653be168c0dSopenharmony_ci- if (input_bias_ == nullptr) { 5654be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input_bias_ error."; 5655be168c0dSopenharmony_ci- return RET_ERROR; 5656be168c0dSopenharmony_ci- } 5657be168c0dSopenharmony_ci- memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float16_t)); 5658be168c0dSopenharmony_ci- if (bias->data_type() == kNumberTypeFloat32) { 5659be168c0dSopenharmony_ci- PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias_data), weight_batch_, lstm_param_->hidden_size_, 5660be168c0dSopenharmony_ci- lstm_param_->input_col_align_, lstm_param_->bidirectional_); 5661be168c0dSopenharmony_ci- } else if (bias->data_type() == kNumberTypeFloat16) { 5662be168c0dSopenharmony_ci- PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias_data), weight_batch_, lstm_param_->hidden_size_, 5663be168c0dSopenharmony_ci- lstm_param_->input_col_align_, lstm_param_->bidirectional_); 5664be168c0dSopenharmony_ci- } else { 5665be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 5666be168c0dSopenharmony_ci- return RET_ERROR; 5667be168c0dSopenharmony_ci- } 5668be168c0dSopenharmony_ci- return RET_OK; 5669be168c0dSopenharmony_ci-} 5670be168c0dSopenharmony_ci- 5671be168c0dSopenharmony_ci-int LstmFp16CPUKernel::InitStateWeightBias() { 5672be168c0dSopenharmony_ci- // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 5673be168c0dSopenharmony_ci- // state -- row: batch; col: hidden_size 5674be168c0dSopenharmony_ci- // weight -- row: hidden_size; col: hidden_size, need transpose 5675be168c0dSopenharmony_ci- // result -- row: batch; col: hidden_size 5676be168c0dSopenharmony_ci- auto weight_h = in_tensors_.at(THIRD_INPUT); 5677be168c0dSopenharmony_ci- auto weight_h_data = weight_h->data(); 5678be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_h_data); 5679be168c0dSopenharmony_ci- weight_h_ptr_ = reinterpret_cast<float16_t *>( 5680be168c0dSopenharmony_ci- malloc(weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * sizeof(float16_t))); 5681be168c0dSopenharmony_ci- if (weight_h_ptr_ == nullptr) { 5682be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_h_ptr_ error."; 5683be168c0dSopenharmony_ci- return RET_ERROR; 5684be168c0dSopenharmony_ci- } 5685be168c0dSopenharmony_ci- 5686be168c0dSopenharmony_ci- if (!is_vec_) { 5687be168c0dSopenharmony_ci- if (weight_h->data_type() == kNumberTypeFloat32) { 5688be168c0dSopenharmony_ci- PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h_data), weight_batch_, 5689be168c0dSopenharmony_ci- lstm_param_->project_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_); 5690be168c0dSopenharmony_ci- } else if (weight_h->data_type() == kNumberTypeFloat16) { 5691be168c0dSopenharmony_ci- PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_batch_, 5692be168c0dSopenharmony_ci- lstm_param_->project_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_); 5693be168c0dSopenharmony_ci- } else { 5694be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 5695be168c0dSopenharmony_ci- return RET_ERROR; 5696be168c0dSopenharmony_ci- } 5697be168c0dSopenharmony_ci- } else { 5698be168c0dSopenharmony_ci- if (weight_h->data_type() == kNumberTypeFloat32) { 5699be168c0dSopenharmony_ci- Float32ToFloat16(reinterpret_cast<float *>(weight_h_data), weight_h_ptr_, weight_h->ElementsNum()); 5700be168c0dSopenharmony_ci- } else if (weight_h->data_type() == kNumberTypeFloat16) { 5701be168c0dSopenharmony_ci- memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_h->Size()); 5702be168c0dSopenharmony_ci- } else { 5703be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 5704be168c0dSopenharmony_ci- return RET_ERROR; 5705be168c0dSopenharmony_ci- } 5706be168c0dSopenharmony_ci- } 5707be168c0dSopenharmony_ci- 5708be168c0dSopenharmony_ci- // state bias 5709be168c0dSopenharmony_ci- auto bias = in_tensors_.at(FOURTH_INPUT); 5710be168c0dSopenharmony_ci- auto bias_data = bias->data(); 5711be168c0dSopenharmony_ci- CHECK_NULL_RETURN(bias_data); 5712be168c0dSopenharmony_ci- state_bias_ = 5713be168c0dSopenharmony_ci- reinterpret_cast<float16_t *>(malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t))); 5714be168c0dSopenharmony_ci- if (state_bias_ == nullptr) { 5715be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_bias_ error."; 5716be168c0dSopenharmony_ci- return RET_ERROR; 5717be168c0dSopenharmony_ci- } 5718be168c0dSopenharmony_ci- memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float16_t)); 5719be168c0dSopenharmony_ci- if (bias->data_type() == kNumberTypeFloat32) { 5720be168c0dSopenharmony_ci- auto state_bias_data = reinterpret_cast<float *>(bias_data) + gate_num * lstm_param_->hidden_size_; 5721be168c0dSopenharmony_ci- PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_, 5722be168c0dSopenharmony_ci- lstm_param_->state_col_align_, lstm_param_->bidirectional_); 5723be168c0dSopenharmony_ci- } else if (bias->data_type() == kNumberTypeFloat16) { 5724be168c0dSopenharmony_ci- auto state_bias_data = reinterpret_cast<float16_t *>(bias_data) + gate_num * lstm_param_->hidden_size_; 5725be168c0dSopenharmony_ci- PackLstmBiasFp16(state_bias_, state_bias_data, weight_batch_, lstm_param_->hidden_size_, 5726be168c0dSopenharmony_ci- lstm_param_->state_col_align_, lstm_param_->bidirectional_); 5727be168c0dSopenharmony_ci- } else { 5728be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 5729be168c0dSopenharmony_ci- return RET_ERROR; 5730be168c0dSopenharmony_ci- } 5731be168c0dSopenharmony_ci- return RET_OK; 5732be168c0dSopenharmony_ci-} 5733be168c0dSopenharmony_ci- 5734be168c0dSopenharmony_ci-int LstmFp16CPUKernel::InitProjectWeight() { 5735be168c0dSopenharmony_ci- if (in_tensors_.size() < C7NUM) { 5736be168c0dSopenharmony_ci- return RET_OK; 5737be168c0dSopenharmony_ci- } 5738be168c0dSopenharmony_ci- auto weight_pro = in_tensors_.at(SEVENTH_INPUT); 5739be168c0dSopenharmony_ci- auto shape = weight_pro->shape(); 5740be168c0dSopenharmony_ci- if (shape.size() != C3NUM) { 5741be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Project-weight's shape must be 3D."; 5742be168c0dSopenharmony_ci- return RET_ERROR; 5743be168c0dSopenharmony_ci- } 5744be168c0dSopenharmony_ci- auto weight_pro_data = weight_pro->data(); 5745be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_pro_data); 5746be168c0dSopenharmony_ci- int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 5747be168c0dSopenharmony_ci- if (shape[0] != batch) { 5748be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 5749be168c0dSopenharmony_ci- return RET_ERROR; 5750be168c0dSopenharmony_ci+ kernel = new (std::nothrow) 5751be168c0dSopenharmony_ci+ LstmNonMindirFp16CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 5752be168c0dSopenharmony_ci } 5753be168c0dSopenharmony_ci- int pro_col_align = is_vec_ ? lstm_param_->project_size_ : UP_ROUND(lstm_param_->project_size_, C8NUM); 5754be168c0dSopenharmony_ci- weight_project_ptr_ = 5755be168c0dSopenharmony_ci- reinterpret_cast<float16_t *>(malloc(batch * lstm_param_->hidden_size_ * pro_col_align * sizeof(float16_t))); 5756be168c0dSopenharmony_ci- if (weight_project_ptr_ == nullptr) { 5757be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc weight_project_ptr_ error."; 5758be168c0dSopenharmony_ci- return RET_ERROR; 5759be168c0dSopenharmony_ci- } 5760be168c0dSopenharmony_ci- 5761be168c0dSopenharmony_ci- if (!is_vec_) { 5762be168c0dSopenharmony_ci- if (weight_pro->data_type() == kNumberTypeFloat32) { 5763be168c0dSopenharmony_ci- PackLstmWeightFp32ToFp16(weight_project_ptr_, reinterpret_cast<float *>(weight_pro_data), batch, 5764be168c0dSopenharmony_ci- lstm_param_->hidden_size_, lstm_param_->project_size_, pro_col_align); 5765be168c0dSopenharmony_ci- } else if (weight_pro->data_type() == kNumberTypeFloat16) { 5766be168c0dSopenharmony_ci- PackLstmWeightFp16(weight_project_ptr_, reinterpret_cast<float16_t *>(weight_pro_data), batch, 5767be168c0dSopenharmony_ci- lstm_param_->hidden_size_, lstm_param_->project_size_, pro_col_align); 5768be168c0dSopenharmony_ci- } else { 5769be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 5770be168c0dSopenharmony_ci- return RET_ERROR; 5771be168c0dSopenharmony_ci- } 5772be168c0dSopenharmony_ci- } else { 5773be168c0dSopenharmony_ci- if (weight_pro->data_type() == kNumberTypeFloat32) { 5774be168c0dSopenharmony_ci- Float32ToFloat16(reinterpret_cast<float *>(weight_pro_data), weight_project_ptr_, weight_pro->ElementsNum()); 5775be168c0dSopenharmony_ci- } else if (weight_pro->data_type() == kNumberTypeFloat16) { 5776be168c0dSopenharmony_ci- memcpy(weight_project_ptr_, weight_pro_data, weight_pro->Size()); 5777be168c0dSopenharmony_ci- } else { 5778be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 5779be168c0dSopenharmony_ci- return RET_ERROR; 5780be168c0dSopenharmony_ci- } 5781be168c0dSopenharmony_ci- } 5782be168c0dSopenharmony_ci- size_t bias_size = UP_ROUND(lstm_param_->project_size_, C8NUM) * sizeof(float16_t); 5783be168c0dSopenharmony_ci- project_bias_ = reinterpret_cast<float16_t *>(malloc(bias_size)); 5784be168c0dSopenharmony_ci- if (project_bias_ == nullptr) { 5785be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc project_bias_ error."; 5786be168c0dSopenharmony_ci- return RET_ERROR; 5787be168c0dSopenharmony_ci- } 5788be168c0dSopenharmony_ci- (void)memset(project_bias_, 0, bias_size); 5789be168c0dSopenharmony_ci- return RET_OK; 5790be168c0dSopenharmony_ci-} 5791be168c0dSopenharmony_ci- 5792be168c0dSopenharmony_ci-int LstmFp16CPUKernel::Prepare() { 5793be168c0dSopenharmony_ci- CHECK_LESS_RETURN(in_tensors_.size(), C6NUM); 5794be168c0dSopenharmony_ci- for (size_t i = 0; i < in_tensors_.size(); i++) { 5795be168c0dSopenharmony_ci- CHECK_NULL_RETURN(in_tensors_.at(i)); 5796be168c0dSopenharmony_ci- } 5797be168c0dSopenharmony_ci- CHECK_LESS_RETURN(out_tensors_.size(), C3NUM); 5798be168c0dSopenharmony_ci- for (size_t i = 0; i < out_tensors_.size(); i++) { 5799be168c0dSopenharmony_ci- CHECK_NULL_RETURN(out_tensors_.at(i)); 5800be168c0dSopenharmony_ci- } 5801be168c0dSopenharmony_ci- CHECK_NULL_RETURN(lstm_param_); 5802be168c0dSopenharmony_ci- if (!InferShapeDone()) { 5803be168c0dSopenharmony_ci- return RET_OK; 5804be168c0dSopenharmony_ci- } 5805be168c0dSopenharmony_ci- return ReSize(); 5806be168c0dSopenharmony_ci-} 5807be168c0dSopenharmony_ci- 5808be168c0dSopenharmony_ci-int LstmFp16CPUKernel::ReSize() { 5809be168c0dSopenharmony_ci- auto ret = InitParam(); 5810be168c0dSopenharmony_ci- if (ret != RET_OK) { 5811be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Lstm fp16 InitParam error."; 5812be168c0dSopenharmony_ci- return RET_ERROR; 5813be168c0dSopenharmony_ci- } 5814be168c0dSopenharmony_ci- 5815be168c0dSopenharmony_ci- FreeTmpBuffer(); 5816be168c0dSopenharmony_ci- ret = InitInputWeightBias(); 5817be168c0dSopenharmony_ci- if (ret != RET_OK) { 5818be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Lstm fp16 InitInputWeightBias error."; 5819be168c0dSopenharmony_ci- FreeTmpBuffer(); 5820be168c0dSopenharmony_ci- return RET_ERROR; 5821be168c0dSopenharmony_ci- } 5822be168c0dSopenharmony_ci- 5823be168c0dSopenharmony_ci- ret = InitStateWeightBias(); 5824be168c0dSopenharmony_ci- if (ret != RET_OK) { 5825be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Lstm fp16 InitStateWeightBias error."; 5826be168c0dSopenharmony_ci- FreeTmpBuffer(); 5827be168c0dSopenharmony_ci- return RET_ERROR; 5828be168c0dSopenharmony_ci- } 5829be168c0dSopenharmony_ci- 5830be168c0dSopenharmony_ci- ret = InitProjectWeight(); 5831be168c0dSopenharmony_ci- if (ret != RET_OK) { 5832be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Lstm fp16 InitProjectWeight error."; 5833be168c0dSopenharmony_ci- FreeTmpBuffer(); 5834be168c0dSopenharmony_ci- return RET_ERROR; 5835be168c0dSopenharmony_ci- } 5836be168c0dSopenharmony_ci- return RET_OK; 5837be168c0dSopenharmony_ci-} 5838be168c0dSopenharmony_ci- 5839be168c0dSopenharmony_ci-int LstmFp16CPUKernel::MallocRunBuffer() { 5840be168c0dSopenharmony_ci- for (int i = 0; i < C7NUM; i++) { 5841be168c0dSopenharmony_ci- buffer_[i] = nullptr; 5842be168c0dSopenharmony_ci- } 5843be168c0dSopenharmony_ci- buffer_[packed_input_index] = reinterpret_cast<float16_t *>( 5844be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 5845be168c0dSopenharmony_ci- if (buffer_[packed_input_index] == nullptr) { 5846be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; 5847be168c0dSopenharmony_ci- return RET_ERROR; 5848be168c0dSopenharmony_ci- } 5849be168c0dSopenharmony_ci- 5850be168c0dSopenharmony_ci- buffer_[input_gate_index] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc( 5851be168c0dSopenharmony_ci- gate_num * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 5852be168c0dSopenharmony_ci- if (buffer_[input_gate_index] == nullptr) { 5853be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 5854be168c0dSopenharmony_ci- return RET_ERROR; 5855be168c0dSopenharmony_ci- } 5856be168c0dSopenharmony_ci- 5857be168c0dSopenharmony_ci- if (!is_vec_) { 5858be168c0dSopenharmony_ci- buffer_[packed_state_index] = reinterpret_cast<float16_t *>( 5859be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->project_size_ * sizeof(float16_t))); 5860be168c0dSopenharmony_ci- if (buffer_[packed_state_index] == nullptr) { 5861be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 5862be168c0dSopenharmony_ci- return RET_ERROR; 5863be168c0dSopenharmony_ci- } 5864be168c0dSopenharmony_ci- } 5865be168c0dSopenharmony_ci- 5866be168c0dSopenharmony_ci- buffer_[state_gate_index] = reinterpret_cast<float16_t *>( 5867be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(gate_num * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 5868be168c0dSopenharmony_ci- if (buffer_[state_gate_index] == nullptr) { 5869be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state gate buffer_ error."; 5870be168c0dSopenharmony_ci- return RET_ERROR; 5871be168c0dSopenharmony_ci- } 5872be168c0dSopenharmony_ci- 5873be168c0dSopenharmony_ci- if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 5874be168c0dSopenharmony_ci- int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); 5875be168c0dSopenharmony_ci- buffer_[cell_state_index] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 5876be168c0dSopenharmony_ci- if (buffer_[cell_state_index] == nullptr) { 5877be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for cell error."; 5878be168c0dSopenharmony_ci- return RET_ERROR; 5879be168c0dSopenharmony_ci- } 5880be168c0dSopenharmony_ci- } 5881be168c0dSopenharmony_ci- if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 5882be168c0dSopenharmony_ci- int buffer_size = lstm_param_->batch_ * lstm_param_->project_size_ * sizeof(float16_t); 5883be168c0dSopenharmony_ci- buffer_[hidden_state_index] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 5884be168c0dSopenharmony_ci- if (buffer_[hidden_state_index] == nullptr) { 5885be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for hidden error."; 5886be168c0dSopenharmony_ci- return RET_ERROR; 5887be168c0dSopenharmony_ci- } 5888be168c0dSopenharmony_ci- } 5889be168c0dSopenharmony_ci- if (!is_vec_ && in_tensors_.size() == C7NUM) { 5890be168c0dSopenharmony_ci- buffer_[project_input_index] = reinterpret_cast<float16_t *>( 5891be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 5892be168c0dSopenharmony_ci- if (buffer_[project_input_index] == nullptr) { 5893be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel malloc project_buffer for hidden error."; 5894be168c0dSopenharmony_ci- return RET_ERROR; 5895be168c0dSopenharmony_ci- } 5896be168c0dSopenharmony_ci- } 5897be168c0dSopenharmony_ci- return RET_OK; 5898be168c0dSopenharmony_ci-} 5899be168c0dSopenharmony_ci- 5900be168c0dSopenharmony_ci-int LstmFp16CPUKernel::Run() { 5901be168c0dSopenharmony_ci- auto input = in_tensors_.at(0); 5902be168c0dSopenharmony_ci- auto input_ptr = reinterpret_cast<float16_t *>(input->data()); 5903be168c0dSopenharmony_ci- CHECK_NULL_RETURN(input_ptr); 5904be168c0dSopenharmony_ci- auto output = out_tensors_.at(0); 5905be168c0dSopenharmony_ci- auto output_ptr = reinterpret_cast<float16_t *>(output->data()); 5906be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output_ptr); 5907be168c0dSopenharmony_ci- 5908be168c0dSopenharmony_ci- auto hidden_state = in_tensors_.at(FIFTH_INPUT); 5909be168c0dSopenharmony_ci- CHECK_NULL_RETURN(hidden_state->data()); 5910be168c0dSopenharmony_ci- auto cell_state = in_tensors_.at(SIXTH_INPUT); 5911be168c0dSopenharmony_ci- CHECK_NULL_RETURN(cell_state->data()); 5912be168c0dSopenharmony_ci- 5913be168c0dSopenharmony_ci- auto output_hidden_state = out_tensors_[1]; 5914be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output_hidden_state->data()); 5915be168c0dSopenharmony_ci- memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float16_t)); 5916be168c0dSopenharmony_ci- auto output_cell_state = out_tensors_[THIRD_INPUT]; 5917be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output_cell_state->data()); 5918be168c0dSopenharmony_ci- memcpy(output_cell_state->data(), cell_state->data(), cell_state->ElementsNum() * sizeof(float16_t)); 5919be168c0dSopenharmony_ci- 5920be168c0dSopenharmony_ci- auto ret = MallocRunBuffer(); 5921be168c0dSopenharmony_ci- if (ret != RET_OK) { 5922be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmFp16CPUKernel MallocRunBuffer error."; 5923be168c0dSopenharmony_ci- FreeRunBuffer(); 5924be168c0dSopenharmony_ci- return RET_ERROR; 5925be168c0dSopenharmony_ci+ if (kernel == nullptr) { 5926be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr."; 5927be168c0dSopenharmony_ci+ free(parameter); 5928be168c0dSopenharmony_ci+ return nullptr; 5929be168c0dSopenharmony_ci } 5930be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_i_ptr_); 5931be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_h_ptr_); 5932be168c0dSopenharmony_ci- CHECK_NULL_RETURN(input_bias_); 5933be168c0dSopenharmony_ci- CHECK_NULL_RETURN(state_bias_); 5934be168c0dSopenharmony_ci- LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, weight_project_ptr_, 5935be168c0dSopenharmony_ci- project_bias_, reinterpret_cast<float16_t *>(output_hidden_state->data()), 5936be168c0dSopenharmony_ci- reinterpret_cast<float16_t *>(output_cell_state->data()), buffer_, lstm_param_); 5937be168c0dSopenharmony_ci- FreeRunBuffer(); 5938be168c0dSopenharmony_ci- return RET_OK; 5939be168c0dSopenharmony_ci+ return kernel; 5940be168c0dSopenharmony_ci } 5941be168c0dSopenharmony_ci 5942be168c0dSopenharmony_ci-REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LSTM, LiteKernelCreator<LstmFp16CPUKernel>) 5943be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LSTM, LstmFp16KernelCreator) 5944be168c0dSopenharmony_ci } // namespace mindspore::kernel 5945be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc 5946be168c0dSopenharmony_cinew file mode 100644 5947be168c0dSopenharmony_ciindex 00000000..767fdef3 5948be168c0dSopenharmony_ci--- /dev/null 5949be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.cc 5950be168c0dSopenharmony_ci@@ -0,0 +1,270 @@ 5951be168c0dSopenharmony_ci+/** 5952be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 5953be168c0dSopenharmony_ci+ * 5954be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 5955be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 5956be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 5957be168c0dSopenharmony_ci+ * 5958be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 5959be168c0dSopenharmony_ci+ * 5960be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 5961be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 5962be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5963be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 5964be168c0dSopenharmony_ci+ * limitations under the License. 5965be168c0dSopenharmony_ci+ */ 5966be168c0dSopenharmony_ci+ 5967be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" 5968be168c0dSopenharmony_ci+#include <cfloat> 5969be168c0dSopenharmony_ci+#include "nnacl/fp16/lstm_fp16.h" 5970be168c0dSopenharmony_ci+ 5971be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 5972be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 5973be168c0dSopenharmony_ci+ 5974be168c0dSopenharmony_ci+namespace mindspore::kernel { 5975be168c0dSopenharmony_ci+namespace { 5976be168c0dSopenharmony_ci+constexpr int kGateNum = 4; 5977be168c0dSopenharmony_ci+constexpr int kTempInputBufferIndex = 0; 5978be168c0dSopenharmony_ci+constexpr int kTempInputGateBufferIndex = 1; 5979be168c0dSopenharmony_ci+constexpr int kTempStateBufferIndex = 2; 5980be168c0dSopenharmony_ci+constexpr int kTempStateGateBufferIndex = 3; 5981be168c0dSopenharmony_ci+constexpr int kTempCellStateBufferIndex = 4; 5982be168c0dSopenharmony_ci+constexpr int kTempHiddenStateBufferIndex = 5; 5983be168c0dSopenharmony_ci+constexpr int kTempProjectInputBufferIndex = 6; 5984be168c0dSopenharmony_ci+} // namespace 5985be168c0dSopenharmony_ci+ 5986be168c0dSopenharmony_ci+LstmFp16BaseCPUKernel::~LstmFp16BaseCPUKernel() { FreePackBuffer(); } 5987be168c0dSopenharmony_ci+ 5988be168c0dSopenharmony_ci+int LstmFp16BaseCPUKernel::Prepare() { 5989be168c0dSopenharmony_ci+ for (size_t i = 0; i < in_tensors_.size(); ++i) { 5990be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(in_tensors_[i]); 5991be168c0dSopenharmony_ci+ } 5992be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(out_tensors_.size(), C3NUM); 5993be168c0dSopenharmony_ci+ for (size_t i = 0; i < out_tensors_.size(); ++i) { 5994be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(out_tensors_[i]); 5995be168c0dSopenharmony_ci+ } 5996be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(lstm_param_); 5997be168c0dSopenharmony_ci+ if (!InferShapeDone()) { 5998be168c0dSopenharmony_ci+ return RET_OK; 5999be168c0dSopenharmony_ci+ } 6000be168c0dSopenharmony_ci+ return ReSize(); 6001be168c0dSopenharmony_ci+} 6002be168c0dSopenharmony_ci+ 6003be168c0dSopenharmony_ci+int LstmFp16BaseCPUKernel::ReSize() { 6004be168c0dSopenharmony_ci+ auto ret = InitParam(); 6005be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6006be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16 InitParam failed."; 6007be168c0dSopenharmony_ci+ return RET_ERROR; 6008be168c0dSopenharmony_ci+ } 6009be168c0dSopenharmony_ci+ if (running_pack_) { 6010be168c0dSopenharmony_ci+ return RET_OK; 6011be168c0dSopenharmony_ci+ } 6012be168c0dSopenharmony_ci+ return PackWeightAndBias(); 6013be168c0dSopenharmony_ci+} 6014be168c0dSopenharmony_ci+ 6015be168c0dSopenharmony_ci+int LstmFp16BaseCPUKernel::Run() { 6016be168c0dSopenharmony_ci+ auto input_ptr = reinterpret_cast<float16_t *>(in_tensors_[FIRST_INPUT]->data()); 6017be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input_ptr); 6018be168c0dSopenharmony_ci+ auto output_ptr = reinterpret_cast<float16_t *>(out_tensors_[FIRST_INPUT]->data()); 6019be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_ptr); 6020be168c0dSopenharmony_ci+ 6021be168c0dSopenharmony_ci+ auto hidden_init = in_tensors_[hidden_init_index_]->data(); 6022be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(hidden_init); 6023be168c0dSopenharmony_ci+ auto cell_init = in_tensors_[cell_init_index_]->data(); 6024be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(cell_init); 6025be168c0dSopenharmony_ci+ 6026be168c0dSopenharmony_ci+ auto output_hidden = out_tensors_[SECOND_INPUT]->data(); 6027be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_hidden); 6028be168c0dSopenharmony_ci+ (void)memcpy(output_hidden, hidden_init, in_tensors_[hidden_init_index_]->ElementsNum() * sizeof(float16_t)); 6029be168c0dSopenharmony_ci+ auto output_cell = out_tensors_[THIRD_INPUT]->data(); 6030be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_cell); 6031be168c0dSopenharmony_ci+ (void)memcpy(output_cell, cell_init, in_tensors_[cell_init_index_]->ElementsNum() * sizeof(float16_t)); 6032be168c0dSopenharmony_ci+ 6033be168c0dSopenharmony_ci+ if (running_pack_) { 6034be168c0dSopenharmony_ci+ auto ret = PackWeightAndBias(); 6035be168c0dSopenharmony_ci+ if (ret != lite::RET_OK) { 6036be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16 PackWeightAndBias failed."; 6037be168c0dSopenharmony_ci+ return ret; 6038be168c0dSopenharmony_ci+ } 6039be168c0dSopenharmony_ci+ } 6040be168c0dSopenharmony_ci+ auto ret = MallocRunBuffer(); 6041be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6042be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel MallocRunBuffer error."; 6043be168c0dSopenharmony_ci+ FreeRunBuffer(); 6044be168c0dSopenharmony_ci+ if (running_pack_) { 6045be168c0dSopenharmony_ci+ FreePackBuffer(); 6046be168c0dSopenharmony_ci+ } 6047be168c0dSopenharmony_ci+ return RET_ERROR; 6048be168c0dSopenharmony_ci+ } 6049be168c0dSopenharmony_ci+ LstmFp16(output_ptr, input_ptr, weight_i_ptr_, weight_h_ptr_, input_bias_, state_bias_, weight_project_ptr_, 6050be168c0dSopenharmony_ci+ project_bias_, reinterpret_cast<float16_t *>(output_hidden), reinterpret_cast<float16_t *>(output_cell), 6051be168c0dSopenharmony_ci+ running_buffer_, lstm_param_); 6052be168c0dSopenharmony_ci+ FreeRunBuffer(); 6053be168c0dSopenharmony_ci+ if (running_pack_) { 6054be168c0dSopenharmony_ci+ FreePackBuffer(); 6055be168c0dSopenharmony_ci+ } 6056be168c0dSopenharmony_ci+ return RET_OK; 6057be168c0dSopenharmony_ci+} 6058be168c0dSopenharmony_ci+ 6059be168c0dSopenharmony_ci+int LstmFp16BaseCPUKernel::InitParam() { 6060be168c0dSopenharmony_ci+ auto in_shape = in_tensors_[FIRST_INPUT]->shape(); 6061be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in_shape.size() == C3NUM, lite::RET_INPUT_TENSOR_ERROR, 6062be168c0dSopenharmony_ci+ "The dims of LSTM's first input must be 3."); 6063be168c0dSopenharmony_ci+ lstm_param_->seq_len_ = in_shape[0]; 6064be168c0dSopenharmony_ci+ lstm_param_->batch_ = in_shape[1]; 6065be168c0dSopenharmony_ci+ lstm_param_->input_size_ = in_shape.back(); 6066be168c0dSopenharmony_ci+ 6067be168c0dSopenharmony_ci+ auto h_init_shape = in_tensors_.at(hidden_init_index_)->shape(); 6068be168c0dSopenharmony_ci+ auto c_init_shape = in_tensors_.at(cell_init_index_)->shape(); 6069be168c0dSopenharmony_ci+ lstm_param_->hidden_size_ = c_init_shape.back(); 6070be168c0dSopenharmony_ci+ lstm_param_->output_size_ = h_init_shape.back(); 6071be168c0dSopenharmony_ci+ 6072be168c0dSopenharmony_ci+ lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->output_size_ 6073be168c0dSopenharmony_ci+ : lstm_param_->batch_ * lstm_param_->output_size_; 6074be168c0dSopenharmony_ci+ weight_segment_num_ = lstm_param_->bidirectional_ ? C2NUM * kGateNum : kGateNum; 6075be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 6076be168c0dSopenharmony_ci+ lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C1NUM); 6077be168c0dSopenharmony_ci+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 6078be168c0dSopenharmony_ci+ 6079be168c0dSopenharmony_ci+ lstm_param_->state_row_align_ = UP_ROUND(lstm_param_->batch_, C1NUM); 6080be168c0dSopenharmony_ci+ lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 6081be168c0dSopenharmony_ci+ lstm_param_->proj_col_align_ = UP_ROUND(lstm_param_->output_size_, C4NUM); 6082be168c0dSopenharmony_ci+ weight_need_pack_ = true; 6083be168c0dSopenharmony_ci+#else 6084be168c0dSopenharmony_ci+ lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, C16NUM); 6085be168c0dSopenharmony_ci+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C8NUM); 6086be168c0dSopenharmony_ci+ 6087be168c0dSopenharmony_ci+ lstm_param_->state_row_align_ = 6088be168c0dSopenharmony_ci+ lstm_param_->batch_ == 1 ? lstm_param_->batch_ : UP_ROUND(lstm_param_->batch_, C16NUM); 6089be168c0dSopenharmony_ci+ lstm_param_->state_col_align_ = 6090be168c0dSopenharmony_ci+ lstm_param_->batch_ == 1 ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, C8NUM); 6091be168c0dSopenharmony_ci+ lstm_param_->proj_col_align_ = 6092be168c0dSopenharmony_ci+ lstm_param_->batch_ == 1 ? lstm_param_->output_size_ : UP_ROUND(lstm_param_->output_size_, C8NUM); 6093be168c0dSopenharmony_ci+ weight_need_pack_ = lstm_param_->batch_ != 1; 6094be168c0dSopenharmony_ci+#endif 6095be168c0dSopenharmony_ci+ return RET_OK; 6096be168c0dSopenharmony_ci+} 6097be168c0dSopenharmony_ci+ 6098be168c0dSopenharmony_ci+int LstmFp16BaseCPUKernel::PackWeightAndBias() { 6099be168c0dSopenharmony_ci+ FreePackBuffer(); 6100be168c0dSopenharmony_ci+ auto ret = InitInputWeightBias(); 6101be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6102be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16 InitInputWeightBias failed."; 6103be168c0dSopenharmony_ci+ FreePackBuffer(); 6104be168c0dSopenharmony_ci+ return RET_ERROR; 6105be168c0dSopenharmony_ci+ } 6106be168c0dSopenharmony_ci+ 6107be168c0dSopenharmony_ci+ ret = InitStateWeightBias(); 6108be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6109be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16 InitStateWeightBias failed."; 6110be168c0dSopenharmony_ci+ FreePackBuffer(); 6111be168c0dSopenharmony_ci+ return RET_ERROR; 6112be168c0dSopenharmony_ci+ } 6113be168c0dSopenharmony_ci+ 6114be168c0dSopenharmony_ci+ ret = InitProjectWeight(); 6115be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6116be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16 InitProjectWeight failed."; 6117be168c0dSopenharmony_ci+ FreePackBuffer(); 6118be168c0dSopenharmony_ci+ return RET_ERROR; 6119be168c0dSopenharmony_ci+ } 6120be168c0dSopenharmony_ci+ return RET_OK; 6121be168c0dSopenharmony_ci+} 6122be168c0dSopenharmony_ci+ 6123be168c0dSopenharmony_ci+void LstmFp16BaseCPUKernel::FreePackBuffer() { 6124be168c0dSopenharmony_ci+ for (auto buffer : pack_buffer_) { 6125be168c0dSopenharmony_ci+ if (buffer) { 6126be168c0dSopenharmony_ci+ free(buffer); 6127be168c0dSopenharmony_ci+ } 6128be168c0dSopenharmony_ci+ } 6129be168c0dSopenharmony_ci+ pack_buffer_.clear(); 6130be168c0dSopenharmony_ci+} 6131be168c0dSopenharmony_ci+ 6132be168c0dSopenharmony_ci+int LstmFp16BaseCPUKernel::MallocRunBuffer() { 6133be168c0dSopenharmony_ci+ for (int i = 0; i < C7NUM; i++) { 6134be168c0dSopenharmony_ci+ running_buffer_[i] = nullptr; 6135be168c0dSopenharmony_ci+ } 6136be168c0dSopenharmony_ci+ bool need_pack_input = true; 6137be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 6138be168c0dSopenharmony_ci+ need_pack_input = lstm_param_->seq_len_ * lstm_param_->batch_ >= C4NUM; 6139be168c0dSopenharmony_ci+#endif 6140be168c0dSopenharmony_ci+ if (need_pack_input) { 6141be168c0dSopenharmony_ci+ running_buffer_[kTempInputBufferIndex] = reinterpret_cast<float16_t *>( 6142be168c0dSopenharmony_ci+ ms_context_->allocator->Malloc(lstm_param_->input_row_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 6143be168c0dSopenharmony_ci+ if (running_buffer_[kTempInputBufferIndex] == nullptr) { 6144be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc input * weight left matirx error."; 6145be168c0dSopenharmony_ci+ return RET_ERROR; 6146be168c0dSopenharmony_ci+ } 6147be168c0dSopenharmony_ci+ } 6148be168c0dSopenharmony_ci+ 6149be168c0dSopenharmony_ci+ running_buffer_[kTempInputGateBufferIndex] = reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc( 6150be168c0dSopenharmony_ci+ kGateNum * lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 6151be168c0dSopenharmony_ci+ if (running_buffer_[kTempInputGateBufferIndex] == nullptr) { 6152be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 6153be168c0dSopenharmony_ci+ return RET_ERROR; 6154be168c0dSopenharmony_ci+ } 6155be168c0dSopenharmony_ci+ 6156be168c0dSopenharmony_ci+ need_pack_input = lstm_param_->batch_ != 1; 6157be168c0dSopenharmony_ci+#ifdef ENABLE_ARM64 6158be168c0dSopenharmony_ci+ need_pack_input = lstm_param_->batch_ >= C4NUM; 6159be168c0dSopenharmony_ci+#endif 6160be168c0dSopenharmony_ci+ if (need_pack_input) { 6161be168c0dSopenharmony_ci+ running_buffer_[kTempStateBufferIndex] = reinterpret_cast<float16_t *>( 6162be168c0dSopenharmony_ci+ ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->output_size_ * sizeof(float16_t))); 6163be168c0dSopenharmony_ci+ if (running_buffer_[kTempStateBufferIndex] == nullptr) { 6164be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state * weight left matirx error."; 6165be168c0dSopenharmony_ci+ return RET_ERROR; 6166be168c0dSopenharmony_ci+ } 6167be168c0dSopenharmony_ci+ } 6168be168c0dSopenharmony_ci+ 6169be168c0dSopenharmony_ci+ running_buffer_[kTempStateGateBufferIndex] = reinterpret_cast<float16_t *>( 6170be168c0dSopenharmony_ci+ ms_context_->allocator->Malloc(kGateNum * lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 6171be168c0dSopenharmony_ci+ if (running_buffer_[kTempStateGateBufferIndex] == nullptr) { 6172be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state gate buffer_ error."; 6173be168c0dSopenharmony_ci+ return RET_ERROR; 6174be168c0dSopenharmony_ci+ } 6175be168c0dSopenharmony_ci+ 6176be168c0dSopenharmony_ci+ if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 6177be168c0dSopenharmony_ci+ int buffer_size = lstm_param_->batch_ * lstm_param_->hidden_size_ * sizeof(float16_t); 6178be168c0dSopenharmony_ci+ running_buffer_[kTempCellStateBufferIndex] = 6179be168c0dSopenharmony_ci+ reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 6180be168c0dSopenharmony_ci+ if (running_buffer_[kTempCellStateBufferIndex] == nullptr) { 6181be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for cell error."; 6182be168c0dSopenharmony_ci+ return RET_ERROR; 6183be168c0dSopenharmony_ci+ } 6184be168c0dSopenharmony_ci+ } 6185be168c0dSopenharmony_ci+ if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 6186be168c0dSopenharmony_ci+ int buffer_size = lstm_param_->batch_ * lstm_param_->output_size_ * sizeof(float16_t); 6187be168c0dSopenharmony_ci+ running_buffer_[kTempHiddenStateBufferIndex] = 6188be168c0dSopenharmony_ci+ reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(buffer_size)); 6189be168c0dSopenharmony_ci+ if (running_buffer_[kTempHiddenStateBufferIndex] == nullptr) { 6190be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc state_buffer for hidden error."; 6191be168c0dSopenharmony_ci+ return RET_ERROR; 6192be168c0dSopenharmony_ci+ } 6193be168c0dSopenharmony_ci+ } 6194be168c0dSopenharmony_ci+ 6195be168c0dSopenharmony_ci+ if (need_pack_input && in_tensors_.size() == C7NUM) { 6196be168c0dSopenharmony_ci+ running_buffer_[kTempProjectInputBufferIndex] = reinterpret_cast<float16_t *>( 6197be168c0dSopenharmony_ci+ ms_context_->allocator->Malloc(lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * sizeof(float16_t))); 6198be168c0dSopenharmony_ci+ if (running_buffer_[kTempProjectInputBufferIndex] == nullptr) { 6199be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmFp16CPUKernel malloc project_buffer for hidden error."; 6200be168c0dSopenharmony_ci+ return RET_ERROR; 6201be168c0dSopenharmony_ci+ } 6202be168c0dSopenharmony_ci+ } 6203be168c0dSopenharmony_ci+ return RET_OK; 6204be168c0dSopenharmony_ci+} 6205be168c0dSopenharmony_ci+ 6206be168c0dSopenharmony_ci+void LstmFp16BaseCPUKernel::FreeRunBuffer() { 6207be168c0dSopenharmony_ci+ ms_context_->allocator->Free(running_buffer_[kTempInputBufferIndex]); 6208be168c0dSopenharmony_ci+ ms_context_->allocator->Free(running_buffer_[kTempInputGateBufferIndex]); 6209be168c0dSopenharmony_ci+ if (lstm_param_->batch_ != 1) { 6210be168c0dSopenharmony_ci+ ms_context_->allocator->Free(running_buffer_[kTempStateBufferIndex]); 6211be168c0dSopenharmony_ci+ } 6212be168c0dSopenharmony_ci+ ms_context_->allocator->Free(running_buffer_[kTempStateGateBufferIndex]); 6213be168c0dSopenharmony_ci+ if (!(lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON)) { 6214be168c0dSopenharmony_ci+ ms_context_->allocator->Free(running_buffer_[kTempCellStateBufferIndex]); 6215be168c0dSopenharmony_ci+ } 6216be168c0dSopenharmony_ci+ if (!(lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON)) { 6217be168c0dSopenharmony_ci+ ms_context_->allocator->Free(running_buffer_[kTempHiddenStateBufferIndex]); 6218be168c0dSopenharmony_ci+ } 6219be168c0dSopenharmony_ci+} 6220be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6221be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h 6222be168c0dSopenharmony_cinew file mode 100644 6223be168c0dSopenharmony_ciindex 00000000..0bcb9e94 6224be168c0dSopenharmony_ci--- /dev/null 6225be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_fp16_base.h 6226be168c0dSopenharmony_ci@@ -0,0 +1,68 @@ 6227be168c0dSopenharmony_ci+/** 6228be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6229be168c0dSopenharmony_ci+ * 6230be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6231be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6232be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6233be168c0dSopenharmony_ci+ * 6234be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6235be168c0dSopenharmony_ci+ * 6236be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6237be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6238be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6239be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6240be168c0dSopenharmony_ci+ * limitations under the License. 6241be168c0dSopenharmony_ci+ */ 6242be168c0dSopenharmony_ci+ 6243be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_FP16_BASE_H_ 6244be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_FP16_BASE_H_ 6245be168c0dSopenharmony_ci+ 6246be168c0dSopenharmony_ci+#include <vector> 6247be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h" 6248be168c0dSopenharmony_ci+#include "nnacl/lstm_parameter.h" 6249be168c0dSopenharmony_ci+ 6250be168c0dSopenharmony_ci+namespace mindspore::kernel { 6251be168c0dSopenharmony_ci+class LstmFp16BaseCPUKernel : public LiteKernel { 6252be168c0dSopenharmony_ci+ public: 6253be168c0dSopenharmony_ci+ LstmFp16BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6254be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6255be168c0dSopenharmony_ci+ : LiteKernel(parameter, inputs, outputs, ctx) { 6256be168c0dSopenharmony_ci+ lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_); 6257be168c0dSopenharmony_ci+ } 6258be168c0dSopenharmony_ci+ 6259be168c0dSopenharmony_ci+ ~LstmFp16BaseCPUKernel() override; 6260be168c0dSopenharmony_ci+ 6261be168c0dSopenharmony_ci+ int Prepare() override; 6262be168c0dSopenharmony_ci+ int ReSize() override; 6263be168c0dSopenharmony_ci+ int Run() override; 6264be168c0dSopenharmony_ci+ 6265be168c0dSopenharmony_ci+ protected: 6266be168c0dSopenharmony_ci+ virtual int InitInputWeightBias() = 0; 6267be168c0dSopenharmony_ci+ virtual int InitStateWeightBias() = 0; 6268be168c0dSopenharmony_ci+ virtual int InitProjectWeight() = 0; 6269be168c0dSopenharmony_ci+ 6270be168c0dSopenharmony_ci+ bool running_pack_{false}; 6271be168c0dSopenharmony_ci+ bool weight_need_pack_{false}; 6272be168c0dSopenharmony_ci+ int hidden_init_index_{0}; 6273be168c0dSopenharmony_ci+ int cell_init_index_{0}; 6274be168c0dSopenharmony_ci+ int weight_segment_num_{0}; 6275be168c0dSopenharmony_ci+ float16_t *weight_i_ptr_{nullptr}; 6276be168c0dSopenharmony_ci+ float16_t *weight_h_ptr_{nullptr}; 6277be168c0dSopenharmony_ci+ float16_t *weight_project_ptr_{nullptr}; 6278be168c0dSopenharmony_ci+ float16_t *input_bias_{nullptr}; 6279be168c0dSopenharmony_ci+ float16_t *state_bias_{nullptr}; 6280be168c0dSopenharmony_ci+ float16_t *project_bias_{nullptr}; 6281be168c0dSopenharmony_ci+ LstmParameter *lstm_param_{nullptr}; 6282be168c0dSopenharmony_ci+ float16_t *running_buffer_[C7NUM] = {nullptr}; 6283be168c0dSopenharmony_ci+ std::vector<void *> pack_buffer_; 6284be168c0dSopenharmony_ci+ 6285be168c0dSopenharmony_ci+ private: 6286be168c0dSopenharmony_ci+ int PackWeightAndBias(); 6287be168c0dSopenharmony_ci+ int InitParam(); 6288be168c0dSopenharmony_ci+ void FreePackBuffer(); 6289be168c0dSopenharmony_ci+ void FreeRunBuffer(); 6290be168c0dSopenharmony_ci+ int MallocRunBuffer(); 6291be168c0dSopenharmony_ci+}; 6292be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6293be168c0dSopenharmony_ci+ 6294be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_FP16_BASE_H_ 6295be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc 6296be168c0dSopenharmony_cinew file mode 100644 6297be168c0dSopenharmony_ciindex 00000000..cf4071eb 6298be168c0dSopenharmony_ci--- /dev/null 6299be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.cc 6300be168c0dSopenharmony_ci@@ -0,0 +1,35 @@ 6301be168c0dSopenharmony_ci+/** 6302be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6303be168c0dSopenharmony_ci+ * 6304be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6305be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6306be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6307be168c0dSopenharmony_ci+ * 6308be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6309be168c0dSopenharmony_ci+ * 6310be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6311be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6312be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6313be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6314be168c0dSopenharmony_ci+ * limitations under the License. 6315be168c0dSopenharmony_ci+ */ 6316be168c0dSopenharmony_ci+ 6317be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h" 6318be168c0dSopenharmony_ci+ 6319be168c0dSopenharmony_ci+namespace mindspore::kernel { 6320be168c0dSopenharmony_ci+namespace { 6321be168c0dSopenharmony_ci+constexpr size_t kMindirInputTensorNum = 4; 6322be168c0dSopenharmony_ci+} // namespace 6323be168c0dSopenharmony_ci+ 6324be168c0dSopenharmony_ci+int LstmMindirFp16CPUKernel::Prepare() { 6325be168c0dSopenharmony_ci+ CHECK_NOT_EQUAL_RETURN(in_tensors_.size(), kMindirInputTensorNum); 6326be168c0dSopenharmony_ci+ running_pack_ = trainable_ || !in_tensors_[FOURTH_INPUT]->IsConst(); 6327be168c0dSopenharmony_ci+ return LstmFp16BaseCPUKernel::Prepare(); 6328be168c0dSopenharmony_ci+} 6329be168c0dSopenharmony_ci+ 6330be168c0dSopenharmony_ci+int LstmMindirFp16CPUKernel::InitInputWeightBias() { return lite::RET_NOT_SUPPORT; } 6331be168c0dSopenharmony_ci+ 6332be168c0dSopenharmony_ci+int LstmMindirFp16CPUKernel::InitStateWeightBias() { return lite::RET_NOT_SUPPORT; } 6333be168c0dSopenharmony_ci+ 6334be168c0dSopenharmony_ci+int LstmMindirFp16CPUKernel::InitProjectWeight() { return lite::RET_NOT_SUPPORT; } 6335be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6336be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h 6337be168c0dSopenharmony_cinew file mode 100644 6338be168c0dSopenharmony_ciindex 00000000..bd8500d0 6339be168c0dSopenharmony_ci--- /dev/null 6340be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_mindir_fp16.h 6341be168c0dSopenharmony_ci@@ -0,0 +1,56 @@ 6342be168c0dSopenharmony_ci+/** 6343be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6344be168c0dSopenharmony_ci+ * 6345be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6346be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6347be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6348be168c0dSopenharmony_ci+ * 6349be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6350be168c0dSopenharmony_ci+ * 6351be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6352be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6353be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6354be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6355be168c0dSopenharmony_ci+ * limitations under the License. 6356be168c0dSopenharmony_ci+ */ 6357be168c0dSopenharmony_ci+ 6358be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_MINDIR_FP16_H_ 6359be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_MINDIR_FP16_H_ 6360be168c0dSopenharmony_ci+ 6361be168c0dSopenharmony_ci+#include <vector> 6362be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" 6363be168c0dSopenharmony_ci+ 6364be168c0dSopenharmony_ci+namespace mindspore::kernel { 6365be168c0dSopenharmony_ci+/* 6366be168c0dSopenharmony_ci+ * 1. LSTM without project, output_size = hidden_size 6367be168c0dSopenharmony_ci+ * h_init: second input, shape is [bidirectional, batch_size, hidden_size] 6368be168c0dSopenharmony_ci+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 6369be168c0dSopenharmony_ci+ * weight_bias: forth input, weight_ih + weight_hh + bias, the gate order is IFGO 6370be168c0dSopenharmony_ci+ * 6371be168c0dSopenharmony_ci+ * 2. LSTM with project, output_size = project_size 6372be168c0dSopenharmony_ci+ * don't support 6373be168c0dSopenharmony_ci+ * h_init: second input, shape is [bidirectional, batch_size, hidden_size] 6374be168c0dSopenharmony_ci+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 6375be168c0dSopenharmony_ci+ * weight_bias: forth input, weight_ih + weight_hh + proj + bias, the gate order is IFGO 6376be168c0dSopenharmony_ci+ */ 6377be168c0dSopenharmony_ci+class LstmMindirFp16CPUKernel : public LstmFp16BaseCPUKernel { 6378be168c0dSopenharmony_ci+ public: 6379be168c0dSopenharmony_ci+ LstmMindirFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6380be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6381be168c0dSopenharmony_ci+ : LstmFp16BaseCPUKernel(parameter, inputs, outputs, ctx) { 6382be168c0dSopenharmony_ci+ hidden_init_index_ = SECOND_INPUT; 6383be168c0dSopenharmony_ci+ cell_init_index_ = THIRD_INPUT; 6384be168c0dSopenharmony_ci+ } 6385be168c0dSopenharmony_ci+ 6386be168c0dSopenharmony_ci+ ~LstmMindirFp16CPUKernel() override = default; 6387be168c0dSopenharmony_ci+ 6388be168c0dSopenharmony_ci+ int Prepare() override; 6389be168c0dSopenharmony_ci+ 6390be168c0dSopenharmony_ci+ protected: 6391be168c0dSopenharmony_ci+ int InitInputWeightBias() override; 6392be168c0dSopenharmony_ci+ int InitStateWeightBias() override; 6393be168c0dSopenharmony_ci+ int InitProjectWeight() override; 6394be168c0dSopenharmony_ci+}; 6395be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6396be168c0dSopenharmony_ci+ 6397be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_MINDIR_FP16_H_ 6398be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc 6399be168c0dSopenharmony_cinew file mode 100644 6400be168c0dSopenharmony_ciindex 00000000..473fe9b0 6401be168c0dSopenharmony_ci--- /dev/null 6402be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.cc 6403be168c0dSopenharmony_ci@@ -0,0 +1,194 @@ 6404be168c0dSopenharmony_ci+/** 6405be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6406be168c0dSopenharmony_ci+ * 6407be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6408be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6409be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6410be168c0dSopenharmony_ci+ * 6411be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6412be168c0dSopenharmony_ci+ * 6413be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6414be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6415be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6416be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6417be168c0dSopenharmony_ci+ * limitations under the License. 6418be168c0dSopenharmony_ci+ */ 6419be168c0dSopenharmony_ci+ 6420be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h" 6421be168c0dSopenharmony_ci+#include "nnacl/fp16/lstm_fp16.h" 6422be168c0dSopenharmony_ci+#include "nnacl/fp16/cast_fp16.h" 6423be168c0dSopenharmony_ci+ 6424be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 6425be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 6426be168c0dSopenharmony_ci+ 6427be168c0dSopenharmony_ci+namespace mindspore::kernel { 6428be168c0dSopenharmony_ci+namespace { 6429be168c0dSopenharmony_ci+constexpr int kGateNum = 4; 6430be168c0dSopenharmony_ci+constexpr size_t kInputTensorNumMin = 6; 6431be168c0dSopenharmony_ci+} // namespace 6432be168c0dSopenharmony_ci+ 6433be168c0dSopenharmony_ci+int LstmNonMindirFp16CPUKernel::Prepare() { 6434be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(in_tensors_.size(), kInputTensorNumMin); 6435be168c0dSopenharmony_ci+ running_pack_ = train_mode_; 6436be168c0dSopenharmony_ci+ for (size_t i = 1; i <= FOURTH_INPUT; ++i) { 6437be168c0dSopenharmony_ci+ running_pack_ = running_pack_ || !in_tensors_[i]->IsConst(); 6438be168c0dSopenharmony_ci+ } 6439be168c0dSopenharmony_ci+ return LstmFp16BaseCPUKernel::Prepare(); 6440be168c0dSopenharmony_ci+} 6441be168c0dSopenharmony_ci+ 6442be168c0dSopenharmony_ci+int LstmNonMindirFp16CPUKernel::InitInputWeightBias() { 6443be168c0dSopenharmony_ci+ // malloc and init input * weight right matrix buffer 6444be168c0dSopenharmony_ci+ // input -- row: seq_len * batch; col: input_size 6445be168c0dSopenharmony_ci+ // weight -- row: hidden_size; col: input_size, need transpose 6446be168c0dSopenharmony_ci+ // result -- row: seq_len * batch; col: hidden_size 6447be168c0dSopenharmony_ci+ auto weight_i = in_tensors_.at(1); 6448be168c0dSopenharmony_ci+ auto weight_i_data = weight_i->data(); 6449be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_i_data); 6450be168c0dSopenharmony_ci+ weight_i_ptr_ = reinterpret_cast<float16_t *>( 6451be168c0dSopenharmony_ci+ malloc(weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float16_t))); 6452be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_i_ptr_ != nullptr, lite::RET_NULL_PTR, 6453be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_i_ptr_ failed."); 6454be168c0dSopenharmony_ci+ pack_buffer_.push_back(weight_i_ptr_); 6455be168c0dSopenharmony_ci+ if (weight_i->data_type() == kNumberTypeFloat32) { 6456be168c0dSopenharmony_ci+ PackLstmWeightFp32ToFp16(weight_i_ptr_, reinterpret_cast<float *>(weight_i_data), weight_segment_num_, 6457be168c0dSopenharmony_ci+ lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 6458be168c0dSopenharmony_ci+ nullptr); 6459be168c0dSopenharmony_ci+ } else if (weight_i->data_type() == kNumberTypeFloat16) { 6460be168c0dSopenharmony_ci+ PackLstmWeightFp16(weight_i_ptr_, reinterpret_cast<float16_t *>(weight_i_data), weight_segment_num_, 6461be168c0dSopenharmony_ci+ lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, nullptr); 6462be168c0dSopenharmony_ci+ } else { 6463be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of weight_i tensor for lstm."; 6464be168c0dSopenharmony_ci+ return RET_ERROR; 6465be168c0dSopenharmony_ci+ } 6466be168c0dSopenharmony_ci+ 6467be168c0dSopenharmony_ci+ // input bias 6468be168c0dSopenharmony_ci+ auto bias = in_tensors_.at(FOURTH_INPUT); 6469be168c0dSopenharmony_ci+ auto bias_data = bias->data(); 6470be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(bias_data); 6471be168c0dSopenharmony_ci+ input_bias_ = 6472be168c0dSopenharmony_ci+ reinterpret_cast<float16_t *>(malloc(weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float16_t))); 6473be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc input_bias_ failed."); 6474be168c0dSopenharmony_ci+ pack_buffer_.push_back(input_bias_); 6475be168c0dSopenharmony_ci+ (void)memset(input_bias_, 0, weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float16_t)); 6476be168c0dSopenharmony_ci+ if (bias->data_type() == kNumberTypeFloat32) { 6477be168c0dSopenharmony_ci+ PackLstmBiasFp32ToFp16(input_bias_, reinterpret_cast<float *>(bias_data), weight_segment_num_, 6478be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 6479be168c0dSopenharmony_ci+ nullptr); 6480be168c0dSopenharmony_ci+ } else if (bias->data_type() == kNumberTypeFloat16) { 6481be168c0dSopenharmony_ci+ PackLstmBiasFp16(input_bias_, reinterpret_cast<float16_t *>(bias_data), weight_segment_num_, 6482be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, nullptr); 6483be168c0dSopenharmony_ci+ } else { 6484be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 6485be168c0dSopenharmony_ci+ return RET_ERROR; 6486be168c0dSopenharmony_ci+ } 6487be168c0dSopenharmony_ci+ return RET_OK; 6488be168c0dSopenharmony_ci+} 6489be168c0dSopenharmony_ci+ 6490be168c0dSopenharmony_ci+int LstmNonMindirFp16CPUKernel::InitStateWeightBias() { 6491be168c0dSopenharmony_ci+ // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 6492be168c0dSopenharmony_ci+ // state -- row: batch; col: hidden_size 6493be168c0dSopenharmony_ci+ // weight -- row: hidden_size; col: hidden_size, need transpose 6494be168c0dSopenharmony_ci+ // result -- row: batch; col: hidden_size 6495be168c0dSopenharmony_ci+ auto weight_h = in_tensors_.at(THIRD_INPUT); 6496be168c0dSopenharmony_ci+ auto weight_h_data = weight_h->data(); 6497be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_h_data); 6498be168c0dSopenharmony_ci+ weight_h_ptr_ = reinterpret_cast<float16_t *>( 6499be168c0dSopenharmony_ci+ malloc(weight_segment_num_ * lstm_param_->state_col_align_ * lstm_param_->output_size_ * sizeof(float16_t))); 6500be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, 6501be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_h_ptr_ failed."); 6502be168c0dSopenharmony_ci+ 6503be168c0dSopenharmony_ci+ if (weight_need_pack_) { 6504be168c0dSopenharmony_ci+ if (weight_h->data_type() == kNumberTypeFloat32) { 6505be168c0dSopenharmony_ci+ PackLstmWeightFp32ToFp16(weight_h_ptr_, reinterpret_cast<float *>(weight_h_data), weight_segment_num_, 6506be168c0dSopenharmony_ci+ lstm_param_->output_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, 6507be168c0dSopenharmony_ci+ nullptr); 6508be168c0dSopenharmony_ci+ } else if (weight_h->data_type() == kNumberTypeFloat16) { 6509be168c0dSopenharmony_ci+ PackLstmWeightFp16(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_segment_num_, 6510be168c0dSopenharmony_ci+ lstm_param_->output_size_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, nullptr); 6511be168c0dSopenharmony_ci+ } else { 6512be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 6513be168c0dSopenharmony_ci+ return RET_ERROR; 6514be168c0dSopenharmony_ci+ } 6515be168c0dSopenharmony_ci+ } else { 6516be168c0dSopenharmony_ci+ if (weight_h->data_type() == kNumberTypeFloat32) { 6517be168c0dSopenharmony_ci+ Float32ToFloat16(reinterpret_cast<float *>(weight_h_data), weight_h_ptr_, weight_h->ElementsNum()); 6518be168c0dSopenharmony_ci+ } else if (weight_h->data_type() == kNumberTypeFloat16) { 6519be168c0dSopenharmony_ci+ (void)memcpy(weight_h_ptr_, reinterpret_cast<float16_t *>(weight_h_data), weight_h->Size()); 6520be168c0dSopenharmony_ci+ } else { 6521be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of weight_h tensor for lstm."; 6522be168c0dSopenharmony_ci+ return RET_ERROR; 6523be168c0dSopenharmony_ci+ } 6524be168c0dSopenharmony_ci+ } 6525be168c0dSopenharmony_ci+ 6526be168c0dSopenharmony_ci+ // state bias 6527be168c0dSopenharmony_ci+ auto bias = in_tensors_[FOURTH_INPUT]; 6528be168c0dSopenharmony_ci+ auto bias_data = bias->data(); 6529be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(bias_data); 6530be168c0dSopenharmony_ci+ state_bias_ = 6531be168c0dSopenharmony_ci+ reinterpret_cast<float16_t *>(malloc(weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float16_t))); 6532be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(state_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc state_bias_ failed."); 6533be168c0dSopenharmony_ci+ (void)memset(state_bias_, 0, weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float16_t)); 6534be168c0dSopenharmony_ci+ if (bias->data_type() == kNumberTypeFloat32) { 6535be168c0dSopenharmony_ci+ auto state_bias_data = reinterpret_cast<float *>(bias_data) + kGateNum * lstm_param_->hidden_size_; 6536be168c0dSopenharmony_ci+ PackLstmBiasFp32ToFp16(state_bias_, state_bias_data, weight_segment_num_, lstm_param_->hidden_size_, 6537be168c0dSopenharmony_ci+ lstm_param_->state_col_align_, lstm_param_->bidirectional_, nullptr); 6538be168c0dSopenharmony_ci+ } else if (bias->data_type() == kNumberTypeFloat16) { 6539be168c0dSopenharmony_ci+ auto state_bias_data = reinterpret_cast<float16_t *>(bias_data) + kGateNum * lstm_param_->hidden_size_; 6540be168c0dSopenharmony_ci+ PackLstmBiasFp16(state_bias_, state_bias_data, weight_segment_num_, lstm_param_->hidden_size_, 6541be168c0dSopenharmony_ci+ lstm_param_->state_col_align_, lstm_param_->bidirectional_, nullptr); 6542be168c0dSopenharmony_ci+ } else { 6543be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of bias tensor for lstm."; 6544be168c0dSopenharmony_ci+ return RET_ERROR; 6545be168c0dSopenharmony_ci+ } 6546be168c0dSopenharmony_ci+ return RET_OK; 6547be168c0dSopenharmony_ci+} 6548be168c0dSopenharmony_ci+ 6549be168c0dSopenharmony_ci+int LstmNonMindirFp16CPUKernel::InitProjectWeight() { 6550be168c0dSopenharmony_ci+ if (in_tensors_.size() < C7NUM) { 6551be168c0dSopenharmony_ci+ return RET_OK; 6552be168c0dSopenharmony_ci+ } 6553be168c0dSopenharmony_ci+ auto weight_pro = in_tensors_[SEVENTH_INPUT]; 6554be168c0dSopenharmony_ci+ auto shape = weight_pro->shape(); 6555be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(shape.size() == C3NUM, lite::RET_ERROR, "Project-weight's shape must be 3D."); 6556be168c0dSopenharmony_ci+ auto weight_pro_data = weight_pro->data(); 6557be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_pro_data); 6558be168c0dSopenharmony_ci+ int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6559be168c0dSopenharmony_ci+ if (shape[0] != batch) { 6560be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 6561be168c0dSopenharmony_ci+ return RET_ERROR; 6562be168c0dSopenharmony_ci+ } 6563be168c0dSopenharmony_ci+ int pro_col_align = lstm_param_->proj_col_align_; 6564be168c0dSopenharmony_ci+ weight_project_ptr_ = 6565be168c0dSopenharmony_ci+ reinterpret_cast<float16_t *>(malloc(batch * lstm_param_->hidden_size_ * pro_col_align * sizeof(float16_t))); 6566be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 6567be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 6568be168c0dSopenharmony_ci+ 6569be168c0dSopenharmony_ci+ if (weight_need_pack_) { 6570be168c0dSopenharmony_ci+ if (weight_pro->data_type() == kNumberTypeFloat32) { 6571be168c0dSopenharmony_ci+ PackLstmWeightFp32ToFp16(weight_project_ptr_, reinterpret_cast<float *>(weight_pro_data), batch, 6572be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->output_size_, pro_col_align, nullptr); 6573be168c0dSopenharmony_ci+ } else if (weight_pro->data_type() == kNumberTypeFloat16) { 6574be168c0dSopenharmony_ci+ PackLstmWeightFp16(weight_project_ptr_, reinterpret_cast<float16_t *>(weight_pro_data), batch, 6575be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->output_size_, pro_col_align, nullptr); 6576be168c0dSopenharmony_ci+ } else { 6577be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 6578be168c0dSopenharmony_ci+ return RET_ERROR; 6579be168c0dSopenharmony_ci+ } 6580be168c0dSopenharmony_ci+ } else { 6581be168c0dSopenharmony_ci+ if (weight_pro->data_type() == kNumberTypeFloat32) { 6582be168c0dSopenharmony_ci+ Float32ToFloat16(reinterpret_cast<float *>(weight_pro_data), weight_project_ptr_, weight_pro->ElementsNum()); 6583be168c0dSopenharmony_ci+ } else if (weight_pro->data_type() == kNumberTypeFloat16) { 6584be168c0dSopenharmony_ci+ (void)memcpy(weight_project_ptr_, weight_pro_data, weight_pro->Size()); 6585be168c0dSopenharmony_ci+ } else { 6586be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported data type of weight_project tensor for lstm."; 6587be168c0dSopenharmony_ci+ return RET_ERROR; 6588be168c0dSopenharmony_ci+ } 6589be168c0dSopenharmony_ci+ } 6590be168c0dSopenharmony_ci+ size_t bias_size = UP_ROUND(lstm_param_->output_size_, C8NUM) * sizeof(float16_t); 6591be168c0dSopenharmony_ci+ project_bias_ = reinterpret_cast<float16_t *>(malloc(bias_size)); 6592be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(project_bias_ != nullptr, lite::RET_NULL_PTR, 6593be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc project_bias_ failed."); 6594be168c0dSopenharmony_ci+ (void)memset(project_bias_, 0, bias_size); 6595be168c0dSopenharmony_ci+ return RET_OK; 6596be168c0dSopenharmony_ci+} 6597be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6598be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h 6599be168c0dSopenharmony_cinew file mode 100644 6600be168c0dSopenharmony_ciindex 00000000..132ef1cf 6601be168c0dSopenharmony_ci--- /dev/null 6602be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/lstm_non_mindir_fp16.h 6603be168c0dSopenharmony_ci@@ -0,0 +1,59 @@ 6604be168c0dSopenharmony_ci+/** 6605be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6606be168c0dSopenharmony_ci+ * 6607be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6608be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6609be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6610be168c0dSopenharmony_ci+ * 6611be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6612be168c0dSopenharmony_ci+ * 6613be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6614be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6615be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6616be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6617be168c0dSopenharmony_ci+ * limitations under the License. 6618be168c0dSopenharmony_ci+ */ 6619be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_NON_MINDIR_FP16_H_ 6620be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_NON_MINDIR_FP16_H_ 6621be168c0dSopenharmony_ci+ 6622be168c0dSopenharmony_ci+#include <vector> 6623be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp16/lstm_fp16_base.h" 6624be168c0dSopenharmony_ci+ 6625be168c0dSopenharmony_ci+namespace mindspore::kernel { 6626be168c0dSopenharmony_ci+/* 6627be168c0dSopenharmony_ci+ * 1. LSTM without project, output_size = hidden_size 6628be168c0dSopenharmony_ci+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 6629be168c0dSopenharmony_ci+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, hidden_size] 6630be168c0dSopenharmony_ci+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 6631be168c0dSopenharmony_ci+ * h_init: fifth input, shape is [bidirectional, batch_size, hidden_size] 6632be168c0dSopenharmony_ci+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 6633be168c0dSopenharmony_ci+ * 6634be168c0dSopenharmony_ci+ * 2. LSTM with project, output_size = project_size 6635be168c0dSopenharmony_ci+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 6636be168c0dSopenharmony_ci+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, project_size] 6637be168c0dSopenharmony_ci+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 6638be168c0dSopenharmony_ci+ * h_init: fifth input, shape is [bidirectional, batch_size, project_size] 6639be168c0dSopenharmony_ci+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 6640be168c0dSopenharmony_ci+ * weight_pro: seventh input, shape is [bidirectional, project_size, hidden_size] 6641be168c0dSopenharmony_ci+ */ 6642be168c0dSopenharmony_ci+class LstmNonMindirFp16CPUKernel : public LstmFp16BaseCPUKernel { 6643be168c0dSopenharmony_ci+ public: 6644be168c0dSopenharmony_ci+ LstmNonMindirFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6645be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6646be168c0dSopenharmony_ci+ : LstmFp16BaseCPUKernel(parameter, inputs, outputs, ctx) { 6647be168c0dSopenharmony_ci+ hidden_init_index_ = FIFTH_INPUT; 6648be168c0dSopenharmony_ci+ cell_init_index_ = SIXTH_INPUT; 6649be168c0dSopenharmony_ci+ } 6650be168c0dSopenharmony_ci+ 6651be168c0dSopenharmony_ci+ ~LstmNonMindirFp16CPUKernel() override = default; 6652be168c0dSopenharmony_ci+ 6653be168c0dSopenharmony_ci+ int Prepare() override; 6654be168c0dSopenharmony_ci+ 6655be168c0dSopenharmony_ci+ protected: 6656be168c0dSopenharmony_ci+ int InitInputWeightBias() override; 6657be168c0dSopenharmony_ci+ int InitStateWeightBias() override; 6658be168c0dSopenharmony_ci+ int InitProjectWeight() override; 6659be168c0dSopenharmony_ci+}; 6660be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6661be168c0dSopenharmony_ci+ 6662be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP16_LSTM_NON_MINDIR_FP16_H_ 6663be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc b/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc 6664be168c0dSopenharmony_ciindex 8adb97b9..d6f94fd9 100644 6665be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc 6666be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp16/matmul_base_fp16.cc 6667be168c0dSopenharmony_ci@@ -187,13 +187,13 @@ void MatmulBaseFP16CPUKernel::InitMatrixA(const void *src_ptr) { 6668be168c0dSopenharmony_ci float16_t *dst = a_pack_ptr_ + i * params_->deep_ * params_->row_align_; 6669be168c0dSopenharmony_ci if (params_->a_transpose_) { 6670be168c0dSopenharmony_ci #ifdef ENABLE_ARM64 6671be168c0dSopenharmony_ci- RowMajor2RowNMajorFp16((const float16_t *)src, dst, params_->deep_, params_->row_); 6672be168c0dSopenharmony_ci+ RowMajor2RowNMajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32); 6673be168c0dSopenharmony_ci #else 6674be168c0dSopenharmony_ci RowMajor2Row12MajorFp16(src, dst, params_->deep_, params_->row_, src_data_type == kNumberTypeFloat32); 6675be168c0dSopenharmony_ci #endif 6676be168c0dSopenharmony_ci } else { 6677be168c0dSopenharmony_ci #ifdef ENABLE_ARM64 6678be168c0dSopenharmony_ci- RowMajor2ColNMajorFp16((const float16_t *)src, dst, params_->row_, params_->deep_); 6679be168c0dSopenharmony_ci+ RowMajor2ColNMajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32); 6680be168c0dSopenharmony_ci #else 6681be168c0dSopenharmony_ci RowMajor2Col12MajorFp16(src, dst, params_->row_, params_->deep_, src_data_type == kNumberTypeFloat32); 6682be168c0dSopenharmony_ci #endif 6683be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc 6684be168c0dSopenharmony_ciindex 0b67f2c2..67f42265 100644 6685be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc 6686be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32.cc 6687be168c0dSopenharmony_ci@@ -1,5 +1,5 @@ 6688be168c0dSopenharmony_ci /** 6689be168c0dSopenharmony_ci- * Copyright 2020 Huawei Technologies Co., Ltd 6690be168c0dSopenharmony_ci+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 6691be168c0dSopenharmony_ci * 6692be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License"); 6693be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 6694be168c0dSopenharmony_ci@@ -14,14 +14,11 @@ 6695be168c0dSopenharmony_ci * limitations under the License. 6696be168c0dSopenharmony_ci */ 6697be168c0dSopenharmony_ci 6698be168c0dSopenharmony_ci-#include "src/litert/kernel/cpu/fp32/lstm_fp32.h" 6699be168c0dSopenharmony_ci-#include <cfloat> 6700be168c0dSopenharmony_ci #include <vector> 6701be168c0dSopenharmony_ci-#include "schema/model_generated.h" 6702be168c0dSopenharmony_ci+#include "src/litert//kernel/cpu/fp32/lstm_mindir_fp32.h" 6703be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h" 6704be168c0dSopenharmony_ci #include "src/litert/kernel_registry.h" 6705be168c0dSopenharmony_ci #include "include/errorcode.h" 6706be168c0dSopenharmony_ci-#include "nnacl/fp32/pack_fp32.h" 6707be168c0dSopenharmony_ci-#include "nnacl/fp32/matmul_fp32.h" 6708be168c0dSopenharmony_ci 6709be168c0dSopenharmony_ci using mindspore::kernel::KERNEL_ARCH; 6710be168c0dSopenharmony_ci using mindspore::lite::KernelRegistrar; 6711be168c0dSopenharmony_ci@@ -32,664 +29,31 @@ using mindspore::schema::PrimitiveType_LSTM; 6712be168c0dSopenharmony_ci 6713be168c0dSopenharmony_ci namespace mindspore::kernel { 6714be168c0dSopenharmony_ci namespace { 6715be168c0dSopenharmony_ci-constexpr int kOutputHiddenStatusIndex = 1; 6716be168c0dSopenharmony_ci-constexpr int kOutputCellStatusIndex = 2; 6717be168c0dSopenharmony_ci-} // namespace 6718be168c0dSopenharmony_ci- 6719be168c0dSopenharmony_ci-int LstmInputMulWeightRun(void *cdata, int task_id, float, float) { 6720be168c0dSopenharmony_ci- auto kernel = reinterpret_cast<const LstmCPUKernel *>(cdata); 6721be168c0dSopenharmony_ci- CHECK_NULL_RETURN(kernel); 6722be168c0dSopenharmony_ci- kernel->InputWeightMatMul(task_id); 6723be168c0dSopenharmony_ci- return RET_OK; 6724be168c0dSopenharmony_ci-} 6725be168c0dSopenharmony_ci- 6726be168c0dSopenharmony_ci-int LstmSequenceLoopRun(void *cdata, int task_id, float, float) { 6727be168c0dSopenharmony_ci- auto kernel = reinterpret_cast<LstmCPUKernel *>(cdata); 6728be168c0dSopenharmony_ci- CHECK_NULL_RETURN(kernel); 6729be168c0dSopenharmony_ci- auto ret = kernel->DoSequenceLoop(task_id); 6730be168c0dSopenharmony_ci- if (ret != RET_OK) { 6731be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LSTM: Do Sequence-loop failed."; 6732be168c0dSopenharmony_ci- } 6733be168c0dSopenharmony_ci- return ret; 6734be168c0dSopenharmony_ci-} 6735be168c0dSopenharmony_ci- 6736be168c0dSopenharmony_ci-void LstmCPUKernel::FreeRunBuffer() { 6737be168c0dSopenharmony_ci- for (auto data : buffer_running_malloc_) { 6738be168c0dSopenharmony_ci- ms_context_->allocator->Free(data); 6739be168c0dSopenharmony_ci- } 6740be168c0dSopenharmony_ci- buffer_running_malloc_.clear(); 6741be168c0dSopenharmony_ci-} 6742be168c0dSopenharmony_ci- 6743be168c0dSopenharmony_ci-int LstmCPUKernel::InitInputWeightBias() { 6744be168c0dSopenharmony_ci- // malloc and init input * weight right matrix buffer 6745be168c0dSopenharmony_ci- // input -- row: seq_len * batch; col: input_size 6746be168c0dSopenharmony_ci- // weight -- row: hidden_size; col: input_size, need transpose 6747be168c0dSopenharmony_ci- // result -- row: seq_len * batch; col: hidden_size 6748be168c0dSopenharmony_ci- weight_i_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 6749be168c0dSopenharmony_ci- weight_batch_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); 6750be168c0dSopenharmony_ci- if (weight_i_ptr_ == nullptr) { 6751be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_i_ptr_ error."; 6752be168c0dSopenharmony_ci- return RET_ERROR; 6753be168c0dSopenharmony_ci- } 6754be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(weight_i_ptr_); 6755be168c0dSopenharmony_ci- int i_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_i_index; 6756be168c0dSopenharmony_ci- const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr; 6757be168c0dSopenharmony_ci- auto weight_i = in_tensors_.at(i_index); 6758be168c0dSopenharmony_ci- auto weight_i_data = reinterpret_cast<float *>(weight_i->data()); 6759be168c0dSopenharmony_ci- 6760be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_i_data); 6761be168c0dSopenharmony_ci- int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); 6762be168c0dSopenharmony_ci- int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); 6763be168c0dSopenharmony_ci- int b_size = (lstm_param_->hidden_size_); 6764be168c0dSopenharmony_ci- bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_i->ElementsNum()) ? true : false; 6765be168c0dSopenharmony_ci- int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (cw_size); 6766be168c0dSopenharmony_ci- PackLstmWeightWithStride(weight_i_ptr_, weight_i_data, weight_batch_, lstm_param_->input_size_, 6767be168c0dSopenharmony_ci- lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 6768be168c0dSopenharmony_ci- stride, weights_order); 6769be168c0dSopenharmony_ci- // input bias 6770be168c0dSopenharmony_ci- input_bias_ = reinterpret_cast<float *>( 6771be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(weight_batch_ * lstm_param_->input_col_align_ * sizeof(float))); 6772be168c0dSopenharmony_ci- if (input_bias_ == nullptr) { 6773be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc input_bias_ error."; 6774be168c0dSopenharmony_ci- return RET_ERROR; 6775be168c0dSopenharmony_ci- } 6776be168c0dSopenharmony_ci- memset(input_bias_, 0, weight_batch_ * lstm_param_->input_col_align_ * sizeof(float)); 6777be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(input_bias_); 6778be168c0dSopenharmony_ci- 6779be168c0dSopenharmony_ci- int offset = weight_batch_ * (cw_size + hh_size); 6780be168c0dSopenharmony_ci- float *bias_data = (has_bias) ? weight_i_data + offset : nullptr; 6781be168c0dSopenharmony_ci- int dir_mul = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6782be168c0dSopenharmony_ci- int b_stride = (gpu_orig_state_) ? gate_num * (dir_mul * b_size) : gate_num * (b_size); 6783be168c0dSopenharmony_ci- if (in_tensors_.size() > mindir_input_tensors) { 6784be168c0dSopenharmony_ci- bias_data = reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data()); 6785be168c0dSopenharmony_ci- CHECK_NULL_RETURN(bias_data); 6786be168c0dSopenharmony_ci- PackLstmBias(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 6787be168c0dSopenharmony_ci- lstm_param_->bidirectional_, weights_order); 6788be168c0dSopenharmony_ci- } else { 6789be168c0dSopenharmony_ci- if (bias_data != nullptr) { 6790be168c0dSopenharmony_ci- PackLstmBiasWithStride(input_bias_, bias_data, weight_batch_, lstm_param_->hidden_size_, 6791be168c0dSopenharmony_ci- lstm_param_->input_col_align_, lstm_param_->bidirectional_, b_stride, weights_order); 6792be168c0dSopenharmony_ci- } 6793be168c0dSopenharmony_ci- } 6794be168c0dSopenharmony_ci- return RET_OK; 6795be168c0dSopenharmony_ci-} 6796be168c0dSopenharmony_ci- 6797be168c0dSopenharmony_ci-int LstmCPUKernel::InitStateWeightBias() { 6798be168c0dSopenharmony_ci- // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 6799be168c0dSopenharmony_ci- // state -- row: batch; col: hidden_size 6800be168c0dSopenharmony_ci- // weight -- row: hidden_size; col: hidden_size, need transpose 6801be168c0dSopenharmony_ci- // result -- row: batch; col: hidden_size 6802be168c0dSopenharmony_ci- int weight_i_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; 6803be168c0dSopenharmony_ci- int h_index = (in_tensors_.size() == mindir_input_tensors) ? combined_weights_index : onnx_weight_h_index; 6804be168c0dSopenharmony_ci- auto weight_h = in_tensors_.at(h_index); 6805be168c0dSopenharmony_ci- auto weight_h_data = (reinterpret_cast<float *>(weight_h->data())); 6806be168c0dSopenharmony_ci- 6807be168c0dSopenharmony_ci- int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); 6808be168c0dSopenharmony_ci- int hh_size = (lstm_param_->hidden_size_ * lstm_param_->project_size_); 6809be168c0dSopenharmony_ci- int b_size = (lstm_param_->hidden_size_); 6810be168c0dSopenharmony_ci- int stride = (gpu_orig_state_) ? gate_num * (cw_size + hh_size) : gate_num * (hh_size); 6811be168c0dSopenharmony_ci- 6812be168c0dSopenharmony_ci- if (in_tensors_.size() == mindir_input_tensors) { 6813be168c0dSopenharmony_ci- if (gpu_orig_state_) { 6814be168c0dSopenharmony_ci- weight_h_data += gate_num * cw_size; 6815be168c0dSopenharmony_ci- } else { 6816be168c0dSopenharmony_ci- weight_h_data += weight_i_size; 6817be168c0dSopenharmony_ci- } 6818be168c0dSopenharmony_ci- } 6819be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_h_data); 6820be168c0dSopenharmony_ci- if (!state_is_vec_) { 6821be168c0dSopenharmony_ci- weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 6822be168c0dSopenharmony_ci- weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * sizeof(float))); 6823be168c0dSopenharmony_ci- if (weight_h_ptr_ == nullptr) { 6824be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; 6825be168c0dSopenharmony_ci- return RET_ERROR; 6826be168c0dSopenharmony_ci- } 6827be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(weight_h_ptr_); 6828be168c0dSopenharmony_ci- const int *weights_order = (in_tensors_.size() == mindir_input_tensors) ? weights_order_IFOG : nullptr; 6829be168c0dSopenharmony_ci- PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_batch_, lstm_param_->project_size_, 6830be168c0dSopenharmony_ci- lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, 6831be168c0dSopenharmony_ci- stride, weights_order); 6832be168c0dSopenharmony_ci- } else { 6833be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 6834be168c0dSopenharmony_ci- weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 6835be168c0dSopenharmony_ci- weight_batch_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * sizeof(float))); 6836be168c0dSopenharmony_ci- if (weight_h_ptr_ == nullptr) { 6837be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_h_ptr_ error."; 6838be168c0dSopenharmony_ci- return RET_ERROR; 6839be168c0dSopenharmony_ci- } 6840be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(weight_h_ptr_); 6841be168c0dSopenharmony_ci- for (int i = 0; i < weight_batch_; i++) { 6842be168c0dSopenharmony_ci- const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->project_size_; 6843be168c0dSopenharmony_ci- float *dst_batch = weight_h_ptr_ + i * lstm_param_->state_col_align_ * lstm_param_->project_size_; 6844be168c0dSopenharmony_ci- RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->project_size_); 6845be168c0dSopenharmony_ci- } 6846be168c0dSopenharmony_ci-#else 6847be168c0dSopenharmony_ci- weight_h_ptr_ = weight_h_data; 6848be168c0dSopenharmony_ci-#endif 6849be168c0dSopenharmony_ci- } 6850be168c0dSopenharmony_ci- 6851be168c0dSopenharmony_ci- // state bias 6852be168c0dSopenharmony_ci- int weight_h_size = weight_batch_ * lstm_param_->hidden_size_ * lstm_param_->hidden_size_; 6853be168c0dSopenharmony_ci- int bias_size = weight_batch_ * lstm_param_->hidden_size_; 6854be168c0dSopenharmony_ci- state_bias_ = reinterpret_cast<float *>( 6855be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(weight_batch_ * lstm_param_->state_col_align_ * sizeof(float))); 6856be168c0dSopenharmony_ci- if (state_bias_ == nullptr) { 6857be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc state_bias_ error."; 6858be168c0dSopenharmony_ci- return RET_ERROR; 6859be168c0dSopenharmony_ci- } 6860be168c0dSopenharmony_ci- memset(state_bias_, 0, weight_batch_ * lstm_param_->state_col_align_ * sizeof(float)); 6861be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(state_bias_); 6862be168c0dSopenharmony_ci- // if ONNX, secend bias is also present order IOFG 6863be168c0dSopenharmony_ci- if (in_tensors_.size() > mindir_input_tensors) { 6864be168c0dSopenharmony_ci- float *state_bias = 6865be168c0dSopenharmony_ci- reinterpret_cast<float *>(in_tensors_.at(onnx_bias_index)->data()) + gate_num * lstm_param_->hidden_size_; 6866be168c0dSopenharmony_ci- CHECK_NULL_RETURN(state_bias); 6867be168c0dSopenharmony_ci- PackLstmBias(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, 6868be168c0dSopenharmony_ci- lstm_param_->bidirectional_, nullptr); 6869be168c0dSopenharmony_ci- } else if (weight_h->ElementsNum() - weight_i_size - weight_h_size - C2NUM * bias_size == 0) { 6870be168c0dSopenharmony_ci- // mindir from device "GPU", secend bias is also present order IFOG 6871be168c0dSopenharmony_ci- int dir_mul = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6872be168c0dSopenharmony_ci- int bias_offset = (gpu_orig_state_) ? gate_num * ((dir_mul - C1NUM) * cw_size + dir_mul * hh_size + b_size) 6873be168c0dSopenharmony_ci- : weight_h_size + bias_size; 6874be168c0dSopenharmony_ci- float *state_bias = weight_h_data + bias_offset; 6875be168c0dSopenharmony_ci- int b_stride = (gpu_orig_state_) ? gate_num * (b_size * C2NUM) : gate_num * b_size; 6876be168c0dSopenharmony_ci- PackLstmBiasWithStride(state_bias_, state_bias, weight_batch_, lstm_param_->hidden_size_, 6877be168c0dSopenharmony_ci- lstm_param_->state_col_align_, lstm_param_->bidirectional_, b_stride, weights_order_IFOG); 6878be168c0dSopenharmony_ci- } 6879be168c0dSopenharmony_ci- return RET_OK; 6880be168c0dSopenharmony_ci-} 6881be168c0dSopenharmony_ci- 6882be168c0dSopenharmony_ci-int LstmCPUKernel::InitProjectWeight() { 6883be168c0dSopenharmony_ci- if (in_tensors_.size() < C7NUM) { 6884be168c0dSopenharmony_ci- return RET_OK; 6885be168c0dSopenharmony_ci- } 6886be168c0dSopenharmony_ci- auto weight_pro = in_tensors_.at(SEVENTH_INPUT); 6887be168c0dSopenharmony_ci- auto shape = weight_pro->shape(); 6888be168c0dSopenharmony_ci- if (shape.size() != C3NUM) { 6889be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Project-weight's shape must be 3D."; 6890be168c0dSopenharmony_ci- return RET_ERROR; 6891be168c0dSopenharmony_ci- } 6892be168c0dSopenharmony_ci- auto weight_pro_data = reinterpret_cast<float *>(weight_pro->data()); 6893be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_pro_data); 6894be168c0dSopenharmony_ci- int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 6895be168c0dSopenharmony_ci- if (shape[0] != batch) { 6896be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 6897be168c0dSopenharmony_ci- return RET_ERROR; 6898be168c0dSopenharmony_ci- } 6899be168c0dSopenharmony_ci- int col_align = UP_ROUND(lstm_param_->project_size_, col_tile_); 6900be168c0dSopenharmony_ci- if (!state_is_vec_) { 6901be168c0dSopenharmony_ci- weight_project_ptr_ = reinterpret_cast<float *>( 6902be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(batch * lstm_param_->hidden_size_ * col_align * sizeof(float))); 6903be168c0dSopenharmony_ci- if (weight_project_ptr_ == nullptr) { 6904be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_project_ptr_ error."; 6905be168c0dSopenharmony_ci- return RET_ERROR; 6906be168c0dSopenharmony_ci- } 6907be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(weight_project_ptr_); 6908be168c0dSopenharmony_ci- PackLstmWeightWithStride(weight_project_ptr_, weight_pro_data, batch, lstm_param_->hidden_size_, 6909be168c0dSopenharmony_ci- lstm_param_->project_size_, col_align, lstm_param_->bidirectional_, 6910be168c0dSopenharmony_ci- lstm_param_->hidden_size_ * lstm_param_->project_size_, nullptr); 6911be168c0dSopenharmony_ci- } else { 6912be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 6913be168c0dSopenharmony_ci- weight_project_ptr_ = reinterpret_cast<float *>( 6914be168c0dSopenharmony_ci- ms_context_->allocator->Malloc(batch * lstm_param_->hidden_size_ * col_align * sizeof(float))); 6915be168c0dSopenharmony_ci- if (weight_project_ptr_ == nullptr) { 6916be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel malloc weight_project_ptr_ error."; 6917be168c0dSopenharmony_ci- return RET_ERROR; 6918be168c0dSopenharmony_ci- } 6919be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(weight_project_ptr_); 6920be168c0dSopenharmony_ci- for (int i = 0; i < batch; ++i) { 6921be168c0dSopenharmony_ci- const float *src_batch = weight_pro_data + i * lstm_param_->hidden_size_ * lstm_param_->project_size_; 6922be168c0dSopenharmony_ci- float *dst_batch = weight_project_ptr_ + i * lstm_param_->hidden_size_ * col_align; 6923be168c0dSopenharmony_ci- RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->project_size_, lstm_param_->hidden_size_); 6924be168c0dSopenharmony_ci- } 6925be168c0dSopenharmony_ci-#else 6926be168c0dSopenharmony_ci- weight_project_ptr_ = weight_pro_data; 6927be168c0dSopenharmony_ci-#endif 6928be168c0dSopenharmony_ci- } 6929be168c0dSopenharmony_ci- return RET_OK; 6930be168c0dSopenharmony_ci-} 6931be168c0dSopenharmony_ci- 6932be168c0dSopenharmony_ci-int LstmCPUKernel::InitParam() { 6933be168c0dSopenharmony_ci- auto input = in_tensors_.front(); 6934be168c0dSopenharmony_ci- std::vector<int> in_shape = input->shape(); 6935be168c0dSopenharmony_ci- lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT); 6936be168c0dSopenharmony_ci- lstm_param_->batch_ = in_shape.at(SECOND_INPUT); 6937be168c0dSopenharmony_ci- lstm_param_->input_size_ = in_shape.at(THIRD_INPUT); 6938be168c0dSopenharmony_ci- 6939be168c0dSopenharmony_ci- auto weight_i = in_tensors_.at(onnx_weight_i_index); 6940be168c0dSopenharmony_ci- std::vector<int> w_shape = weight_i->shape(); 6941be168c0dSopenharmony_ci- if (in_tensors_.size() == mindir_input_tensors) { 6942be168c0dSopenharmony_ci- hidden_state_input_index_ = mindir_hidden_state_input_index; 6943be168c0dSopenharmony_ci- cell_state_input_index_ = mindir_cell_state_input_index; 6944be168c0dSopenharmony_ci- lstm_param_->hidden_size_ = w_shape.at(THIRD_INPUT); 6945be168c0dSopenharmony_ci- lstm_param_->project_size_ = lstm_param_->hidden_size_; 6946be168c0dSopenharmony_ci- } else { 6947be168c0dSopenharmony_ci- lstm_param_->hidden_size_ = w_shape.at(SECOND_INPUT) / gate_num; 6948be168c0dSopenharmony_ci- auto weight_h = in_tensors_[THIRD_INPUT]; 6949be168c0dSopenharmony_ci- auto h_shape = weight_h->shape(); 6950be168c0dSopenharmony_ci- lstm_param_->project_size_ = h_shape.back(); 6951be168c0dSopenharmony_ci- } 6952be168c0dSopenharmony_ci- 6953be168c0dSopenharmony_ci- lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->hidden_size_ 6954be168c0dSopenharmony_ci- : lstm_param_->batch_ * lstm_param_->hidden_size_; 6955be168c0dSopenharmony_ci- weight_batch_ = lstm_param_->bidirectional_ ? C2NUM * gate_num : gate_num; 6956be168c0dSopenharmony_ci- state_is_vec_ = lstm_param_->batch_ == 1; 6957be168c0dSopenharmony_ci- // determine FB origin 6958be168c0dSopenharmony_ci- gpu_orig_state_ = false; 6959be168c0dSopenharmony_ci- if (in_tensors_.size() == mindir_input_tensors) { 6960be168c0dSopenharmony_ci- gpu_orig_state_ = gpu_orig_cfg_; 6961be168c0dSopenharmony_ci- auto weight_t = in_tensors_.at(combined_weights_index); 6962be168c0dSopenharmony_ci- int cw_size = (lstm_param_->input_size_ * lstm_param_->hidden_size_); 6963be168c0dSopenharmony_ci- int hh_size = (lstm_param_->hidden_size_ * lstm_param_->hidden_size_); 6964be168c0dSopenharmony_ci- int b_size = (lstm_param_->hidden_size_); 6965be168c0dSopenharmony_ci- bool has_bias = (weight_batch_ * (cw_size + hh_size) < weight_t->ElementsNum()) ? true : false; 6966be168c0dSopenharmony_ci- // if bias exist we can determine the gpu_orig_state_ 6967be168c0dSopenharmony_ci- if (has_bias) { 6968be168c0dSopenharmony_ci- gpu_orig_state_ = 6969be168c0dSopenharmony_ci- (weight_batch_ * (cw_size + hh_size + C2NUM * b_size) == weight_t->ElementsNum()) ? true : false; 6970be168c0dSopenharmony_ci- } 6971be168c0dSopenharmony_ci- } 6972be168c0dSopenharmony_ci- 6973be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 6974be168c0dSopenharmony_ci- row_tile_ = C6NUM; 6975be168c0dSopenharmony_ci- col_tile_ = C16NUM; 6976be168c0dSopenharmony_ci-#elif defined(ENABLE_ARM32) 6977be168c0dSopenharmony_ci- row_tile_ = C12NUM; 6978be168c0dSopenharmony_ci- col_tile_ = C4NUM; 6979be168c0dSopenharmony_ci-#elif defined(ENABLE_SSE) 6980be168c0dSopenharmony_ci- row_tile_ = C4NUM; 6981be168c0dSopenharmony_ci- col_tile_ = C8NUM; 6982be168c0dSopenharmony_ci-#else 6983be168c0dSopenharmony_ci- row_tile_ = C12NUM; 6984be168c0dSopenharmony_ci- col_tile_ = C8NUM; 6985be168c0dSopenharmony_ci-#endif 6986be168c0dSopenharmony_ci- lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_); 6987be168c0dSopenharmony_ci- lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_); 6988be168c0dSopenharmony_ci- input_thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_)); 6989be168c0dSopenharmony_ci- MS_CHECK_FALSE(input_thread_count_ == 0, RET_ERROR); 6990be168c0dSopenharmony_ci- input_thread_stride_ = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), input_thread_count_); 6991be168c0dSopenharmony_ci- 6992be168c0dSopenharmony_ci- state_row_tile_ = row_tile_; 6993be168c0dSopenharmony_ci- state_col_tile_ = col_tile_; 6994be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 6995be168c0dSopenharmony_ci- if (state_is_vec_) { 6996be168c0dSopenharmony_ci- state_row_tile_ = 1; 6997be168c0dSopenharmony_ci- state_col_tile_ = C8NUM; 6998be168c0dSopenharmony_ci- } 6999be168c0dSopenharmony_ci-#endif 7000be168c0dSopenharmony_ci- 7001be168c0dSopenharmony_ci- lstm_param_->state_row_align_ = state_is_vec_ ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_); 7002be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 7003be168c0dSopenharmony_ci- lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7004be168c0dSopenharmony_ci-#else 7005be168c0dSopenharmony_ci- lstm_param_->state_col_align_ = 7006be168c0dSopenharmony_ci- state_is_vec_ ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7007be168c0dSopenharmony_ci-#endif 7008be168c0dSopenharmony_ci- return RET_OK; 7009be168c0dSopenharmony_ci+constexpr size_t kMindirInputTensorNum = 4; 7010be168c0dSopenharmony_ci } 7011be168c0dSopenharmony_ci- 7012be168c0dSopenharmony_ci-int LstmCPUKernel::Prepare() { 7013be168c0dSopenharmony_ci- CHECK_LESS_RETURN(in_tensors_.size(), mindir_input_tensors); 7014be168c0dSopenharmony_ci- for (size_t i = 0; i < in_tensors_.size(); i++) { 7015be168c0dSopenharmony_ci- CHECK_NULL_RETURN(in_tensors_.at(i)); 7016be168c0dSopenharmony_ci- } 7017be168c0dSopenharmony_ci- CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_3D); 7018be168c0dSopenharmony_ci- for (size_t i = 0; i < out_tensors_.size(); i++) { 7019be168c0dSopenharmony_ci- CHECK_NULL_RETURN(out_tensors_.at(i)); 7020be168c0dSopenharmony_ci- } 7021be168c0dSopenharmony_ci- CHECK_NULL_RETURN(lstm_param_); 7022be168c0dSopenharmony_ci- if (!InferShapeDone()) { 7023be168c0dSopenharmony_ci- return RET_OK; 7024be168c0dSopenharmony_ci- } 7025be168c0dSopenharmony_ci- return ReSize(); 7026be168c0dSopenharmony_ci-} 7027be168c0dSopenharmony_ci- 7028be168c0dSopenharmony_ci-int LstmCPUKernel::ReSize() { 7029be168c0dSopenharmony_ci- auto ret = InitParam(); 7030be168c0dSopenharmony_ci- if (ret != RET_OK) { 7031be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel InitParam error."; 7032be168c0dSopenharmony_ci- return RET_ERROR; 7033be168c0dSopenharmony_ci+LiteKernel *LstmFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs, 7034be168c0dSopenharmony_ci+ OpParameter *parameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc) { 7035be168c0dSopenharmony_ci+ if (parameter == nullptr) { 7036be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "parameter is nullptr."; 7037be168c0dSopenharmony_ci+ return nullptr; 7038be168c0dSopenharmony_ci } 7039be168c0dSopenharmony_ci- 7040be168c0dSopenharmony_ci- return RET_OK; 7041be168c0dSopenharmony_ci-} 7042be168c0dSopenharmony_ci- 7043be168c0dSopenharmony_ci-int LstmCPUKernel::MallocRunBuffer(bool is_double) { 7044be168c0dSopenharmony_ci- bool need_zone = lstm_param_->zoneout_cell_ < -FLT_EPSILON || lstm_param_->zoneout_cell_ > FLT_EPSILON; 7045be168c0dSopenharmony_ci- size_t whole_size = 0; 7046be168c0dSopenharmony_ci- std::vector<size_t> segments; 7047be168c0dSopenharmony_ci- int scale = is_double ? C2NUM : 1; 7048be168c0dSopenharmony_ci- size_t segment = gate_num * lstm_param_->seq_len_ * lstm_param_->batch_ * 7049be168c0dSopenharmony_ci- lstm_param_->hidden_size_; // 0: input * weight for result matrix 7050be168c0dSopenharmony_ci- segments.push_back(segment); 7051be168c0dSopenharmony_ci- whole_size += segment * scale; 7052be168c0dSopenharmony_ci- 7053be168c0dSopenharmony_ci- segment = state_is_vec_ 7054be168c0dSopenharmony_ci- ? 0 7055be168c0dSopenharmony_ci- : lstm_param_->state_row_align_ * lstm_param_->project_size_; // 1: state * weight for left matirx 7056be168c0dSopenharmony_ci- segments.push_back(segment); 7057be168c0dSopenharmony_ci- whole_size += segment * scale; 7058be168c0dSopenharmony_ci- 7059be168c0dSopenharmony_ci- segment = gate_num * lstm_param_->batch_ * lstm_param_->hidden_size_; // 2: state gate buffer 7060be168c0dSopenharmony_ci- segments.push_back(segment); 7061be168c0dSopenharmony_ci- whole_size += segment * scale; 7062be168c0dSopenharmony_ci- 7063be168c0dSopenharmony_ci- segment = need_zone ? lstm_param_->batch_ * lstm_param_->hidden_size_ : 0; // 3: state_buffer for cell 7064be168c0dSopenharmony_ci- segments.push_back(segment); 7065be168c0dSopenharmony_ci- whole_size += segment * scale; 7066be168c0dSopenharmony_ci- 7067be168c0dSopenharmony_ci- segment = need_zone ? lstm_param_->batch_ * lstm_param_->project_size_ : 0; // 4: state_buffer for hidden 7068be168c0dSopenharmony_ci- segments.push_back(segment); 7069be168c0dSopenharmony_ci- whole_size += segment * scale; 7070be168c0dSopenharmony_ci- 7071be168c0dSopenharmony_ci- segment = 0; 7072be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 7073be168c0dSopenharmony_ci- bool output_need_packed = lstm_param_->hidden_size_ % state_col_tile_; 7074be168c0dSopenharmony_ci- if (state_is_vec_ && output_need_packed) { // vec matmul need to malloc dst 7075be168c0dSopenharmony_ci- int out_channel = lstm_param_->hidden_size_; 7076be168c0dSopenharmony_ci- int oc_block_num = UP_DIV(out_channel, state_col_tile_); 7077be168c0dSopenharmony_ci- MS_ASSERT(ms_context_->allocator != nullptr); 7078be168c0dSopenharmony_ci- segment = lstm_param_->batch_ * oc_block_num * state_col_tile_; // 5: tmp output data 7079be168c0dSopenharmony_ci+ if (desc.data_type == kTypeUnknown) { 7080be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "desc data_type is unknown."; 7081be168c0dSopenharmony_ci } 7082be168c0dSopenharmony_ci-#endif 7083be168c0dSopenharmony_ci- segments.push_back(segment); 7084be168c0dSopenharmony_ci- whole_size += segment * scale; 7085be168c0dSopenharmony_ci- 7086be168c0dSopenharmony_ci- if (in_tensors_.size() == C7NUM) { 7087be168c0dSopenharmony_ci- segment = state_is_vec_ ? 0 : lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * scale; 7088be168c0dSopenharmony_ci- segments.push_back(segment); // 6: project-layer input 7089be168c0dSopenharmony_ci- whole_size += segment; 7090be168c0dSopenharmony_ci- segment = 0; 7091be168c0dSopenharmony_ci-#ifdef ENABLE_AVX 7092be168c0dSopenharmony_ci- segment = 7093be168c0dSopenharmony_ci- output_need_packed ? lstm_param_->batch_ * UP_ROUND(lstm_param_->project_size_, state_col_tile_) * scale : 0; 7094be168c0dSopenharmony_ci-#endif 7095be168c0dSopenharmony_ci- segments.push_back(segment); // 7: project-layer output 7096be168c0dSopenharmony_ci- whole_size += segment; 7097be168c0dSopenharmony_ci+ LiteKernel *kernel{nullptr}; 7098be168c0dSopenharmony_ci+ if (inputs.size() == kMindirInputTensorNum) { 7099be168c0dSopenharmony_ci+ kernel = new (std::nothrow) 7100be168c0dSopenharmony_ci+ LstmMindirFp32CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 7101be168c0dSopenharmony_ci } else { 7102be168c0dSopenharmony_ci- (void)segments.insert(segments.end(), C2NUM, 0); 7103be168c0dSopenharmony_ci- } 7104be168c0dSopenharmony_ci- 7105be168c0dSopenharmony_ci- segment = 0; 7106be168c0dSopenharmony_ci- if (!(in_tensors_.size() > mindir_input_tensors)) { 7107be168c0dSopenharmony_ci- segment = lstm_param_->batch_ * lstm_param_->hidden_size_; 7108be168c0dSopenharmony_ci- } 7109be168c0dSopenharmony_ci- segments.push_back(segment); 7110be168c0dSopenharmony_ci- whole_size += segment * scale; 7111be168c0dSopenharmony_ci- 7112be168c0dSopenharmony_ci- segment = 7113be168c0dSopenharmony_ci- lstm_param_->input_row_align_ * lstm_param_->input_size_; // input * weight for left matrix, which only once 7114be168c0dSopenharmony_ci- whole_size += segment; 7115be168c0dSopenharmony_ci- 7116be168c0dSopenharmony_ci- auto whole_memory = reinterpret_cast<float *>(ms_context_->allocator->Malloc(whole_size * sizeof(float))); 7117be168c0dSopenharmony_ci- MS_CHECK_TRUE_MSG(whole_memory != nullptr, RET_ERROR, "LSTM: malloc failed."); 7118be168c0dSopenharmony_ci- buffer_running_malloc_.push_back(whole_memory); 7119be168c0dSopenharmony_ci- MS_ASSERT(segments.size() == C9NUM); 7120be168c0dSopenharmony_ci- auto Allocate = [&whole_memory, &segments](float **buffer) mutable { 7121be168c0dSopenharmony_ci- for (int i = 0; i < C9NUM; ++i) { 7122be168c0dSopenharmony_ci- buffer[i] = nullptr; 7123be168c0dSopenharmony_ci- if (segments[i] == 0) { 7124be168c0dSopenharmony_ci- continue; 7125be168c0dSopenharmony_ci- } 7126be168c0dSopenharmony_ci- buffer[i] = whole_memory; 7127be168c0dSopenharmony_ci- whole_memory += segments[i]; 7128be168c0dSopenharmony_ci- } 7129be168c0dSopenharmony_ci- }; 7130be168c0dSopenharmony_ci- Allocate(buffer_forward_); 7131be168c0dSopenharmony_ci- if (is_double) { 7132be168c0dSopenharmony_ci- Allocate(buffer_backward_); 7133be168c0dSopenharmony_ci- } 7134be168c0dSopenharmony_ci- packed_input_ = whole_memory; 7135be168c0dSopenharmony_ci- return RET_OK; 7136be168c0dSopenharmony_ci-} 7137be168c0dSopenharmony_ci- 7138be168c0dSopenharmony_ci-void LstmCPUKernel::InputWeightMatMul(int task_id) const { 7139be168c0dSopenharmony_ci- int current_start_oc = task_id * input_thread_stride_ * col_tile_; 7140be168c0dSopenharmony_ci- int current_rest_oc = 0; 7141be168c0dSopenharmony_ci- current_rest_oc = lstm_param_->hidden_size_ - current_start_oc; 7142be168c0dSopenharmony_ci- int cur_oc = MSMIN(input_thread_stride_ * col_tile_, current_rest_oc); 7143be168c0dSopenharmony_ci- if (cur_oc <= 0) { 7144be168c0dSopenharmony_ci- return; 7145be168c0dSopenharmony_ci- } 7146be168c0dSopenharmony_ci- 7147be168c0dSopenharmony_ci- auto b = weight_loop_ + current_start_oc * lstm_param_->input_size_; 7148be168c0dSopenharmony_ci- auto c = gate_loop_ + current_start_oc; 7149be168c0dSopenharmony_ci- auto bias = (bias_loop_ == nullptr) ? nullptr : bias_loop_ + current_start_oc; 7150be168c0dSopenharmony_ci- MatMulOpt(packed_input_, b, c, bias, ActType_No, lstm_param_->input_size_, 7151be168c0dSopenharmony_ci- lstm_param_->seq_len_ * lstm_param_->batch_, cur_oc, lstm_param_->hidden_size_, OutType_Nhwc); 7152be168c0dSopenharmony_ci-} 7153be168c0dSopenharmony_ci- 7154be168c0dSopenharmony_ci-int LstmCPUKernel::DoSequenceLoop(int task_id) { 7155be168c0dSopenharmony_ci- if (task_id == 0) { 7156be168c0dSopenharmony_ci- LstmForwardLoop(buffer_forward_); 7157be168c0dSopenharmony_ci- return RET_OK; 7158be168c0dSopenharmony_ci- } 7159be168c0dSopenharmony_ci- if (task_id == 1) { 7160be168c0dSopenharmony_ci- LstmBackwardLoop(buffer_backward_); 7161be168c0dSopenharmony_ci- return RET_OK; 7162be168c0dSopenharmony_ci- } 7163be168c0dSopenharmony_ci- return RET_ERROR; 7164be168c0dSopenharmony_ci-} 7165be168c0dSopenharmony_ci- 7166be168c0dSopenharmony_ci-int LstmCPUKernel::LstmPreProcessWithInput(const float *weight_i, const float *input_bias, float *dst) { 7167be168c0dSopenharmony_ci- for (int i = 0; i < gate_num; i++) { 7168be168c0dSopenharmony_ci- weight_loop_ = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i; 7169be168c0dSopenharmony_ci- bias_loop_ = input_bias + lstm_param_->input_col_align_ * i; 7170be168c0dSopenharmony_ci- gate_loop_ = dst + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i; 7171be168c0dSopenharmony_ci- auto ret = ParallelLaunch(this->ms_context_, LstmInputMulWeightRun, this, input_thread_count_); 7172be168c0dSopenharmony_ci- if (ret != RET_OK) { 7173be168c0dSopenharmony_ci- return RET_ERROR; 7174be168c0dSopenharmony_ci- } 7175be168c0dSopenharmony_ci- } 7176be168c0dSopenharmony_ci- return RET_OK; 7177be168c0dSopenharmony_ci-} 7178be168c0dSopenharmony_ci- 7179be168c0dSopenharmony_ci-void LstmCPUKernel::LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, 7180be168c0dSopenharmony_ci- float *hidden_state, float *cell_state, const float *weight_project, 7181be168c0dSopenharmony_ci- float *intermediate_states, float *buffer[], bool is_backward) { 7182be168c0dSopenharmony_ci- float *gate = buffer[input_gate_index]; 7183be168c0dSopenharmony_ci- float *input_gate = gate; 7184be168c0dSopenharmony_ci- float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C2NUM; 7185be168c0dSopenharmony_ci- float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C3NUM; 7186be168c0dSopenharmony_ci- float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; 7187be168c0dSopenharmony_ci- float *tmp = buffer[tmp_hidden_output_index]; 7188be168c0dSopenharmony_ci- int dir_mult = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 7189be168c0dSopenharmony_ci- for (int t = 0; t < lstm_param_->seq_len_; t++) { 7190be168c0dSopenharmony_ci- int real_t = is_backward ? lstm_param_->seq_len_ - t - C1NUM : t; 7191be168c0dSopenharmony_ci- float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7192be168c0dSopenharmony_ci- float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7193be168c0dSopenharmony_ci- float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7194be168c0dSopenharmony_ci- float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 7195be168c0dSopenharmony_ci- // if ONNX 7196be168c0dSopenharmony_ci- if (in_tensors_.size() > mindir_input_tensors) { 7197be168c0dSopenharmony_ci- // Sequence, DirMul, Batch, Hidden 7198be168c0dSopenharmony_ci- float *output_ptr = output + real_t * lstm_param_->output_step_; 7199be168c0dSopenharmony_ci- 7200be168c0dSopenharmony_ci- LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, 7201be168c0dSopenharmony_ci- weight_project, hidden_state, cell_state, buffer, lstm_param_); 7202be168c0dSopenharmony_ci- } else { 7203be168c0dSopenharmony_ci- // Sequence, Batch, DirMul, Hidden 7204be168c0dSopenharmony_ci- LstmStepUnit(tmp, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, nullptr, 7205be168c0dSopenharmony_ci- hidden_state, cell_state, buffer, lstm_param_); 7206be168c0dSopenharmony_ci- int seq_offset = real_t * lstm_param_->batch_ * dir_mult * lstm_param_->hidden_size_; 7207be168c0dSopenharmony_ci- for (int b = 0; b < lstm_param_->batch_; b++) { 7208be168c0dSopenharmony_ci- int batch_offset = b * dir_mult * lstm_param_->hidden_size_; 7209be168c0dSopenharmony_ci- float *output_ptr = output + seq_offset + batch_offset; 7210be168c0dSopenharmony_ci- memcpy(output_ptr, tmp + b * lstm_param_->hidden_size_, lstm_param_->hidden_size_ * sizeof(float)); 7211be168c0dSopenharmony_ci- } 7212be168c0dSopenharmony_ci- } 7213be168c0dSopenharmony_ci- if (intermediate_states) { 7214be168c0dSopenharmony_ci- RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t, 7215be168c0dSopenharmony_ci- intermediate_states, real_t); 7216be168c0dSopenharmony_ci- } 7217be168c0dSopenharmony_ci- } 7218be168c0dSopenharmony_ci-} 7219be168c0dSopenharmony_ci- 7220be168c0dSopenharmony_ci-void LstmCPUKernel::RecordStates(const float *hidden_state, float *cell_state, float *input_gate, 7221be168c0dSopenharmony_ci- const float *output_gate, float *forget_gate, const float *cell_gate, 7222be168c0dSopenharmony_ci- float *intermediate_states, int step) { 7223be168c0dSopenharmony_ci- float *states = intermediate_states; 7224be168c0dSopenharmony_ci- auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; 7225be168c0dSopenharmony_ci- if (state_size < 0) { 7226be168c0dSopenharmony_ci- MS_LOG(ERROR) << "state size should be greater than or equal to zero."; 7227be168c0dSopenharmony_ci- return; 7228be168c0dSopenharmony_ci- } 7229be168c0dSopenharmony_ci- auto stride = step * lstm_param_->output_step_; 7230be168c0dSopenharmony_ci- auto seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; 7231be168c0dSopenharmony_ci- memcpy(states + stride, hidden_state, state_size * sizeof(float)); 7232be168c0dSopenharmony_ci- stride += seq_stride; 7233be168c0dSopenharmony_ci- memcpy(states + stride, cell_state, state_size * sizeof(float)); 7234be168c0dSopenharmony_ci- stride += seq_stride; 7235be168c0dSopenharmony_ci- memcpy(states + stride, input_gate, state_size * sizeof(float)); 7236be168c0dSopenharmony_ci- stride += seq_stride; 7237be168c0dSopenharmony_ci- memcpy(states + stride, output_gate, state_size * sizeof(float)); 7238be168c0dSopenharmony_ci- stride += seq_stride; 7239be168c0dSopenharmony_ci- memcpy(states + stride, forget_gate, state_size * sizeof(float)); 7240be168c0dSopenharmony_ci- stride += seq_stride; 7241be168c0dSopenharmony_ci- memcpy(states + stride, cell_gate, state_size * sizeof(float)); 7242be168c0dSopenharmony_ci-} 7243be168c0dSopenharmony_ci- 7244be168c0dSopenharmony_ci-void LstmCPUKernel::LstmForwardLoop(float *buffer[]) { 7245be168c0dSopenharmony_ci- auto *output = reinterpret_cast<float *>(out_tensors_.at(0)->data()); 7246be168c0dSopenharmony_ci- auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(1)->data()); 7247be168c0dSopenharmony_ci- auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(C2NUM)->data()); 7248be168c0dSopenharmony_ci- LstmUnidirectional(output, weight_h_ptr_, state_bias_, hidden_state, cell_state, weight_project_ptr_, 7249be168c0dSopenharmony_ci- intermediate_states_, buffer, false); 7250be168c0dSopenharmony_ci-} 7251be168c0dSopenharmony_ci- 7252be168c0dSopenharmony_ci-void LstmCPUKernel::LstmBackwardLoop(float *buffer[]) { 7253be168c0dSopenharmony_ci- auto *output = reinterpret_cast<float *>(out_tensors_.at(0)->data()); 7254be168c0dSopenharmony_ci- auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(1)->data()); 7255be168c0dSopenharmony_ci- auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(C2NUM)->data()); 7256be168c0dSopenharmony_ci- const float *backward_weight_h = weight_h_ptr_ + gate_num * lstm_param_->state_col_align_ * lstm_param_->hidden_size_; 7257be168c0dSopenharmony_ci- const float *backward_state_bias = state_bias_ + gate_num * lstm_param_->state_col_align_; 7258be168c0dSopenharmony_ci- float *backward_output = output + lstm_param_->batch_ * lstm_param_->hidden_size_; 7259be168c0dSopenharmony_ci- if (in_tensors_.size() == mindir_input_tensors) { 7260be168c0dSopenharmony_ci- backward_output = output + lstm_param_->hidden_size_; 7261be168c0dSopenharmony_ci- } 7262be168c0dSopenharmony_ci- float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_; 7263be168c0dSopenharmony_ci- float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->hidden_size_; 7264be168c0dSopenharmony_ci- float *intermediate_states = nullptr; 7265be168c0dSopenharmony_ci- if (intermediate_states_) { 7266be168c0dSopenharmony_ci- intermediate_states = intermediate_states_ + lstm_param_->batch_ * lstm_param_->hidden_size_; 7267be168c0dSopenharmony_ci- } 7268be168c0dSopenharmony_ci- float *backward_weight_project = 7269be168c0dSopenharmony_ci- weight_project_ptr_ 7270be168c0dSopenharmony_ci- ? weight_project_ptr_ + lstm_param_->hidden_size_ * UP_ROUND(lstm_param_->project_size_, col_tile_) 7271be168c0dSopenharmony_ci- : nullptr; 7272be168c0dSopenharmony_ci- LstmUnidirectional(backward_output, backward_weight_h, backward_state_bias, backward_hidden_state, 7273be168c0dSopenharmony_ci- backward_cell_state, backward_weight_project, intermediate_states, buffer, true); 7274be168c0dSopenharmony_ci-} 7275be168c0dSopenharmony_ci- 7276be168c0dSopenharmony_ci-int LstmCPUKernel::ExecuteUnidirectionalOrSingleThread() { 7277be168c0dSopenharmony_ci- auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[input_gate_index]); 7278be168c0dSopenharmony_ci- if (ret != RET_OK) { 7279be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7280be168c0dSopenharmony_ci- return RET_ERROR; 7281be168c0dSopenharmony_ci- } 7282be168c0dSopenharmony_ci- LstmForwardLoop(buffer_forward_); 7283be168c0dSopenharmony_ci- 7284be168c0dSopenharmony_ci- // backward 7285be168c0dSopenharmony_ci- if (lstm_param_->bidirectional_) { 7286be168c0dSopenharmony_ci- const float *backward_weight_i = 7287be168c0dSopenharmony_ci- weight_i_ptr_ + gate_num * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7288be168c0dSopenharmony_ci- const float *backward_input_bias = input_bias_ + gate_num * lstm_param_->input_col_align_; 7289be168c0dSopenharmony_ci- ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_forward_[input_gate_index]); 7290be168c0dSopenharmony_ci- if (ret != RET_OK) { 7291be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7292be168c0dSopenharmony_ci- return RET_ERROR; 7293be168c0dSopenharmony_ci- } 7294be168c0dSopenharmony_ci- LstmBackwardLoop(buffer_forward_); 7295be168c0dSopenharmony_ci- } 7296be168c0dSopenharmony_ci- return RET_OK; 7297be168c0dSopenharmony_ci-} 7298be168c0dSopenharmony_ci- 7299be168c0dSopenharmony_ci-int LstmCPUKernel::ExecuteBidirectionalWithMultiThread() { 7300be168c0dSopenharmony_ci- auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[input_gate_index]); 7301be168c0dSopenharmony_ci- if (ret != RET_OK) { 7302be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7303be168c0dSopenharmony_ci- return RET_ERROR; 7304be168c0dSopenharmony_ci- } 7305be168c0dSopenharmony_ci- const float *backward_weight_i = weight_i_ptr_ + gate_num * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7306be168c0dSopenharmony_ci- const float *backward_input_bias = input_bias_ + gate_num * lstm_param_->input_col_align_; 7307be168c0dSopenharmony_ci- ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_backward_[input_gate_index]); 7308be168c0dSopenharmony_ci- if (ret != RET_OK) { 7309be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7310be168c0dSopenharmony_ci- return RET_ERROR; 7311be168c0dSopenharmony_ci- } 7312be168c0dSopenharmony_ci- ret = ParallelLaunch(this->ms_context_, LstmSequenceLoopRun, this, C2NUM); 7313be168c0dSopenharmony_ci- if (ret != RET_OK) { 7314be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LSTM: Do sequence-loop failed."; 7315be168c0dSopenharmony_ci- } 7316be168c0dSopenharmony_ci- return ret; 7317be168c0dSopenharmony_ci-} 7318be168c0dSopenharmony_ci- 7319be168c0dSopenharmony_ci-int LstmCPUKernel::Run() { 7320be168c0dSopenharmony_ci- auto input = in_tensors_.at(0); 7321be168c0dSopenharmony_ci- auto output = out_tensors_.at(0); 7322be168c0dSopenharmony_ci- CHECK_NULL_RETURN(input); 7323be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output); 7324be168c0dSopenharmony_ci- auto input_ptr = reinterpret_cast<float *>(input->data()); 7325be168c0dSopenharmony_ci- CHECK_NULL_RETURN(input_ptr); 7326be168c0dSopenharmony_ci- auto output_ptr = reinterpret_cast<float *>(output->data()); 7327be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output_ptr); 7328be168c0dSopenharmony_ci- 7329be168c0dSopenharmony_ci- auto hidden_state = in_tensors_.at(hidden_state_input_index_); 7330be168c0dSopenharmony_ci- CHECK_NULL_RETURN(hidden_state->data()); 7331be168c0dSopenharmony_ci- auto cell_state = in_tensors_.at(cell_state_input_index_); 7332be168c0dSopenharmony_ci- CHECK_NULL_RETURN(cell_state->data()); 7333be168c0dSopenharmony_ci- 7334be168c0dSopenharmony_ci- auto output_hidden_state = out_tensors_[kOutputHiddenStatusIndex]; 7335be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output_hidden_state->data()); 7336be168c0dSopenharmony_ci- (void)memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float)); 7337be168c0dSopenharmony_ci- auto output_cell_state = out_tensors_[kOutputCellStatusIndex]; 7338be168c0dSopenharmony_ci- CHECK_NULL_RETURN(output_cell_state->data()); 7339be168c0dSopenharmony_ci- (void)memcpy(output_cell_state->data(), cell_state->data(), cell_state->ElementsNum() * sizeof(float)); 7340be168c0dSopenharmony_ci- 7341be168c0dSopenharmony_ci- auto ret = InitInputWeightBias(); 7342be168c0dSopenharmony_ci- if (ret != RET_OK) { 7343be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel InitInputWeightBias error."; 7344be168c0dSopenharmony_ci- FreeRunBuffer(); 7345be168c0dSopenharmony_ci- return RET_ERROR; 7346be168c0dSopenharmony_ci- } 7347be168c0dSopenharmony_ci- 7348be168c0dSopenharmony_ci- ret = InitStateWeightBias(); 7349be168c0dSopenharmony_ci- if (ret != RET_OK) { 7350be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel InitStateWeightBias error."; 7351be168c0dSopenharmony_ci- FreeRunBuffer(); 7352be168c0dSopenharmony_ci- return RET_ERROR; 7353be168c0dSopenharmony_ci+ kernel = new (std::nothrow) 7354be168c0dSopenharmony_ci+ LstmNonMindirFp32CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx)); 7355be168c0dSopenharmony_ci } 7356be168c0dSopenharmony_ci- 7357be168c0dSopenharmony_ci- ret = InitProjectWeight(); 7358be168c0dSopenharmony_ci- if (ret != RET_OK) { 7359be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel InitProjectWeight error."; 7360be168c0dSopenharmony_ci- FreeRunBuffer(); 7361be168c0dSopenharmony_ci- return RET_ERROR; 7362be168c0dSopenharmony_ci- } 7363be168c0dSopenharmony_ci- bool is_bidirectional_with_multi_thread = thread_num_ != 1 && lstm_param_->bidirectional_; 7364be168c0dSopenharmony_ci- ret = MallocRunBuffer(is_bidirectional_with_multi_thread); 7365be168c0dSopenharmony_ci- if (ret != RET_OK) { 7366be168c0dSopenharmony_ci- MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer Error."; 7367be168c0dSopenharmony_ci- FreeRunBuffer(); 7368be168c0dSopenharmony_ci- return RET_ERROR; 7369be168c0dSopenharmony_ci- } 7370be168c0dSopenharmony_ci- 7371be168c0dSopenharmony_ci- PackLstmInput(input_ptr, packed_input_, lstm_param_->seq_len_ * lstm_param_->batch_, lstm_param_->input_size_); 7372be168c0dSopenharmony_ci- if (IsTrain() && IsTrainable()) { 7373be168c0dSopenharmony_ci- intermediate_states_ = reinterpret_cast<float *>(out_tensors_[out_intermediate_states_index]->data()); 7374be168c0dSopenharmony_ci+ if (kernel == nullptr) { 7375be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "kernel: " << parameter->name_ << "is nullptr."; 7376be168c0dSopenharmony_ci+ free(parameter); 7377be168c0dSopenharmony_ci+ return nullptr; 7378be168c0dSopenharmony_ci } 7379be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_h_ptr_); 7380be168c0dSopenharmony_ci- CHECK_NULL_RETURN(weight_i_ptr_); 7381be168c0dSopenharmony_ci- CHECK_NULL_RETURN(input_bias_); 7382be168c0dSopenharmony_ci- CHECK_NULL_RETURN(state_bias_); 7383be168c0dSopenharmony_ci- if (is_bidirectional_with_multi_thread) { 7384be168c0dSopenharmony_ci- ret = ExecuteBidirectionalWithMultiThread(); 7385be168c0dSopenharmony_ci- } else { 7386be168c0dSopenharmony_ci- ret = ExecuteUnidirectionalOrSingleThread(); 7387be168c0dSopenharmony_ci- } 7388be168c0dSopenharmony_ci- FreeRunBuffer(); 7389be168c0dSopenharmony_ci- return ret; 7390be168c0dSopenharmony_ci+ return kernel; 7391be168c0dSopenharmony_ci } 7392be168c0dSopenharmony_ci- 7393be168c0dSopenharmony_ci-REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTM, LiteKernelCreator<LstmCPUKernel>) 7394be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LSTM, LstmFp32KernelCreator) 7395be168c0dSopenharmony_ci } // namespace mindspore::kernel 7396be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc 7397be168c0dSopenharmony_cinew file mode 100644 7398be168c0dSopenharmony_ciindex 00000000..bd0f0e7d 7399be168c0dSopenharmony_ci--- /dev/null 7400be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.cc 7401be168c0dSopenharmony_ci@@ -0,0 +1,398 @@ 7402be168c0dSopenharmony_ci+/** 7403be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 7404be168c0dSopenharmony_ci+ * 7405be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 7406be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 7407be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 7408be168c0dSopenharmony_ci+ * 7409be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 7410be168c0dSopenharmony_ci+ * 7411be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 7412be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 7413be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7414be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 7415be168c0dSopenharmony_ci+ * limitations under the License. 7416be168c0dSopenharmony_ci+ */ 7417be168c0dSopenharmony_ci+ 7418be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" 7419be168c0dSopenharmony_ci+#include <vector> 7420be168c0dSopenharmony_ci+#include "include/errorcode.h" 7421be168c0dSopenharmony_ci+#include "nnacl/fp32/pack_fp32.h" 7422be168c0dSopenharmony_ci+#include "nnacl/fp32/matmul_fp32.h" 7423be168c0dSopenharmony_ci+ 7424be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 7425be168c0dSopenharmony_ci+using mindspore::lite::RET_MEMORY_FAILED; 7426be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 7427be168c0dSopenharmony_ci+ 7428be168c0dSopenharmony_ci+namespace mindspore::kernel { 7429be168c0dSopenharmony_ci+namespace { 7430be168c0dSopenharmony_ci+constexpr size_t kMindirInputTensorNum = 4; 7431be168c0dSopenharmony_ci+constexpr int kGateNum = 4; 7432be168c0dSopenharmony_ci+constexpr int kOutIntermediateStatesIndex = 3; 7433be168c0dSopenharmony_ci+constexpr int kInputGateIndex = 0; 7434be168c0dSopenharmony_ci+} // namespace 7435be168c0dSopenharmony_ci+ 7436be168c0dSopenharmony_ci+int LstmSequenceLoopRun(void *cdata, int task_id, float, float) { 7437be168c0dSopenharmony_ci+ auto kernel = reinterpret_cast<LstmFp32BaseCPUKernel *>(cdata); 7438be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(kernel); 7439be168c0dSopenharmony_ci+ auto ret = kernel->DoSequenceLoop(task_id); 7440be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7441be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LSTM: Do Sequence-loop failed."; 7442be168c0dSopenharmony_ci+ } 7443be168c0dSopenharmony_ci+ return ret; 7444be168c0dSopenharmony_ci+} 7445be168c0dSopenharmony_ci+ 7446be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::Prepare() { 7447be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in_tensors_.size() == kMindirInputTensorNum || in_tensors_.size() >= C6NUM, 7448be168c0dSopenharmony_ci+ lite::RET_INPUT_TENSOR_ERROR, "Lstm's input-num is invalid."); 7449be168c0dSopenharmony_ci+ for (size_t i = 0; i < in_tensors_.size(); i++) { 7450be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(in_tensors_.at(i)); 7451be168c0dSopenharmony_ci+ } 7452be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(out_tensors_.size(), DIMENSION_3D); 7453be168c0dSopenharmony_ci+ for (size_t i = 0; i < out_tensors_.size(); i++) { 7454be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(out_tensors_.at(i)); 7455be168c0dSopenharmony_ci+ } 7456be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(lstm_param_); 7457be168c0dSopenharmony_ci+ if (!InferShapeDone()) { 7458be168c0dSopenharmony_ci+ return RET_OK; 7459be168c0dSopenharmony_ci+ } 7460be168c0dSopenharmony_ci+ return ReSize(); 7461be168c0dSopenharmony_ci+} 7462be168c0dSopenharmony_ci+ 7463be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::ReSize() { 7464be168c0dSopenharmony_ci+ auto input = in_tensors_.front(); 7465be168c0dSopenharmony_ci+ std::vector<int> in_shape = input->shape(); 7466be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in_shape.size() == C3NUM, lite::RET_INPUT_TENSOR_ERROR, 7467be168c0dSopenharmony_ci+ "The dims of LSTM's first input must be 3."); 7468be168c0dSopenharmony_ci+ lstm_param_->seq_len_ = in_shape.at(FIRST_INPUT); 7469be168c0dSopenharmony_ci+ lstm_param_->batch_ = in_shape.at(SECOND_INPUT); 7470be168c0dSopenharmony_ci+ lstm_param_->input_size_ = in_shape.at(THIRD_INPUT); 7471be168c0dSopenharmony_ci+ 7472be168c0dSopenharmony_ci+ auto h_init_shape = in_tensors_.at(hidden_init_index_)->shape(); 7473be168c0dSopenharmony_ci+ auto c_init_shape = in_tensors_.at(cell_init_index_)->shape(); 7474be168c0dSopenharmony_ci+ lstm_param_->hidden_size_ = c_init_shape.back(); 7475be168c0dSopenharmony_ci+ lstm_param_->output_size_ = h_init_shape.back(); 7476be168c0dSopenharmony_ci+ 7477be168c0dSopenharmony_ci+ lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->output_size_ 7478be168c0dSopenharmony_ci+ : lstm_param_->batch_ * lstm_param_->output_size_; 7479be168c0dSopenharmony_ci+ weight_segment_num_ = lstm_param_->bidirectional_ ? C2NUM * kGateNum : kGateNum; 7480be168c0dSopenharmony_ci+ 7481be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 7482be168c0dSopenharmony_ci+ row_tile_ = C6NUM; 7483be168c0dSopenharmony_ci+ col_tile_ = C16NUM; 7484be168c0dSopenharmony_ci+#elif defined(ENABLE_ARM32) 7485be168c0dSopenharmony_ci+ row_tile_ = C12NUM; 7486be168c0dSopenharmony_ci+ col_tile_ = C4NUM; 7487be168c0dSopenharmony_ci+#elif defined(ENABLE_SSE) 7488be168c0dSopenharmony_ci+ row_tile_ = C4NUM; 7489be168c0dSopenharmony_ci+ col_tile_ = C8NUM; 7490be168c0dSopenharmony_ci+#else 7491be168c0dSopenharmony_ci+ row_tile_ = C12NUM; 7492be168c0dSopenharmony_ci+ col_tile_ = C8NUM; 7493be168c0dSopenharmony_ci+#endif 7494be168c0dSopenharmony_ci+ lstm_param_->input_row_align_ = UP_ROUND(lstm_param_->seq_len_ * lstm_param_->batch_, row_tile_); 7495be168c0dSopenharmony_ci+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, col_tile_); 7496be168c0dSopenharmony_ci+ 7497be168c0dSopenharmony_ci+ state_row_tile_ = row_tile_; 7498be168c0dSopenharmony_ci+ state_col_tile_ = col_tile_; 7499be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 7500be168c0dSopenharmony_ci+ if (lstm_param_->batch_ == 1) { 7501be168c0dSopenharmony_ci+ state_row_tile_ = 1; 7502be168c0dSopenharmony_ci+ state_col_tile_ = C8NUM; 7503be168c0dSopenharmony_ci+ } 7504be168c0dSopenharmony_ci+#endif 7505be168c0dSopenharmony_ci+ 7506be168c0dSopenharmony_ci+ lstm_param_->state_row_align_ = lstm_param_->batch_ == 1 ? 1 : UP_ROUND(lstm_param_->batch_, state_row_tile_); 7507be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 7508be168c0dSopenharmony_ci+ lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7509be168c0dSopenharmony_ci+ lstm_param_->proj_col_align_ = UP_ROUND(lstm_param_->output_size_, state_col_tile_); 7510be168c0dSopenharmony_ci+#else 7511be168c0dSopenharmony_ci+ lstm_param_->state_col_align_ = 7512be168c0dSopenharmony_ci+ lstm_param_->batch_ == 1 ? lstm_param_->hidden_size_ : UP_ROUND(lstm_param_->hidden_size_, state_col_tile_); 7513be168c0dSopenharmony_ci+ lstm_param_->proj_col_align_ = 7514be168c0dSopenharmony_ci+ lstm_param_->batch_ == 1 ? lstm_param_->output_size_ : UP_ROUND(lstm_param_->output_size_, state_col_tile_); 7515be168c0dSopenharmony_ci+#endif 7516be168c0dSopenharmony_ci+ return RET_OK; 7517be168c0dSopenharmony_ci+} 7518be168c0dSopenharmony_ci+ 7519be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::Run() { 7520be168c0dSopenharmony_ci+ auto input = in_tensors_.at(FIRST_INPUT); 7521be168c0dSopenharmony_ci+ auto output = out_tensors_.at(FIRST_INPUT); 7522be168c0dSopenharmony_ci+ auto input_ptr = reinterpret_cast<float *>(input->data()); 7523be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input_ptr); 7524be168c0dSopenharmony_ci+ auto output_ptr = reinterpret_cast<float *>(output->data()); 7525be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_ptr); 7526be168c0dSopenharmony_ci+ 7527be168c0dSopenharmony_ci+ auto hidden_state = in_tensors_.at(hidden_init_index_); 7528be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(hidden_state->data()); 7529be168c0dSopenharmony_ci+ auto cell_state = in_tensors_.at(cell_init_index_); 7530be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(cell_state->data()); 7531be168c0dSopenharmony_ci+ 7532be168c0dSopenharmony_ci+ auto output_hidden_state = out_tensors_[SECOND_INPUT]; 7533be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_hidden_state->data()); 7534be168c0dSopenharmony_ci+ (void)memcpy(output_hidden_state->data(), hidden_state->data(), hidden_state->ElementsNum() * sizeof(float)); 7535be168c0dSopenharmony_ci+ auto output_cell_state = out_tensors_[THIRD_INPUT]; 7536be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_cell_state->data()); 7537be168c0dSopenharmony_ci+ (void)memcpy(output_cell_state->data(), cell_state->data(), cell_state->ElementsNum() * sizeof(float)); 7538be168c0dSopenharmony_ci+ 7539be168c0dSopenharmony_ci+ auto ret = InitInputWeightBias(); 7540be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7541be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmCPUKernel InitInputWeightBias error."; 7542be168c0dSopenharmony_ci+ FreeRunBuffer(); 7543be168c0dSopenharmony_ci+ return RET_ERROR; 7544be168c0dSopenharmony_ci+ } 7545be168c0dSopenharmony_ci+ 7546be168c0dSopenharmony_ci+ ret = InitStateWeightBias(); 7547be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7548be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmCPUKernel InitStateWeightBias error."; 7549be168c0dSopenharmony_ci+ FreeRunBuffer(); 7550be168c0dSopenharmony_ci+ return RET_ERROR; 7551be168c0dSopenharmony_ci+ } 7552be168c0dSopenharmony_ci+ 7553be168c0dSopenharmony_ci+ ret = InitProjectWeight(); 7554be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7555be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmCPUKernel InitProjectWeight error."; 7556be168c0dSopenharmony_ci+ FreeRunBuffer(); 7557be168c0dSopenharmony_ci+ return RET_ERROR; 7558be168c0dSopenharmony_ci+ } 7559be168c0dSopenharmony_ci+ bool is_bidirectional_with_multi_thread = thread_num_ != 1 && lstm_param_->bidirectional_; 7560be168c0dSopenharmony_ci+ ret = MallocRunBuffer(is_bidirectional_with_multi_thread); 7561be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7562be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmCPUKernel MallocRunBuffer Error."; 7563be168c0dSopenharmony_ci+ FreeRunBuffer(); 7564be168c0dSopenharmony_ci+ return RET_ERROR; 7565be168c0dSopenharmony_ci+ } 7566be168c0dSopenharmony_ci+ 7567be168c0dSopenharmony_ci+ PackLstmInput(input_ptr, packed_input_, lstm_param_->seq_len_ * lstm_param_->batch_, lstm_param_->input_size_); 7568be168c0dSopenharmony_ci+ if (IsTrain() && IsTrainable()) { 7569be168c0dSopenharmony_ci+ intermediate_states_ = reinterpret_cast<float *>(out_tensors_[kOutIntermediateStatesIndex]->data()); 7570be168c0dSopenharmony_ci+ } 7571be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_h_ptr_); 7572be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_i_ptr_); 7573be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input_bias_); 7574be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(state_bias_); 7575be168c0dSopenharmony_ci+ if (is_bidirectional_with_multi_thread) { 7576be168c0dSopenharmony_ci+ ret = ExecuteBidirectionalWithMultiThread(); 7577be168c0dSopenharmony_ci+ } else { 7578be168c0dSopenharmony_ci+ ret = ExecuteUnidirectionalOrSingleThread(); 7579be168c0dSopenharmony_ci+ } 7580be168c0dSopenharmony_ci+ FreeRunBuffer(); 7581be168c0dSopenharmony_ci+ return ret; 7582be168c0dSopenharmony_ci+} 7583be168c0dSopenharmony_ci+ 7584be168c0dSopenharmony_ci+void LstmFp32BaseCPUKernel::FreeRunBuffer() { 7585be168c0dSopenharmony_ci+ for (auto data : running_buffer_) { 7586be168c0dSopenharmony_ci+ ms_context_->allocator->Free(data); 7587be168c0dSopenharmony_ci+ } 7588be168c0dSopenharmony_ci+ running_buffer_.clear(); 7589be168c0dSopenharmony_ci+} 7590be168c0dSopenharmony_ci+ 7591be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::MallocRunBuffer(bool is_double) { 7592be168c0dSopenharmony_ci+ bool need_zone = lstm_param_->zoneout_cell_ < -FLT_EPSILON || lstm_param_->zoneout_cell_ > FLT_EPSILON; 7593be168c0dSopenharmony_ci+ size_t whole_size = 0; 7594be168c0dSopenharmony_ci+ std::vector<size_t> segments; 7595be168c0dSopenharmony_ci+ int scale = is_double ? C2NUM : 1; 7596be168c0dSopenharmony_ci+ size_t segment = kGateNum * lstm_param_->seq_len_ * lstm_param_->batch_ * 7597be168c0dSopenharmony_ci+ lstm_param_->hidden_size_; // 0: input * weight for result matrix 7598be168c0dSopenharmony_ci+ segments.push_back(segment); 7599be168c0dSopenharmony_ci+ whole_size += segment * scale; 7600be168c0dSopenharmony_ci+ 7601be168c0dSopenharmony_ci+ segment = lstm_param_->batch_ == 1 7602be168c0dSopenharmony_ci+ ? 0 7603be168c0dSopenharmony_ci+ : lstm_param_->state_row_align_ * lstm_param_->output_size_; // 1: state * weight for left matirx 7604be168c0dSopenharmony_ci+ segments.push_back(segment); 7605be168c0dSopenharmony_ci+ whole_size += segment * scale; 7606be168c0dSopenharmony_ci+ 7607be168c0dSopenharmony_ci+ segment = kGateNum * lstm_param_->batch_ * lstm_param_->hidden_size_; // 2: state gate buffer 7608be168c0dSopenharmony_ci+ segments.push_back(segment); 7609be168c0dSopenharmony_ci+ whole_size += segment * scale; 7610be168c0dSopenharmony_ci+ 7611be168c0dSopenharmony_ci+ segment = need_zone ? lstm_param_->batch_ * lstm_param_->hidden_size_ : 0; // 3: state_buffer for cell 7612be168c0dSopenharmony_ci+ segments.push_back(segment); 7613be168c0dSopenharmony_ci+ whole_size += segment * scale; 7614be168c0dSopenharmony_ci+ 7615be168c0dSopenharmony_ci+ segment = need_zone ? lstm_param_->batch_ * lstm_param_->output_size_ : 0; // 4: state_buffer for hidden 7616be168c0dSopenharmony_ci+ segments.push_back(segment); 7617be168c0dSopenharmony_ci+ whole_size += segment * scale; 7618be168c0dSopenharmony_ci+ 7619be168c0dSopenharmony_ci+ segment = 0; 7620be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 7621be168c0dSopenharmony_ci+ bool output_need_packed = lstm_param_->hidden_size_ % state_col_tile_; 7622be168c0dSopenharmony_ci+ if (lstm_param_->batch_ == 1 && output_need_packed) { // vec matmul need to malloc dst 7623be168c0dSopenharmony_ci+ int out_channel = lstm_param_->hidden_size_; 7624be168c0dSopenharmony_ci+ int oc_block_num = UP_DIV(out_channel, state_col_tile_); 7625be168c0dSopenharmony_ci+ MS_ASSERT(ms_context_->allocator != nullptr); 7626be168c0dSopenharmony_ci+ segment = lstm_param_->batch_ * oc_block_num * state_col_tile_; // 5: tmp output data 7627be168c0dSopenharmony_ci+ } 7628be168c0dSopenharmony_ci+#endif 7629be168c0dSopenharmony_ci+ segments.push_back(segment); 7630be168c0dSopenharmony_ci+ whole_size += segment * scale; 7631be168c0dSopenharmony_ci+ 7632be168c0dSopenharmony_ci+ if (in_tensors_.size() == C7NUM || lstm_param_->project_size_ != 0) { 7633be168c0dSopenharmony_ci+ segment = lstm_param_->batch_ == 1 ? 0 : lstm_param_->state_row_align_ * lstm_param_->hidden_size_ * scale; 7634be168c0dSopenharmony_ci+ segments.push_back(segment); // 6: project-layer input 7635be168c0dSopenharmony_ci+ whole_size += segment; 7636be168c0dSopenharmony_ci+ segment = 0; 7637be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 7638be168c0dSopenharmony_ci+ segment = 7639be168c0dSopenharmony_ci+ output_need_packed ? lstm_param_->batch_ * UP_ROUND(lstm_param_->output_size_, state_col_tile_) * scale : 0; 7640be168c0dSopenharmony_ci+#endif 7641be168c0dSopenharmony_ci+ segments.push_back(segment); // 7: project-layer output 7642be168c0dSopenharmony_ci+ whole_size += segment; 7643be168c0dSopenharmony_ci+ } else { 7644be168c0dSopenharmony_ci+ (void)segments.insert(segments.end(), C2NUM, 0); 7645be168c0dSopenharmony_ci+ } 7646be168c0dSopenharmony_ci+ 7647be168c0dSopenharmony_ci+ segment = 0; 7648be168c0dSopenharmony_ci+ if (in_tensors_.size() == kMindirInputTensorNum) { 7649be168c0dSopenharmony_ci+ segment = lstm_param_->batch_ * lstm_param_->output_size_; 7650be168c0dSopenharmony_ci+ } 7651be168c0dSopenharmony_ci+ segments.push_back(segment); 7652be168c0dSopenharmony_ci+ whole_size += segment * scale; 7653be168c0dSopenharmony_ci+ 7654be168c0dSopenharmony_ci+ segment = 7655be168c0dSopenharmony_ci+ lstm_param_->input_row_align_ * lstm_param_->input_size_; // input * weight for left matrix, which only once 7656be168c0dSopenharmony_ci+ whole_size += segment; 7657be168c0dSopenharmony_ci+ 7658be168c0dSopenharmony_ci+ auto whole_memory = reinterpret_cast<float *>(ms_context_->allocator->Malloc(whole_size * sizeof(float))); 7659be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(whole_memory != nullptr, RET_ERROR, "LSTM: malloc failed."); 7660be168c0dSopenharmony_ci+ running_buffer_.push_back(whole_memory); 7661be168c0dSopenharmony_ci+ MS_ASSERT(segments.size() == C9NUM); 7662be168c0dSopenharmony_ci+ auto Allocate = [&whole_memory, &segments](float **buffer) mutable { 7663be168c0dSopenharmony_ci+ for (int i = 0; i < C9NUM; ++i) { 7664be168c0dSopenharmony_ci+ buffer[i] = nullptr; 7665be168c0dSopenharmony_ci+ if (segments[i] == 0) { 7666be168c0dSopenharmony_ci+ continue; 7667be168c0dSopenharmony_ci+ } 7668be168c0dSopenharmony_ci+ buffer[i] = whole_memory; 7669be168c0dSopenharmony_ci+ whole_memory += segments[i]; 7670be168c0dSopenharmony_ci+ } 7671be168c0dSopenharmony_ci+ }; 7672be168c0dSopenharmony_ci+ Allocate(buffer_forward_); 7673be168c0dSopenharmony_ci+ if (is_double) { 7674be168c0dSopenharmony_ci+ Allocate(buffer_backward_); 7675be168c0dSopenharmony_ci+ } 7676be168c0dSopenharmony_ci+ packed_input_ = whole_memory; 7677be168c0dSopenharmony_ci+ return RET_OK; 7678be168c0dSopenharmony_ci+} 7679be168c0dSopenharmony_ci+ 7680be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::ExecuteBidirectionalWithMultiThread() { 7681be168c0dSopenharmony_ci+ auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[kInputGateIndex]); 7682be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7683be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7684be168c0dSopenharmony_ci+ return RET_ERROR; 7685be168c0dSopenharmony_ci+ } 7686be168c0dSopenharmony_ci+ const float *backward_weight_i = weight_i_ptr_ + kGateNum * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7687be168c0dSopenharmony_ci+ const float *backward_input_bias = input_bias_ + kGateNum * lstm_param_->input_col_align_; 7688be168c0dSopenharmony_ci+ ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_backward_[kInputGateIndex]); 7689be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7690be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7691be168c0dSopenharmony_ci+ return RET_ERROR; 7692be168c0dSopenharmony_ci+ } 7693be168c0dSopenharmony_ci+ ret = ParallelLaunch(this->ms_context_, LstmSequenceLoopRun, this, C2NUM); 7694be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7695be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LSTM: Do sequence-loop failed."; 7696be168c0dSopenharmony_ci+ } 7697be168c0dSopenharmony_ci+ return ret; 7698be168c0dSopenharmony_ci+} 7699be168c0dSopenharmony_ci+ 7700be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::ExecuteUnidirectionalOrSingleThread() { 7701be168c0dSopenharmony_ci+ auto ret = LstmPreProcessWithInput(weight_i_ptr_, input_bias_, buffer_forward_[kInputGateIndex]); 7702be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7703be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LSTM Forward: Input-MatMul running failed."; 7704be168c0dSopenharmony_ci+ return RET_ERROR; 7705be168c0dSopenharmony_ci+ } 7706be168c0dSopenharmony_ci+ LstmForwardLoop(buffer_forward_); 7707be168c0dSopenharmony_ci+ 7708be168c0dSopenharmony_ci+ // backward 7709be168c0dSopenharmony_ci+ if (lstm_param_->bidirectional_) { 7710be168c0dSopenharmony_ci+ const float *backward_weight_i = 7711be168c0dSopenharmony_ci+ weight_i_ptr_ + kGateNum * lstm_param_->input_col_align_ * lstm_param_->input_size_; 7712be168c0dSopenharmony_ci+ const float *backward_input_bias = input_bias_ + kGateNum * lstm_param_->input_col_align_; 7713be168c0dSopenharmony_ci+ ret = LstmPreProcessWithInput(backward_weight_i, backward_input_bias, buffer_forward_[kInputGateIndex]); 7714be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7715be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LSTM Backward: Input-MatMul running failed."; 7716be168c0dSopenharmony_ci+ return RET_ERROR; 7717be168c0dSopenharmony_ci+ } 7718be168c0dSopenharmony_ci+ LstmBackwardLoop(buffer_forward_); 7719be168c0dSopenharmony_ci+ } 7720be168c0dSopenharmony_ci+ return RET_OK; 7721be168c0dSopenharmony_ci+} 7722be168c0dSopenharmony_ci+ 7723be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::LstmPreProcessWithInput(const float *weight_i, const float *input_bias, float *dst) { 7724be168c0dSopenharmony_ci+ const float *weight{nullptr}; 7725be168c0dSopenharmony_ci+ const float *bias{nullptr}; 7726be168c0dSopenharmony_ci+ float *gate{nullptr}; 7727be168c0dSopenharmony_ci+ int thread_num = MSMIN(op_parameter_->thread_num_, UP_DIV(lstm_param_->input_col_align_, col_tile_)); 7728be168c0dSopenharmony_ci+ MS_CHECK_FALSE(thread_num == 0, RET_ERROR); 7729be168c0dSopenharmony_ci+ int stride = UP_DIV(UP_DIV(lstm_param_->input_col_align_, col_tile_), thread_num); 7730be168c0dSopenharmony_ci+ auto MatMulCoreFunc = [this, &weight, &bias, &gate, &stride](void *, int task_id, float, float) { 7731be168c0dSopenharmony_ci+ int current_start_oc = task_id * stride * col_tile_; 7732be168c0dSopenharmony_ci+ int current_rest_oc = 0; 7733be168c0dSopenharmony_ci+ current_rest_oc = lstm_param_->hidden_size_ - current_start_oc; 7734be168c0dSopenharmony_ci+ int cur_oc = MSMIN(stride * col_tile_, current_rest_oc); 7735be168c0dSopenharmony_ci+ if (cur_oc <= 0) { 7736be168c0dSopenharmony_ci+ return RET_OK; 7737be168c0dSopenharmony_ci+ } 7738be168c0dSopenharmony_ci+ 7739be168c0dSopenharmony_ci+ auto b = weight + current_start_oc * lstm_param_->input_size_; 7740be168c0dSopenharmony_ci+ auto c = gate + current_start_oc; 7741be168c0dSopenharmony_ci+ auto bias_ = (bias == nullptr) ? nullptr : bias + current_start_oc; 7742be168c0dSopenharmony_ci+ MatMulOpt(packed_input_, b, c, bias_, ActType_No, lstm_param_->input_size_, 7743be168c0dSopenharmony_ci+ lstm_param_->seq_len_ * lstm_param_->batch_, cur_oc, lstm_param_->hidden_size_, OutType_Nhwc); 7744be168c0dSopenharmony_ci+ return RET_OK; 7745be168c0dSopenharmony_ci+ }; 7746be168c0dSopenharmony_ci+ for (int i = 0; i < kGateNum; i++) { 7747be168c0dSopenharmony_ci+ weight = weight_i + lstm_param_->input_size_ * lstm_param_->input_col_align_ * i; 7748be168c0dSopenharmony_ci+ bias = input_bias + lstm_param_->input_col_align_ * i; 7749be168c0dSopenharmony_ci+ gate = dst + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * i; 7750be168c0dSopenharmony_ci+ auto ret = ParallelLaunch(this->ms_context_, MatMulCoreFunc, nullptr, thread_num); 7751be168c0dSopenharmony_ci+ if (ret != RET_OK) { 7752be168c0dSopenharmony_ci+ return RET_ERROR; 7753be168c0dSopenharmony_ci+ } 7754be168c0dSopenharmony_ci+ } 7755be168c0dSopenharmony_ci+ return RET_OK; 7756be168c0dSopenharmony_ci+} 7757be168c0dSopenharmony_ci+ 7758be168c0dSopenharmony_ci+int LstmFp32BaseCPUKernel::DoSequenceLoop(int task_id) { 7759be168c0dSopenharmony_ci+ if (task_id == 0) { 7760be168c0dSopenharmony_ci+ LstmForwardLoop(buffer_forward_); 7761be168c0dSopenharmony_ci+ return RET_OK; 7762be168c0dSopenharmony_ci+ } 7763be168c0dSopenharmony_ci+ if (task_id == 1) { 7764be168c0dSopenharmony_ci+ LstmBackwardLoop(buffer_backward_); 7765be168c0dSopenharmony_ci+ return RET_OK; 7766be168c0dSopenharmony_ci+ } 7767be168c0dSopenharmony_ci+ return RET_ERROR; 7768be168c0dSopenharmony_ci+} 7769be168c0dSopenharmony_ci+ 7770be168c0dSopenharmony_ci+void LstmFp32BaseCPUKernel::LstmForwardLoop(float *buffer[]) { 7771be168c0dSopenharmony_ci+ auto *output = reinterpret_cast<float *>(out_tensors_.at(FIRST_INPUT)->data()); 7772be168c0dSopenharmony_ci+ auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(SECOND_INPUT)->data()); 7773be168c0dSopenharmony_ci+ auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(THIRD_INPUT)->data()); 7774be168c0dSopenharmony_ci+ LstmUnidirectional(output, weight_h_ptr_, state_bias_, hidden_state, cell_state, weight_project_ptr_, 7775be168c0dSopenharmony_ci+ intermediate_states_, buffer, false); 7776be168c0dSopenharmony_ci+} 7777be168c0dSopenharmony_ci+ 7778be168c0dSopenharmony_ci+void LstmFp32BaseCPUKernel::LstmBackwardLoop(float *buffer[]) { 7779be168c0dSopenharmony_ci+ auto *output = reinterpret_cast<float *>(out_tensors_.at(0)->data()); 7780be168c0dSopenharmony_ci+ auto *hidden_state = reinterpret_cast<float *>(out_tensors_.at(1)->data()); 7781be168c0dSopenharmony_ci+ auto *cell_state = reinterpret_cast<float *>(out_tensors_.at(C2NUM)->data()); 7782be168c0dSopenharmony_ci+ const float *backward_weight_h = weight_h_ptr_ + kGateNum * lstm_param_->state_col_align_ * lstm_param_->output_size_; 7783be168c0dSopenharmony_ci+ const float *backward_state_bias = state_bias_ + kGateNum * lstm_param_->state_col_align_; 7784be168c0dSopenharmony_ci+ float *backward_output = output + lstm_param_->batch_ * lstm_param_->output_size_; 7785be168c0dSopenharmony_ci+ if (in_tensors_.size() == kMindirInputTensorNum) { 7786be168c0dSopenharmony_ci+ backward_output = output + lstm_param_->output_size_; 7787be168c0dSopenharmony_ci+ } 7788be168c0dSopenharmony_ci+ float *backward_cell_state = cell_state + lstm_param_->batch_ * lstm_param_->hidden_size_; 7789be168c0dSopenharmony_ci+ float *backward_hidden_state = hidden_state + lstm_param_->batch_ * lstm_param_->output_size_; 7790be168c0dSopenharmony_ci+ float *intermediate_states = nullptr; 7791be168c0dSopenharmony_ci+ if (intermediate_states_) { 7792be168c0dSopenharmony_ci+ intermediate_states = intermediate_states_ + lstm_param_->batch_ * lstm_param_->output_size_; 7793be168c0dSopenharmony_ci+ } 7794be168c0dSopenharmony_ci+ float *backward_weight_project = 7795be168c0dSopenharmony_ci+ weight_project_ptr_ ? weight_project_ptr_ + lstm_param_->hidden_size_ * lstm_param_->proj_col_align_ : nullptr; 7796be168c0dSopenharmony_ci+ LstmUnidirectional(backward_output, backward_weight_h, backward_state_bias, backward_hidden_state, 7797be168c0dSopenharmony_ci+ backward_cell_state, backward_weight_project, intermediate_states, buffer, true); 7798be168c0dSopenharmony_ci+} 7799be168c0dSopenharmony_ci+} // namespace mindspore::kernel 7800be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h 7801be168c0dSopenharmony_cinew file mode 100644 7802be168c0dSopenharmony_ciindex 00000000..c3c10cea 7803be168c0dSopenharmony_ci--- /dev/null 7804be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_fp32_base.h 7805be168c0dSopenharmony_ci@@ -0,0 +1,78 @@ 7806be168c0dSopenharmony_ci+/** 7807be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 7808be168c0dSopenharmony_ci+ * 7809be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 7810be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 7811be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 7812be168c0dSopenharmony_ci+ * 7813be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 7814be168c0dSopenharmony_ci+ * 7815be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 7816be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 7817be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7818be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 7819be168c0dSopenharmony_ci+ * limitations under the License. 7820be168c0dSopenharmony_ci+ */ 7821be168c0dSopenharmony_ci+ 7822be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_FP32_BASE_H_ 7823be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_FP32_BASE_H_ 7824be168c0dSopenharmony_ci+ 7825be168c0dSopenharmony_ci+#include <vector> 7826be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h" 7827be168c0dSopenharmony_ci+#include "nnacl/fp32/lstm_fp32.h" 7828be168c0dSopenharmony_ci+ 7829be168c0dSopenharmony_ci+namespace mindspore::kernel { 7830be168c0dSopenharmony_ci+class LstmFp32BaseCPUKernel : public LiteKernel { 7831be168c0dSopenharmony_ci+ public: 7832be168c0dSopenharmony_ci+ LstmFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 7833be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 7834be168c0dSopenharmony_ci+ : LiteKernel(parameter, inputs, outputs, ctx) { 7835be168c0dSopenharmony_ci+ lstm_param_ = reinterpret_cast<LstmParameter *>(op_parameter_); 7836be168c0dSopenharmony_ci+ } 7837be168c0dSopenharmony_ci+ 7838be168c0dSopenharmony_ci+ ~LstmFp32BaseCPUKernel() override = default; 7839be168c0dSopenharmony_ci+ 7840be168c0dSopenharmony_ci+ int Prepare() override; 7841be168c0dSopenharmony_ci+ int ReSize() override; 7842be168c0dSopenharmony_ci+ int Run() override; 7843be168c0dSopenharmony_ci+ int DoSequenceLoop(int task_id); 7844be168c0dSopenharmony_ci+ 7845be168c0dSopenharmony_ci+ protected: 7846be168c0dSopenharmony_ci+ virtual int InitInputWeightBias() = 0; 7847be168c0dSopenharmony_ci+ virtual int InitStateWeightBias() = 0; 7848be168c0dSopenharmony_ci+ virtual int InitProjectWeight() = 0; 7849be168c0dSopenharmony_ci+ virtual void LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, float *hidden_state, 7850be168c0dSopenharmony_ci+ float *cell_state, const float *weight_project, float *intermediate_states, 7851be168c0dSopenharmony_ci+ float *buffer[], bool is_backward) = 0; 7852be168c0dSopenharmony_ci+ 7853be168c0dSopenharmony_ci+ int hidden_init_index_{0}; 7854be168c0dSopenharmony_ci+ int cell_init_index_{0}; 7855be168c0dSopenharmony_ci+ int row_tile_{0}; 7856be168c0dSopenharmony_ci+ int col_tile_{0}; 7857be168c0dSopenharmony_ci+ int state_row_tile_{0}; 7858be168c0dSopenharmony_ci+ int state_col_tile_{0}; 7859be168c0dSopenharmony_ci+ int weight_segment_num_{0}; 7860be168c0dSopenharmony_ci+ float *weight_i_ptr_{nullptr}; 7861be168c0dSopenharmony_ci+ float *weight_h_ptr_{nullptr}; 7862be168c0dSopenharmony_ci+ float *weight_project_ptr_{nullptr}; 7863be168c0dSopenharmony_ci+ float *input_bias_{nullptr}; 7864be168c0dSopenharmony_ci+ float *state_bias_{nullptr}; 7865be168c0dSopenharmony_ci+ LstmParameter *lstm_param_{nullptr}; 7866be168c0dSopenharmony_ci+ std::vector<void *> running_buffer_; 7867be168c0dSopenharmony_ci+ 7868be168c0dSopenharmony_ci+ private: 7869be168c0dSopenharmony_ci+ void FreeRunBuffer(); 7870be168c0dSopenharmony_ci+ int MallocRunBuffer(bool is_double); 7871be168c0dSopenharmony_ci+ int ExecuteBidirectionalWithMultiThread(); 7872be168c0dSopenharmony_ci+ int ExecuteUnidirectionalOrSingleThread(); 7873be168c0dSopenharmony_ci+ int LstmPreProcessWithInput(const float *weight_i, const float *input_bias, float *dst); 7874be168c0dSopenharmony_ci+ void LstmForwardLoop(float *buffer[]); 7875be168c0dSopenharmony_ci+ void LstmBackwardLoop(float *buffer[]); 7876be168c0dSopenharmony_ci+ float *packed_input_{nullptr}; 7877be168c0dSopenharmony_ci+ float *intermediate_states_{nullptr}; 7878be168c0dSopenharmony_ci+ float *buffer_forward_[C9NUM] = {nullptr}; 7879be168c0dSopenharmony_ci+ float *buffer_backward_[C9NUM] = {nullptr}; 7880be168c0dSopenharmony_ci+}; 7881be168c0dSopenharmony_ci+} // namespace mindspore::kernel 7882be168c0dSopenharmony_ci+ 7883be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_FP32_BASE_H_ 7884be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc 7885be168c0dSopenharmony_cinew file mode 100644 7886be168c0dSopenharmony_ciindex 00000000..476d5940 7887be168c0dSopenharmony_ci--- /dev/null 7888be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.cc 7889be168c0dSopenharmony_ci@@ -0,0 +1,266 @@ 7890be168c0dSopenharmony_ci+/** 7891be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 7892be168c0dSopenharmony_ci+ * 7893be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 7894be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 7895be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 7896be168c0dSopenharmony_ci+ * 7897be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 7898be168c0dSopenharmony_ci+ * 7899be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 7900be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 7901be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7902be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 7903be168c0dSopenharmony_ci+ * limitations under the License. 7904be168c0dSopenharmony_ci+ */ 7905be168c0dSopenharmony_ci+ 7906be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h" 7907be168c0dSopenharmony_ci+#include "nnacl/fp32/pack_fp32.h" 7908be168c0dSopenharmony_ci+ 7909be168c0dSopenharmony_ci+namespace mindspore::kernel { 7910be168c0dSopenharmony_ci+namespace { 7911be168c0dSopenharmony_ci+constexpr int kInputGateIndex = 0; 7912be168c0dSopenharmony_ci+constexpr int kTempHiddenOutputIndex = 8; 7913be168c0dSopenharmony_ci+constexpr int kGateNum = 4; 7914be168c0dSopenharmony_ci+constexpr int kWeightsIndex = 3; 7915be168c0dSopenharmony_ci+const int kWeightsOrderMap[8] = {0, 2, 3, 1, 4, 6, 7, 5}; // IFGO order to IOFG order 7916be168c0dSopenharmony_ci+} // namespace 7917be168c0dSopenharmony_ci+ 7918be168c0dSopenharmony_ci+int LstmMindirFp32CPUKernel::ReSize() { 7919be168c0dSopenharmony_ci+ auto ret = LstmFp32BaseCPUKernel::ReSize(); 7920be168c0dSopenharmony_ci+ if (ret != lite::RET_OK) { 7921be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmMindirFp32CPUKernel resize failed."; 7922be168c0dSopenharmony_ci+ return ret; 7923be168c0dSopenharmony_ci+ } 7924be168c0dSopenharmony_ci+ // determine FB origin 7925be168c0dSopenharmony_ci+ gpu_orig_state_ = false; 7926be168c0dSopenharmony_ci+ auto weight_t = in_tensors_.at(kWeightsIndex); 7927be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(lstm_param_->hidden_size_, lstm_param_->input_size_, lite::RET_ERROR); 7928be168c0dSopenharmony_ci+ int hi_unit_size = lstm_param_->hidden_size_ * lstm_param_->input_size_; 7929be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_segment_num_, hi_unit_size, lite::RET_ERROR); 7930be168c0dSopenharmony_ci+ int hi_whole_size = weight_segment_num_ * hi_unit_size; 7931be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(lstm_param_->hidden_size_, lstm_param_->output_size_, lite::RET_ERROR); 7932be168c0dSopenharmony_ci+ int hh_unit_size = lstm_param_->hidden_size_ * lstm_param_->output_size_; 7933be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_segment_num_, hh_unit_size, lite::RET_ERROR); 7934be168c0dSopenharmony_ci+ int hh_whole_size = weight_segment_num_ * hh_unit_size; 7935be168c0dSopenharmony_ci+ int scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 7936be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(lstm_param_->hidden_size_, lstm_param_->project_size_, lite::RET_ERROR); 7937be168c0dSopenharmony_ci+ int hp_unit_size = lstm_param_->hidden_size_ * lstm_param_->project_size_; 7938be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(scale, hp_unit_size, lite::RET_ERROR); 7939be168c0dSopenharmony_ci+ int hp_whole_size = scale * hp_unit_size; 7940be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(weight_segment_num_ * C2NUM, lstm_param_->hidden_size_, lite::RET_ERROR); 7941be168c0dSopenharmony_ci+ int bias_whole_size = weight_segment_num_ * C2NUM * lstm_param_->hidden_size_; 7942be168c0dSopenharmony_ci+ auto whole_size = weight_t->ElementsNum(); 7943be168c0dSopenharmony_ci+ bool has_bias = (hi_whole_size + hh_whole_size + hp_whole_size < whole_size) ? true : false; 7944be168c0dSopenharmony_ci+ // if bias exist we can determine the gpu_orig_state_ 7945be168c0dSopenharmony_ci+ if (has_bias) { 7946be168c0dSopenharmony_ci+ gpu_orig_state_ = (hi_whole_size + hh_whole_size + hp_whole_size + bias_whole_size == whole_size) ? true : false; 7947be168c0dSopenharmony_ci+ } else { 7948be168c0dSopenharmony_ci+ bias_whole_size = 0; 7949be168c0dSopenharmony_ci+ } 7950be168c0dSopenharmony_ci+ if (gpu_orig_state_) { 7951be168c0dSopenharmony_ci+ return lite::RET_OK; 7952be168c0dSopenharmony_ci+ } 7953be168c0dSopenharmony_ci+ bias_whole_size /= C2NUM; 7954be168c0dSopenharmony_ci+ if (hi_whole_size + hh_whole_size + hp_whole_size + bias_whole_size != whole_size) { 7955be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmMindir is invalid when original model exports from CPU."; 7956be168c0dSopenharmony_ci+ return lite::RET_INPUT_TENSOR_ERROR; 7957be168c0dSopenharmony_ci+ } 7958be168c0dSopenharmony_ci+ return lite::RET_OK; 7959be168c0dSopenharmony_ci+} 7960be168c0dSopenharmony_ci+ 7961be168c0dSopenharmony_ci+int LstmMindirFp32CPUKernel::InitInputWeightBias() { 7962be168c0dSopenharmony_ci+ // malloc and init input * weight right matrix buffer 7963be168c0dSopenharmony_ci+ // input -- row: seq_len * batch; col: input_size 7964be168c0dSopenharmony_ci+ // weight -- row: hidden_size; col: input_size, need transpose 7965be168c0dSopenharmony_ci+ // result -- row: seq_len * batch; col: hidden_size 7966be168c0dSopenharmony_ci+ weight_i_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 7967be168c0dSopenharmony_ci+ weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); 7968be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_i_ptr_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc weight_i_ptr_ failed."); 7969be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_i_ptr_); 7970be168c0dSopenharmony_ci+ auto weight_data = reinterpret_cast<float *>(in_tensors_.at(kWeightsIndex)->data()); 7971be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_data); 7972be168c0dSopenharmony_ci+ 7973be168c0dSopenharmony_ci+ int hi_unit_size = lstm_param_->input_size_ * lstm_param_->hidden_size_; 7974be168c0dSopenharmony_ci+ int hh_unit_size = lstm_param_->hidden_size_ * lstm_param_->output_size_; 7975be168c0dSopenharmony_ci+ int stride = (gpu_orig_state_) ? kGateNum * (hi_unit_size + hh_unit_size) : kGateNum * hi_unit_size; 7976be168c0dSopenharmony_ci+ PackLstmWeightWithStride(weight_i_ptr_, weight_data, weight_segment_num_, lstm_param_->input_size_, 7977be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 7978be168c0dSopenharmony_ci+ stride, kWeightsOrderMap); 7979be168c0dSopenharmony_ci+ // input bias 7980be168c0dSopenharmony_ci+ auto bias_size = weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float); 7981be168c0dSopenharmony_ci+ input_bias_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(bias_size)); 7982be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_bias_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc input_bias_ failed."); 7983be168c0dSopenharmony_ci+ memset(input_bias_, 0, bias_size); 7984be168c0dSopenharmony_ci+ running_buffer_.push_back(input_bias_); 7985be168c0dSopenharmony_ci+ if (!lstm_param_->has_bias_) { 7986be168c0dSopenharmony_ci+ return RET_OK; 7987be168c0dSopenharmony_ci+ } 7988be168c0dSopenharmony_ci+ int scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 7989be168c0dSopenharmony_ci+ int offset = weight_segment_num_ * (hi_unit_size + hh_unit_size) + 7990be168c0dSopenharmony_ci+ scale * lstm_param_->project_size_ * lstm_param_->hidden_size_; 7991be168c0dSopenharmony_ci+ float *bias_data = weight_data + offset; 7992be168c0dSopenharmony_ci+ int b_stride = 7993be168c0dSopenharmony_ci+ (gpu_orig_state_) ? kGateNum * (scale * lstm_param_->hidden_size_) : kGateNum * (lstm_param_->hidden_size_); 7994be168c0dSopenharmony_ci+ PackLstmBiasWithStride(input_bias_, bias_data, weight_segment_num_, lstm_param_->hidden_size_, 7995be168c0dSopenharmony_ci+ lstm_param_->input_col_align_, lstm_param_->bidirectional_, b_stride, kWeightsOrderMap); 7996be168c0dSopenharmony_ci+ return RET_OK; 7997be168c0dSopenharmony_ci+} 7998be168c0dSopenharmony_ci+ 7999be168c0dSopenharmony_ci+int LstmMindirFp32CPUKernel::InitStateWeightBias() { 8000be168c0dSopenharmony_ci+ // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 8001be168c0dSopenharmony_ci+ // state -- row: batch; col: hidden_size 8002be168c0dSopenharmony_ci+ // weight -- row: hidden_size; col: hidden_size, need transpose 8003be168c0dSopenharmony_ci+ // result -- row: batch; col: hidden_size 8004be168c0dSopenharmony_ci+ auto weight_data = (reinterpret_cast<float *>(in_tensors_.at(kWeightsIndex)->data())); 8005be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_data); 8006be168c0dSopenharmony_ci+ 8007be168c0dSopenharmony_ci+ int hi_unit_size = lstm_param_->input_size_ * lstm_param_->hidden_size_; 8008be168c0dSopenharmony_ci+ int hh_unit_size = lstm_param_->hidden_size_ * lstm_param_->output_size_; 8009be168c0dSopenharmony_ci+ int stride = (gpu_orig_state_) ? kGateNum * (hi_unit_size + hh_unit_size) : kGateNum * hh_unit_size; 8010be168c0dSopenharmony_ci+ 8011be168c0dSopenharmony_ci+ auto weight_h_data = weight_data + (gpu_orig_state_ ? kGateNum * hi_unit_size : weight_segment_num_ * hi_unit_size); 8012be168c0dSopenharmony_ci+ 8013be168c0dSopenharmony_ci+ auto weight_unit_pack_size = sizeof(float) * lstm_param_->state_col_align_ * lstm_param_->output_size_; 8014be168c0dSopenharmony_ci+ auto weight_pack_size = weight_segment_num_ * weight_unit_pack_size; 8015be168c0dSopenharmony_ci+ weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weight_pack_size)); 8016be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc weight_h_ptr_ failed."); 8017be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_h_ptr_); 8018be168c0dSopenharmony_ci+ if (lstm_param_->batch_ != 1) { 8019be168c0dSopenharmony_ci+ PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_segment_num_, lstm_param_->output_size_, 8020be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, 8021be168c0dSopenharmony_ci+ stride, kWeightsOrderMap); 8022be168c0dSopenharmony_ci+ } else { 8023be168c0dSopenharmony_ci+ for (int i = 0; i < weight_segment_num_; i++) { 8024be168c0dSopenharmony_ci+ const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8025be168c0dSopenharmony_ci+ float *dst_batch = 8026be168c0dSopenharmony_ci+ weight_h_ptr_ + kWeightsOrderMap[i] * lstm_param_->state_col_align_ * lstm_param_->output_size_; 8027be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 8028be168c0dSopenharmony_ci+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->output_size_); 8029be168c0dSopenharmony_ci+#else 8030be168c0dSopenharmony_ci+ (void)memcpy(dst_batch, src_batch, weight_unit_pack_size); 8031be168c0dSopenharmony_ci+#endif 8032be168c0dSopenharmony_ci+ } 8033be168c0dSopenharmony_ci+ } 8034be168c0dSopenharmony_ci+ 8035be168c0dSopenharmony_ci+ // state bias 8036be168c0dSopenharmony_ci+ auto bias_pack_size = weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float); 8037be168c0dSopenharmony_ci+ state_bias_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(bias_pack_size)); 8038be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(state_bias_ != nullptr, lite::RET_NULL_PTR, "LstmMindirCPUKernel malloc state_bias_ failed."); 8039be168c0dSopenharmony_ci+ memset(state_bias_, 0, bias_pack_size); 8040be168c0dSopenharmony_ci+ running_buffer_.push_back(state_bias_); 8041be168c0dSopenharmony_ci+ if (!lstm_param_->has_bias_ || !gpu_orig_state_) { 8042be168c0dSopenharmony_ci+ return RET_OK; 8043be168c0dSopenharmony_ci+ } 8044be168c0dSopenharmony_ci+ 8045be168c0dSopenharmony_ci+ int hi_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; 8046be168c0dSopenharmony_ci+ int hh_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8047be168c0dSopenharmony_ci+ int proj_size = 8048be168c0dSopenharmony_ci+ (lstm_param_->bidirectional_ ? C2NUM : C1NUM) * lstm_param_->project_size_ * lstm_param_->hidden_size_; 8049be168c0dSopenharmony_ci+ // mindir from device "GPU", secend bias is also present order IFOG 8050be168c0dSopenharmony_ci+ int bias_offset = hi_whole_size + hh_whole_size + proj_size + lstm_param_->hidden_size_ * kGateNum; 8051be168c0dSopenharmony_ci+ float *state_bias = weight_data + bias_offset; 8052be168c0dSopenharmony_ci+ int b_stride = kGateNum * lstm_param_->hidden_size_ * C2NUM; 8053be168c0dSopenharmony_ci+ PackLstmBiasWithStride(state_bias_, state_bias, weight_segment_num_, lstm_param_->hidden_size_, 8054be168c0dSopenharmony_ci+ lstm_param_->state_col_align_, lstm_param_->bidirectional_, b_stride, kWeightsOrderMap); 8055be168c0dSopenharmony_ci+ return RET_OK; 8056be168c0dSopenharmony_ci+} 8057be168c0dSopenharmony_ci+ 8058be168c0dSopenharmony_ci+int LstmMindirFp32CPUKernel::InitProjectWeight() { 8059be168c0dSopenharmony_ci+ if (lstm_param_->project_size_ == 0) { 8060be168c0dSopenharmony_ci+ return RET_OK; 8061be168c0dSopenharmony_ci+ } 8062be168c0dSopenharmony_ci+ auto weight_data = (reinterpret_cast<float *>(in_tensors_.at(kWeightsIndex)->data())); 8063be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_data); 8064be168c0dSopenharmony_ci+ int hi_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->input_size_; 8065be168c0dSopenharmony_ci+ int hh_whole_size = weight_segment_num_ * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8066be168c0dSopenharmony_ci+ auto weight_proj_data = weight_data + hi_whole_size + hh_whole_size; 8067be168c0dSopenharmony_ci+ int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 8068be168c0dSopenharmony_ci+ auto pack_size = batch * lstm_param_->hidden_size_ * lstm_param_->proj_col_align_ * sizeof(float); 8069be168c0dSopenharmony_ci+ if (lstm_param_->batch_ != 1) { 8070be168c0dSopenharmony_ci+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8071be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8072be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8073be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_project_ptr_); 8074be168c0dSopenharmony_ci+ PackLstmWeightWithStride(weight_project_ptr_, weight_proj_data, batch, lstm_param_->hidden_size_, 8075be168c0dSopenharmony_ci+ lstm_param_->output_size_, lstm_param_->proj_col_align_, lstm_param_->bidirectional_, 8076be168c0dSopenharmony_ci+ lstm_param_->hidden_size_ * lstm_param_->output_size_, nullptr); 8077be168c0dSopenharmony_ci+ } else { 8078be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 8079be168c0dSopenharmony_ci+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8080be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8081be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8082be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_project_ptr_); 8083be168c0dSopenharmony_ci+ for (int i = 0; i < batch; ++i) { 8084be168c0dSopenharmony_ci+ const float *src_batch = weight_proj_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8085be168c0dSopenharmony_ci+ float *dst_batch = weight_project_ptr_ + i * lstm_param_->hidden_size_ * lstm_param_->proj_col_align_; 8086be168c0dSopenharmony_ci+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->output_size_, lstm_param_->hidden_size_); 8087be168c0dSopenharmony_ci+ } 8088be168c0dSopenharmony_ci+#else 8089be168c0dSopenharmony_ci+ weight_project_ptr_ = weight_proj_data; 8090be168c0dSopenharmony_ci+#endif 8091be168c0dSopenharmony_ci+ } 8092be168c0dSopenharmony_ci+ return RET_OK; 8093be168c0dSopenharmony_ci+} 8094be168c0dSopenharmony_ci+ 8095be168c0dSopenharmony_ci+void LstmMindirFp32CPUKernel::LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, 8096be168c0dSopenharmony_ci+ float *hidden_state, float *cell_state, const float *weight_project, 8097be168c0dSopenharmony_ci+ float *intermediate_states, float **buffer, bool is_backward) { 8098be168c0dSopenharmony_ci+ float *gate = buffer[kInputGateIndex]; 8099be168c0dSopenharmony_ci+ float *input_gate = gate; 8100be168c0dSopenharmony_ci+ float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C2NUM; 8101be168c0dSopenharmony_ci+ float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C3NUM; 8102be168c0dSopenharmony_ci+ float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; 8103be168c0dSopenharmony_ci+ float *tmp = buffer[kTempHiddenOutputIndex]; 8104be168c0dSopenharmony_ci+ int dir_mult = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 8105be168c0dSopenharmony_ci+ for (int t = 0; t < lstm_param_->seq_len_; t++) { 8106be168c0dSopenharmony_ci+ int real_t = is_backward ? lstm_param_->seq_len_ - t - C1NUM : t; 8107be168c0dSopenharmony_ci+ float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8108be168c0dSopenharmony_ci+ float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8109be168c0dSopenharmony_ci+ float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8110be168c0dSopenharmony_ci+ float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8111be168c0dSopenharmony_ci+ // Sequence, Batch, DirMul, Hidden 8112be168c0dSopenharmony_ci+ LstmStepUnit(tmp, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, weight_project, 8113be168c0dSopenharmony_ci+ hidden_state, cell_state, buffer, lstm_param_); 8114be168c0dSopenharmony_ci+ int seq_offset = real_t * lstm_param_->batch_ * dir_mult * lstm_param_->output_size_; 8115be168c0dSopenharmony_ci+ for (int b = 0; b < lstm_param_->batch_; b++) { 8116be168c0dSopenharmony_ci+ int batch_offset = b * dir_mult * lstm_param_->output_size_; 8117be168c0dSopenharmony_ci+ float *output_ptr = output + seq_offset + batch_offset; 8118be168c0dSopenharmony_ci+ memcpy(output_ptr, tmp + b * lstm_param_->output_size_, lstm_param_->output_size_ * sizeof(float)); 8119be168c0dSopenharmony_ci+ } 8120be168c0dSopenharmony_ci+ if (intermediate_states) { 8121be168c0dSopenharmony_ci+ RecordStates(hidden_state, cell_state, input_gate_t, output_gate_t, forget_gate_t, cell_gate_t, 8122be168c0dSopenharmony_ci+ intermediate_states, real_t); 8123be168c0dSopenharmony_ci+ } 8124be168c0dSopenharmony_ci+ } 8125be168c0dSopenharmony_ci+} 8126be168c0dSopenharmony_ci+ 8127be168c0dSopenharmony_ci+void LstmMindirFp32CPUKernel::RecordStates(const float *hidden_state, float *cell_state, float *input_gate, 8128be168c0dSopenharmony_ci+ const float *output_gate, float *forget_gate, const float *cell_gate, 8129be168c0dSopenharmony_ci+ float *intermediate_states, int step) { 8130be168c0dSopenharmony_ci+ float *states = intermediate_states; 8131be168c0dSopenharmony_ci+ auto hidden_size = lstm_param_->batch_ * lstm_param_->output_size_; 8132be168c0dSopenharmony_ci+ auto state_size = lstm_param_->batch_ * lstm_param_->hidden_size_; 8133be168c0dSopenharmony_ci+ if (state_size < 0) { 8134be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "state size should be greater than or equal to zero."; 8135be168c0dSopenharmony_ci+ return; 8136be168c0dSopenharmony_ci+ } 8137be168c0dSopenharmony_ci+ auto hidden_stride = step * lstm_param_->output_step_; 8138be168c0dSopenharmony_ci+ auto hidden_seq_stride = lstm_param_->seq_len_ * lstm_param_->output_step_; 8139be168c0dSopenharmony_ci+ auto other_output_step = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->hidden_size_ 8140be168c0dSopenharmony_ci+ : lstm_param_->batch_ * lstm_param_->hidden_size_; 8141be168c0dSopenharmony_ci+ auto stride = step * other_output_step; 8142be168c0dSopenharmony_ci+ auto seq_stride = lstm_param_->seq_len_ * other_output_step; 8143be168c0dSopenharmony_ci+ memcpy(states + hidden_stride, hidden_state, hidden_size * sizeof(float)); 8144be168c0dSopenharmony_ci+ stride += hidden_seq_stride; 8145be168c0dSopenharmony_ci+ memcpy(states + stride, cell_state, state_size * sizeof(float)); 8146be168c0dSopenharmony_ci+ stride += seq_stride; 8147be168c0dSopenharmony_ci+ memcpy(states + stride, input_gate, state_size * sizeof(float)); 8148be168c0dSopenharmony_ci+ stride += seq_stride; 8149be168c0dSopenharmony_ci+ memcpy(states + stride, output_gate, state_size * sizeof(float)); 8150be168c0dSopenharmony_ci+ stride += seq_stride; 8151be168c0dSopenharmony_ci+ memcpy(states + stride, forget_gate, state_size * sizeof(float)); 8152be168c0dSopenharmony_ci+ stride += seq_stride; 8153be168c0dSopenharmony_ci+ memcpy(states + stride, cell_gate, state_size * sizeof(float)); 8154be168c0dSopenharmony_ci+} 8155be168c0dSopenharmony_ci+} // namespace mindspore::kernel 8156be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h 8157be168c0dSopenharmony_cinew file mode 100644 8158be168c0dSopenharmony_ciindex 00000000..84cdd38e 8159be168c0dSopenharmony_ci--- /dev/null 8160be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_mindir_fp32.h 8161be168c0dSopenharmony_ci@@ -0,0 +1,63 @@ 8162be168c0dSopenharmony_ci+/** 8163be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8164be168c0dSopenharmony_ci+ * 8165be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8166be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8167be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8168be168c0dSopenharmony_ci+ * 8169be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8170be168c0dSopenharmony_ci+ * 8171be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8172be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8173be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8174be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8175be168c0dSopenharmony_ci+ * limitations under the License. 8176be168c0dSopenharmony_ci+ */ 8177be168c0dSopenharmony_ci+ 8178be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_MINDIR_FP32_H_ 8179be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_MINDIR_FP32_H_ 8180be168c0dSopenharmony_ci+ 8181be168c0dSopenharmony_ci+#include <vector> 8182be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" 8183be168c0dSopenharmony_ci+ 8184be168c0dSopenharmony_ci+namespace mindspore::kernel { 8185be168c0dSopenharmony_ci+/* 8186be168c0dSopenharmony_ci+ * 1. LSTM without project, output_size = hidden_size 8187be168c0dSopenharmony_ci+ * h_init: second input, shape is [bidirectional, batch_size, hidden_size] 8188be168c0dSopenharmony_ci+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 8189be168c0dSopenharmony_ci+ * weight_bias: forth input, weight_ih + weight_hh + bias, the gate order is IFGO 8190be168c0dSopenharmony_ci+ * 8191be168c0dSopenharmony_ci+ * 2. LSTM with project, output_size = project_size 8192be168c0dSopenharmony_ci+ * h_init: second input, shape is [bidirectional, batch_size, project_size] 8193be168c0dSopenharmony_ci+ * c_init: third input, shape is [bidirectional, batch_size, hidden_size] 8194be168c0dSopenharmony_ci+ * weight_bias: forth input, weight_ih + weight_hh + proj + bias, the gate order is IFGO 8195be168c0dSopenharmony_ci+ */ 8196be168c0dSopenharmony_ci+class LstmMindirFp32CPUKernel : public LstmFp32BaseCPUKernel { 8197be168c0dSopenharmony_ci+ public: 8198be168c0dSopenharmony_ci+ LstmMindirFp32CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 8199be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 8200be168c0dSopenharmony_ci+ : LstmFp32BaseCPUKernel(parameter, inputs, outputs, ctx) { 8201be168c0dSopenharmony_ci+ hidden_init_index_ = SECOND_INPUT; 8202be168c0dSopenharmony_ci+ cell_init_index_ = THIRD_INPUT; 8203be168c0dSopenharmony_ci+ } 8204be168c0dSopenharmony_ci+ 8205be168c0dSopenharmony_ci+ ~LstmMindirFp32CPUKernel() override = default; 8206be168c0dSopenharmony_ci+ 8207be168c0dSopenharmony_ci+ int ReSize() override; 8208be168c0dSopenharmony_ci+ 8209be168c0dSopenharmony_ci+ protected: 8210be168c0dSopenharmony_ci+ int InitInputWeightBias() override; 8211be168c0dSopenharmony_ci+ int InitStateWeightBias() override; 8212be168c0dSopenharmony_ci+ int InitProjectWeight() override; 8213be168c0dSopenharmony_ci+ void LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, float *hidden_state, 8214be168c0dSopenharmony_ci+ float *cell_state, const float *weight_project, float *intermediate_states, float *buffer[], 8215be168c0dSopenharmony_ci+ bool is_backward) override; 8216be168c0dSopenharmony_ci+ 8217be168c0dSopenharmony_ci+ private: 8218be168c0dSopenharmony_ci+ void RecordStates(const float *hidden_state, float *cell_state, float *input_gate, const float *output_gate, 8219be168c0dSopenharmony_ci+ float *forget_gate, const float *cell_gate, float *intermediate_states, int step); 8220be168c0dSopenharmony_ci+ bool gpu_orig_state_{false}; 8221be168c0dSopenharmony_ci+}; 8222be168c0dSopenharmony_ci+} // namespace mindspore::kernel 8223be168c0dSopenharmony_ci+ 8224be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_MINDIR_FP32_H_ 8225be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc 8226be168c0dSopenharmony_cinew file mode 100644 8227be168c0dSopenharmony_ciindex 00000000..62f9f2b7 8228be168c0dSopenharmony_ci--- /dev/null 8229be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.cc 8230be168c0dSopenharmony_ci@@ -0,0 +1,173 @@ 8231be168c0dSopenharmony_ci+/** 8232be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8233be168c0dSopenharmony_ci+ * 8234be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8235be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8236be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8237be168c0dSopenharmony_ci+ * 8238be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8239be168c0dSopenharmony_ci+ * 8240be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8241be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8242be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8243be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8244be168c0dSopenharmony_ci+ * limitations under the License. 8245be168c0dSopenharmony_ci+ */ 8246be168c0dSopenharmony_ci+ 8247be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h" 8248be168c0dSopenharmony_ci+#include "nnacl/fp32/pack_fp32.h" 8249be168c0dSopenharmony_ci+ 8250be168c0dSopenharmony_ci+namespace mindspore::kernel { 8251be168c0dSopenharmony_ci+namespace { 8252be168c0dSopenharmony_ci+constexpr int kInputGateIndex = 0; 8253be168c0dSopenharmony_ci+constexpr int kGateNum = 4; 8254be168c0dSopenharmony_ci+constexpr int kWeightInputIndex = 1; 8255be168c0dSopenharmony_ci+constexpr int kWeightHiddenindex = 2; 8256be168c0dSopenharmony_ci+constexpr int kCombinedBiasIndex = 3; 8257be168c0dSopenharmony_ci+} // namespace 8258be168c0dSopenharmony_ci+ 8259be168c0dSopenharmony_ci+int LstmNonMindirFp32CPUKernel::InitInputWeightBias() { 8260be168c0dSopenharmony_ci+ // malloc and init input * weight right matrix buffer 8261be168c0dSopenharmony_ci+ // input -- row: seq_len * batch; col: input_size 8262be168c0dSopenharmony_ci+ // weight -- row: hidden_size; col: input_size, need transpose 8263be168c0dSopenharmony_ci+ // result -- row: seq_len * batch; col: hidden_size 8264be168c0dSopenharmony_ci+ weight_i_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc( 8265be168c0dSopenharmony_ci+ weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * sizeof(float))); 8266be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_i_ptr_ != nullptr, lite::RET_NULL_PTR, 8267be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_i_ptr_ failed."); 8268be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_i_ptr_); 8269be168c0dSopenharmony_ci+ auto weight_i = in_tensors_.at(kWeightInputIndex); 8270be168c0dSopenharmony_ci+ auto weight_i_data = reinterpret_cast<float *>(weight_i->data()); 8271be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_i_data); 8272be168c0dSopenharmony_ci+ 8273be168c0dSopenharmony_ci+ int stride = kGateNum * lstm_param_->input_size_ * lstm_param_->hidden_size_; 8274be168c0dSopenharmony_ci+ PackLstmWeightWithStride(weight_i_ptr_, weight_i_data, weight_segment_num_, lstm_param_->input_size_, 8275be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->input_col_align_, lstm_param_->bidirectional_, 8276be168c0dSopenharmony_ci+ stride, nullptr); 8277be168c0dSopenharmony_ci+ // input bias 8278be168c0dSopenharmony_ci+ input_bias_ = reinterpret_cast<float *>( 8279be168c0dSopenharmony_ci+ ms_context_->allocator->Malloc(weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float))); 8280be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc input_bias_ failed."); 8281be168c0dSopenharmony_ci+ memset(input_bias_, 0, weight_segment_num_ * lstm_param_->input_col_align_ * sizeof(float)); 8282be168c0dSopenharmony_ci+ running_buffer_.push_back(input_bias_); 8283be168c0dSopenharmony_ci+ auto bias_data = reinterpret_cast<float *>(in_tensors_.at(kCombinedBiasIndex)->data()); 8284be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(bias_data); 8285be168c0dSopenharmony_ci+ PackLstmBias(input_bias_, bias_data, weight_segment_num_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 8286be168c0dSopenharmony_ci+ lstm_param_->bidirectional_, nullptr); 8287be168c0dSopenharmony_ci+ return RET_OK; 8288be168c0dSopenharmony_ci+} 8289be168c0dSopenharmony_ci+ 8290be168c0dSopenharmony_ci+int LstmNonMindirFp32CPUKernel::InitStateWeightBias() { 8291be168c0dSopenharmony_ci+ // malloc and init state * weight right matrix buffer, state * weight will be executed seq_len_ times. 8292be168c0dSopenharmony_ci+ // state -- row: batch; col: hidden_size 8293be168c0dSopenharmony_ci+ // weight -- row: hidden_size; col: hidden_size, need transpose 8294be168c0dSopenharmony_ci+ // result -- row: batch; col: hidden_size 8295be168c0dSopenharmony_ci+ auto weight_h = in_tensors_.at(kWeightHiddenindex); 8296be168c0dSopenharmony_ci+ auto weight_h_data = reinterpret_cast<float *>(weight_h->data()); 8297be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_h_data); 8298be168c0dSopenharmony_ci+ 8299be168c0dSopenharmony_ci+ int stride = kGateNum * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8300be168c0dSopenharmony_ci+ auto weight_pack_size = 8301be168c0dSopenharmony_ci+ weight_segment_num_ * lstm_param_->state_col_align_ * lstm_param_->output_size_ * sizeof(float); 8302be168c0dSopenharmony_ci+ if (lstm_param_->batch_ != 1) { 8303be168c0dSopenharmony_ci+ weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weight_pack_size)); 8304be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, 8305be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_h_ptr_ failed."); 8306be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_h_ptr_); 8307be168c0dSopenharmony_ci+ PackLstmWeightWithStride(weight_h_ptr_, weight_h_data, weight_segment_num_, lstm_param_->output_size_, 8308be168c0dSopenharmony_ci+ lstm_param_->hidden_size_, lstm_param_->state_col_align_, lstm_param_->bidirectional_, 8309be168c0dSopenharmony_ci+ stride, nullptr); 8310be168c0dSopenharmony_ci+ } else { 8311be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 8312be168c0dSopenharmony_ci+ weight_h_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(weight_pack_size)); 8313be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_h_ptr_ != nullptr, lite::RET_NULL_PTR, 8314be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_h_ptr_ failed."); 8315be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_h_ptr_); 8316be168c0dSopenharmony_ci+ for (int i = 0; i < weight_segment_num_; i++) { 8317be168c0dSopenharmony_ci+ const float *src_batch = weight_h_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8318be168c0dSopenharmony_ci+ float *dst_batch = weight_h_ptr_ + i * lstm_param_->state_col_align_ * lstm_param_->output_size_; 8319be168c0dSopenharmony_ci+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->hidden_size_, lstm_param_->output_size_); 8320be168c0dSopenharmony_ci+ } 8321be168c0dSopenharmony_ci+#else 8322be168c0dSopenharmony_ci+ weight_h_ptr_ = weight_h_data; 8323be168c0dSopenharmony_ci+#endif 8324be168c0dSopenharmony_ci+ } 8325be168c0dSopenharmony_ci+ 8326be168c0dSopenharmony_ci+ // state bias 8327be168c0dSopenharmony_ci+ auto bias_pack_size = weight_segment_num_ * lstm_param_->state_col_align_ * sizeof(float); 8328be168c0dSopenharmony_ci+ state_bias_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(bias_pack_size)); 8329be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(state_bias_ != nullptr, lite::RET_NULL_PTR, "LstmNonMindirCPUKernel malloc state_bias_ failed."); 8330be168c0dSopenharmony_ci+ memset(state_bias_, 0, bias_pack_size); 8331be168c0dSopenharmony_ci+ running_buffer_.push_back(state_bias_); 8332be168c0dSopenharmony_ci+ // if ONNX, secend bias is also present order IOFG 8333be168c0dSopenharmony_ci+ auto bias_data = reinterpret_cast<float *>(in_tensors_.at(kCombinedBiasIndex)->data()); 8334be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(bias_data); 8335be168c0dSopenharmony_ci+ auto *state_bias = bias_data + kGateNum * lstm_param_->hidden_size_; 8336be168c0dSopenharmony_ci+ PackLstmBias(state_bias_, state_bias, weight_segment_num_, lstm_param_->hidden_size_, lstm_param_->state_col_align_, 8337be168c0dSopenharmony_ci+ lstm_param_->bidirectional_, nullptr); 8338be168c0dSopenharmony_ci+ return RET_OK; 8339be168c0dSopenharmony_ci+} 8340be168c0dSopenharmony_ci+ 8341be168c0dSopenharmony_ci+int LstmNonMindirFp32CPUKernel::InitProjectWeight() { 8342be168c0dSopenharmony_ci+ if (in_tensors_.size() < C7NUM) { 8343be168c0dSopenharmony_ci+ return RET_OK; 8344be168c0dSopenharmony_ci+ } 8345be168c0dSopenharmony_ci+ auto weight_pro = in_tensors_.at(SEVENTH_INPUT); 8346be168c0dSopenharmony_ci+ auto shape = weight_pro->shape(); 8347be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(shape.size() == C3NUM, lite::RET_ERROR, "Project-weight's shape must be 3D."); 8348be168c0dSopenharmony_ci+ auto weight_pro_data = reinterpret_cast<float *>(weight_pro->data()); 8349be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(weight_pro_data); 8350be168c0dSopenharmony_ci+ int batch = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 8351be168c0dSopenharmony_ci+ if (shape[0] != batch) { 8352be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Project-weight's shape[0] must be 1(bidirectional=false) or 2(bidirectional=true)."; 8353be168c0dSopenharmony_ci+ return lite::RET_ERROR; 8354be168c0dSopenharmony_ci+ } 8355be168c0dSopenharmony_ci+ int col_align = UP_ROUND(lstm_param_->output_size_, col_tile_); 8356be168c0dSopenharmony_ci+ auto pack_size = batch * lstm_param_->hidden_size_ * col_align * sizeof(float); 8357be168c0dSopenharmony_ci+ if (lstm_param_->batch_ != 1) { 8358be168c0dSopenharmony_ci+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8359be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8360be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8361be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_project_ptr_); 8362be168c0dSopenharmony_ci+ PackLstmWeightWithStride(weight_project_ptr_, weight_pro_data, batch, lstm_param_->hidden_size_, 8363be168c0dSopenharmony_ci+ lstm_param_->output_size_, col_align, lstm_param_->bidirectional_, 8364be168c0dSopenharmony_ci+ lstm_param_->hidden_size_ * lstm_param_->output_size_, nullptr); 8365be168c0dSopenharmony_ci+ } else { 8366be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 8367be168c0dSopenharmony_ci+ weight_project_ptr_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(pack_size)); 8368be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weight_project_ptr_ != nullptr, lite::RET_NULL_PTR, 8369be168c0dSopenharmony_ci+ "LstmNonMindirCPUKernel malloc weight_project_ptr_ failed."); 8370be168c0dSopenharmony_ci+ running_buffer_.push_back(weight_project_ptr_); 8371be168c0dSopenharmony_ci+ for (int i = 0; i < batch; ++i) { 8372be168c0dSopenharmony_ci+ const float *src_batch = weight_pro_data + i * lstm_param_->hidden_size_ * lstm_param_->output_size_; 8373be168c0dSopenharmony_ci+ float *dst_batch = weight_project_ptr_ + i * lstm_param_->hidden_size_ * col_align; 8374be168c0dSopenharmony_ci+ RowMajor2Col32Major(src_batch, dst_batch, lstm_param_->output_size_, lstm_param_->hidden_size_); 8375be168c0dSopenharmony_ci+ } 8376be168c0dSopenharmony_ci+#else 8377be168c0dSopenharmony_ci+ weight_project_ptr_ = weight_pro_data; 8378be168c0dSopenharmony_ci+#endif 8379be168c0dSopenharmony_ci+ } 8380be168c0dSopenharmony_ci+ return RET_OK; 8381be168c0dSopenharmony_ci+} 8382be168c0dSopenharmony_ci+ 8383be168c0dSopenharmony_ci+void LstmNonMindirFp32CPUKernel::LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, 8384be168c0dSopenharmony_ci+ float *hidden_state, float *cell_state, const float *weight_project, 8385be168c0dSopenharmony_ci+ float *intermediate_states, float **buffer, bool is_backward) { 8386be168c0dSopenharmony_ci+ float *gate = buffer[kInputGateIndex]; 8387be168c0dSopenharmony_ci+ float *input_gate = gate; 8388be168c0dSopenharmony_ci+ float *forget_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C2NUM; 8389be168c0dSopenharmony_ci+ float *cell_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_ * C3NUM; 8390be168c0dSopenharmony_ci+ float *output_gate = gate + lstm_param_->seq_len_ * lstm_param_->batch_ * lstm_param_->hidden_size_; 8391be168c0dSopenharmony_ci+ for (int t = 0; t < lstm_param_->seq_len_; t++) { 8392be168c0dSopenharmony_ci+ int real_t = is_backward ? lstm_param_->seq_len_ - t - C1NUM : t; 8393be168c0dSopenharmony_ci+ float *input_gate_t = input_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8394be168c0dSopenharmony_ci+ float *forget_gate_t = forget_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8395be168c0dSopenharmony_ci+ float *cell_gate_t = cell_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8396be168c0dSopenharmony_ci+ float *output_gate_t = output_gate + lstm_param_->batch_ * lstm_param_->hidden_size_ * real_t; 8397be168c0dSopenharmony_ci+ // Sequence, DirMul, Batch, Hidden 8398be168c0dSopenharmony_ci+ float *output_ptr = output + real_t * lstm_param_->output_step_; 8399be168c0dSopenharmony_ci+ LstmStepUnit(output_ptr, input_gate_t, forget_gate_t, cell_gate_t, output_gate_t, weight_h, state_bias, 8400be168c0dSopenharmony_ci+ weight_project, hidden_state, cell_state, buffer, lstm_param_); 8401be168c0dSopenharmony_ci+ } 8402be168c0dSopenharmony_ci+} 8403be168c0dSopenharmony_ci+} // namespace mindspore::kernel 8404be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h 8405be168c0dSopenharmony_cinew file mode 100644 8406be168c0dSopenharmony_ciindex 00000000..b16e9175 8407be168c0dSopenharmony_ci--- /dev/null 8408be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/lstm_non_mindir_fp32.h 8409be168c0dSopenharmony_ci@@ -0,0 +1,61 @@ 8410be168c0dSopenharmony_ci+/** 8411be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8412be168c0dSopenharmony_ci+ * 8413be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8414be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8415be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8416be168c0dSopenharmony_ci+ * 8417be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8418be168c0dSopenharmony_ci+ * 8419be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8420be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8421be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8422be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8423be168c0dSopenharmony_ci+ * limitations under the License. 8424be168c0dSopenharmony_ci+ */ 8425be168c0dSopenharmony_ci+ 8426be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_NON_MINDIR_FP32_H_ 8427be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_NON_MINDIR_FP32_H_ 8428be168c0dSopenharmony_ci+ 8429be168c0dSopenharmony_ci+#include <vector> 8430be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32/lstm_fp32_base.h" 8431be168c0dSopenharmony_ci+ 8432be168c0dSopenharmony_ci+namespace mindspore::kernel { 8433be168c0dSopenharmony_ci+/* 8434be168c0dSopenharmony_ci+ * 1. LSTM without project, output_size = hidden_size 8435be168c0dSopenharmony_ci+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 8436be168c0dSopenharmony_ci+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, hidden_size] 8437be168c0dSopenharmony_ci+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 8438be168c0dSopenharmony_ci+ * h_init: fifth input, shape is [bidirectional, batch_size, hidden_size] 8439be168c0dSopenharmony_ci+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 8440be168c0dSopenharmony_ci+ * 8441be168c0dSopenharmony_ci+ * 2. LSTM with project, output_size = project_size 8442be168c0dSopenharmony_ci+ * weight_ih: second input, shape is [bidirectional, 4 * hidden_size, input_size] 8443be168c0dSopenharmony_ci+ * weight_hh: third input, shape is [bidirectional, 4 * hidden_size, project_size] 8444be168c0dSopenharmony_ci+ * bias: forth input, shape is [bidirectional, 8 * hidden_size] 8445be168c0dSopenharmony_ci+ * h_init: fifth input, shape is [bidirectional, batch_size, project_size] 8446be168c0dSopenharmony_ci+ * c_init: sixth input, shape is [bidirectional, batch_size, hidden_size] 8447be168c0dSopenharmony_ci+ * weight_pro: seventh input, shape is [bidirectional, project_size, hidden_size] 8448be168c0dSopenharmony_ci+ */ 8449be168c0dSopenharmony_ci+class LstmNonMindirFp32CPUKernel : public LstmFp32BaseCPUKernel { 8450be168c0dSopenharmony_ci+ public: 8451be168c0dSopenharmony_ci+ LstmNonMindirFp32CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 8452be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 8453be168c0dSopenharmony_ci+ : LstmFp32BaseCPUKernel(parameter, inputs, outputs, ctx) { 8454be168c0dSopenharmony_ci+ hidden_init_index_ = FIFTH_INPUT; 8455be168c0dSopenharmony_ci+ cell_init_index_ = SIXTH_INPUT; 8456be168c0dSopenharmony_ci+ } 8457be168c0dSopenharmony_ci+ 8458be168c0dSopenharmony_ci+ ~LstmNonMindirFp32CPUKernel() override = default; 8459be168c0dSopenharmony_ci+ 8460be168c0dSopenharmony_ci+ protected: 8461be168c0dSopenharmony_ci+ int InitInputWeightBias() override; 8462be168c0dSopenharmony_ci+ int InitStateWeightBias() override; 8463be168c0dSopenharmony_ci+ int InitProjectWeight() override; 8464be168c0dSopenharmony_ci+ void LstmUnidirectional(float *output, const float *weight_h, const float *state_bias, float *hidden_state, 8465be168c0dSopenharmony_ci+ float *cell_state, const float *weight_project, float *intermediate_states, float *buffer[], 8466be168c0dSopenharmony_ci+ bool is_backward) override; 8467be168c0dSopenharmony_ci+}; 8468be168c0dSopenharmony_ci+} // namespace mindspore::kernel 8469be168c0dSopenharmony_ci+ 8470be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_LSTM_NON_MINDIR_FP32_H_ 8471be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc 8472be168c0dSopenharmony_cinew file mode 100644 8473be168c0dSopenharmony_ciindex 00000000..60d3f213 8474be168c0dSopenharmony_ci--- /dev/null 8475be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.cc 8476be168c0dSopenharmony_ci@@ -0,0 +1,147 @@ 8477be168c0dSopenharmony_ci+/** 8478be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8479be168c0dSopenharmony_ci+ * 8480be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8481be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8482be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8483be168c0dSopenharmony_ci+ * 8484be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8485be168c0dSopenharmony_ci+ * 8486be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8487be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8488be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8489be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8490be168c0dSopenharmony_ci+ * limitations under the License. 8491be168c0dSopenharmony_ci+ */ 8492be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h" 8493be168c0dSopenharmony_ci+#include "src/litert//kernel_registry.h" 8494be168c0dSopenharmony_ci+#include "include/errorcode.h" 8495be168c0dSopenharmony_ci+#include "src/common/log_adapter.h" 8496be168c0dSopenharmony_ci+#include "nnacl/custom_gather_d_grad_v2_parameter.h" 8497be168c0dSopenharmony_ci+ 8498be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar; 8499be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 8500be168c0dSopenharmony_ci+using mindspore::lite::RET_NOT_SUPPORT; 8501be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 8502be168c0dSopenharmony_ci+ 8503be168c0dSopenharmony_ci+namespace mindspore::kernel { 8504be168c0dSopenharmony_ci+namespace { 8505be168c0dSopenharmony_ci+constexpr size_t index_idx_{1}; 8506be168c0dSopenharmony_ci+constexpr size_t grad_idx_{2}; 8507be168c0dSopenharmony_ci+size_t get_element_num(const std::vector<int> &shape) { 8508be168c0dSopenharmony_ci+ return std::accumulate(shape.begin(), shape.end(), static_cast<std::size_t>(1), std::multiplies<int>()); 8509be168c0dSopenharmony_ci+} 8510be168c0dSopenharmony_ci+ 8511be168c0dSopenharmony_ci+void GatherDGradCopyTask(size_t cur, std::vector<size_t> *pos, float *input, int *index, const int &dim, float *output, 8512be168c0dSopenharmony_ci+ const std::vector<int> &output_shape, const std::vector<size_t> &out_cargo_size, 8513be168c0dSopenharmony_ci+ const std::vector<size_t> &input_cargo_size) { 8514be168c0dSopenharmony_ci+ for (int i = 0; i < output_shape[cur]; ++i) { 8515be168c0dSopenharmony_ci+ (*pos)[cur] = i; 8516be168c0dSopenharmony_ci+ if (cur == output_shape.size() - 1) { 8517be168c0dSopenharmony_ci+ int input_offset = 0; 8518be168c0dSopenharmony_ci+ int out_offset = 0; 8519be168c0dSopenharmony_ci+ // out offset 8520be168c0dSopenharmony_ci+ for (size_t j = 0; j < output_shape.size(); ++j) { 8521be168c0dSopenharmony_ci+ out_offset += (*pos)[j] * out_cargo_size[j]; 8522be168c0dSopenharmony_ci+ } 8523be168c0dSopenharmony_ci+ // input offset 8524be168c0dSopenharmony_ci+ int cur_index = (*pos)[dim]; 8525be168c0dSopenharmony_ci+ (*pos)[dim] = index[out_offset]; 8526be168c0dSopenharmony_ci+ for (size_t j = 0; j < output_shape.size(); ++j) { 8527be168c0dSopenharmony_ci+ input_offset += (*pos)[j] * input_cargo_size[j]; 8528be168c0dSopenharmony_ci+ } 8529be168c0dSopenharmony_ci+ // do copy 8530be168c0dSopenharmony_ci+ input[input_offset] += output[out_offset]; 8531be168c0dSopenharmony_ci+ (*pos)[dim] = cur_index; 8532be168c0dSopenharmony_ci+ } else { 8533be168c0dSopenharmony_ci+ // CopyTask 8534be168c0dSopenharmony_ci+ GatherDGradCopyTask(cur + 1, pos, input, index, dim, output, output_shape, out_cargo_size, input_cargo_size); 8535be168c0dSopenharmony_ci+ } 8536be168c0dSopenharmony_ci+ } 8537be168c0dSopenharmony_ci+} 8538be168c0dSopenharmony_ci+} // namespace 8539be168c0dSopenharmony_ci+ 8540be168c0dSopenharmony_ci+CustomGatherDGradV2CPUKernel::~CustomGatherDGradV2CPUKernel() {} 8541be168c0dSopenharmony_ci+ 8542be168c0dSopenharmony_ci+int CustomGatherDGradV2CPUKernel::Prepare() { 8543be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(in_tensors_.size(), C3NUM); 8544be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 8545be168c0dSopenharmony_ci+ if (InitParamter() != RET_OK) { 8546be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init Built-in CustomGatherGradV2 Parameter failed." << name_; 8547be168c0dSopenharmony_ci+ return RET_ERROR; 8548be168c0dSopenharmony_ci+ } 8549be168c0dSopenharmony_ci+ if (!InferShapeDone()) { 8550be168c0dSopenharmony_ci+ return RET_OK; 8551be168c0dSopenharmony_ci+ } 8552be168c0dSopenharmony_ci+ return ReSize(); 8553be168c0dSopenharmony_ci+} 8554be168c0dSopenharmony_ci+ 8555be168c0dSopenharmony_ci+int CustomGatherDGradV2CPUKernel::InitParamter() { 8556be168c0dSopenharmony_ci+ auto param = reinterpret_cast<CustomGatherGradV2Parameter *>(op_parameter_); 8557be168c0dSopenharmony_ci+ axis_ = param->dim; 8558be168c0dSopenharmony_ci+ return RET_OK; 8559be168c0dSopenharmony_ci+} 8560be168c0dSopenharmony_ci+ 8561be168c0dSopenharmony_ci+int CustomGatherDGradV2CPUKernel::ReSize() { 8562be168c0dSopenharmony_ci+ index_shape_ = in_tensors_[index_idx_]->shape(); 8563be168c0dSopenharmony_ci+ grad_shape_ = in_tensors_[grad_idx_]->shape(); 8564be168c0dSopenharmony_ci+ output_shape_ = out_tensors_[0]->shape(); 8565be168c0dSopenharmony_ci+ if (grad_shape_.size() != index_shape_.size() || output_shape_.size() != index_shape_.size()) { 8566be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "For '" << name_ << "', the dimension of grad and output must be the equal to the " 8567be168c0dSopenharmony_ci+ << "dimension of index: " << index_shape_.size() 8568be168c0dSopenharmony_ci+ << ", but got the dimension of grad: " << grad_shape_.size() 8569be168c0dSopenharmony_ci+ << ", the dimension of output: " << output_shape_.size(); 8570be168c0dSopenharmony_ci+ return RET_ERROR; 8571be168c0dSopenharmony_ci+ } 8572be168c0dSopenharmony_ci+ 8573be168c0dSopenharmony_ci+ return RET_OK; 8574be168c0dSopenharmony_ci+} 8575be168c0dSopenharmony_ci+ 8576be168c0dSopenharmony_ci+int CustomGatherDGradV2CPUKernel::Run() { 8577be168c0dSopenharmony_ci+ auto *index = reinterpret_cast<int *>(in_tensors_[index_idx_]->data()); 8578be168c0dSopenharmony_ci+ auto *grad = reinterpret_cast<float *>(in_tensors_[grad_idx_]->data()); 8579be168c0dSopenharmony_ci+ auto out = reinterpret_cast<float *>(out_tensors_[0]->data()); 8580be168c0dSopenharmony_ci+ int output_rank = output_shape_.size(); 8581be168c0dSopenharmony_ci+ if (axis_ >= output_rank || axis_ < -output_rank) { 8582be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "For '" << name_ << "', the value of 'dim' must be in [" << -output_rank << ", " << output_rank 8583be168c0dSopenharmony_ci+ << "), but got: " << axis_; 8584be168c0dSopenharmony_ci+ } 8585be168c0dSopenharmony_ci+ if (axis_ < 0) { 8586be168c0dSopenharmony_ci+ axis_ = axis_ + output_rank; 8587be168c0dSopenharmony_ci+ } 8588be168c0dSopenharmony_ci+ 8589be168c0dSopenharmony_ci+ // check index 8590be168c0dSopenharmony_ci+ size_t index_size = get_element_num(index_shape_); 8591be168c0dSopenharmony_ci+ int max_index = output_shape_[axis_]; 8592be168c0dSopenharmony_ci+ for (size_t i = 0; i < index_size; ++i) { 8593be168c0dSopenharmony_ci+ if (index[i] >= max_index || index[i] < -max_index) { 8594be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "For '" << name_ << "', the value of 'index' must be in [" << -max_index << ", " << max_index 8595be168c0dSopenharmony_ci+ << "), but got: " << index[i]; 8596be168c0dSopenharmony_ci+ } 8597be168c0dSopenharmony_ci+ if (index[i] < 0) { 8598be168c0dSopenharmony_ci+ index[i] = max_index + index[i]; 8599be168c0dSopenharmony_ci+ } 8600be168c0dSopenharmony_ci+ } 8601be168c0dSopenharmony_ci+ auto out_size = get_element_num(output_shape_); 8602be168c0dSopenharmony_ci+ memset(out, 0, out_size * sizeof(float)); 8603be168c0dSopenharmony_ci+ 8604be168c0dSopenharmony_ci+ // out_cargo_size 8605be168c0dSopenharmony_ci+ std::vector<size_t> out_cargo_size = std::vector<size_t>(output_shape_.size(), 1); 8606be168c0dSopenharmony_ci+ for (int i = static_cast<int>(out_cargo_size.size()) - 2; i >= 0; --i) { 8607be168c0dSopenharmony_ci+ out_cargo_size[i] = output_shape_[i + 1] * out_cargo_size[i + 1]; 8608be168c0dSopenharmony_ci+ } 8609be168c0dSopenharmony_ci+ // grad_cargo_size 8610be168c0dSopenharmony_ci+ std::vector<size_t> grad_cargo_size = std::vector<size_t>(grad_shape_.size(), 1); 8611be168c0dSopenharmony_ci+ for (int i = static_cast<int>(grad_cargo_size.size()) - 2; i >= 0; --i) { 8612be168c0dSopenharmony_ci+ grad_cargo_size[i] = grad_shape_[i + 1] * grad_cargo_size[i + 1]; 8613be168c0dSopenharmony_ci+ } 8614be168c0dSopenharmony_ci+ 8615be168c0dSopenharmony_ci+ // copy task 8616be168c0dSopenharmony_ci+ std::vector<size_t> pos(index_shape_.size(), 0); 8617be168c0dSopenharmony_ci+ GatherDGradCopyTask(0, &pos, out, index, axis_, grad, index_shape_, grad_cargo_size, out_cargo_size); 8618be168c0dSopenharmony_ci+ return RET_OK; 8619be168c0dSopenharmony_ci+} 8620be168c0dSopenharmony_ci+ 8621be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomGatherDGradV2, 8622be168c0dSopenharmony_ci+ LiteKernelCreator<CustomGatherDGradV2CPUKernel>) 8623be168c0dSopenharmony_ci+} // namespace mindspore::kernel 8624be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h 8625be168c0dSopenharmony_cinew file mode 100644 8626be168c0dSopenharmony_ciindex 00000000..25666023 8627be168c0dSopenharmony_ci--- /dev/null 8628be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32_grad/custom_gather_d_grad_v2_fp32.h 8629be168c0dSopenharmony_ci@@ -0,0 +1,42 @@ 8630be168c0dSopenharmony_ci+/** 8631be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8632be168c0dSopenharmony_ci+ * 8633be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8634be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8635be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8636be168c0dSopenharmony_ci+ * 8637be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8638be168c0dSopenharmony_ci+ * 8639be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8640be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8641be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8642be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8643be168c0dSopenharmony_ci+ * limitations under the License. 8644be168c0dSopenharmony_ci+ */ 8645be168c0dSopenharmony_ci+ 8646be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_CUSTOM_GATHER_D_GRAD_V2_H_ 8647be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_CUSTOM_GATHER_D_GRAD_V2_H_ 8648be168c0dSopenharmony_ci+#include <vector> 8649be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h" 8650be168c0dSopenharmony_ci+ 8651be168c0dSopenharmony_ci+namespace mindspore::kernel { 8652be168c0dSopenharmony_ci+class CustomGatherDGradV2CPUKernel : public LiteKernel { 8653be168c0dSopenharmony_ci+ public: 8654be168c0dSopenharmony_ci+ CustomGatherDGradV2CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 8655be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 8656be168c0dSopenharmony_ci+ : LiteKernel(parameter, inputs, outputs, ctx) {} 8657be168c0dSopenharmony_ci+ ~CustomGatherDGradV2CPUKernel() override; 8658be168c0dSopenharmony_ci+ int Prepare() override; 8659be168c0dSopenharmony_ci+ int ReSize() override; 8660be168c0dSopenharmony_ci+ int Run() override; 8661be168c0dSopenharmony_ci+ 8662be168c0dSopenharmony_ci+ private: 8663be168c0dSopenharmony_ci+ int InitParamter(); 8664be168c0dSopenharmony_ci+ 8665be168c0dSopenharmony_ci+ std::vector<int> index_shape_; 8666be168c0dSopenharmony_ci+ std::vector<int> grad_shape_; 8667be168c0dSopenharmony_ci+ std::vector<int> output_shape_; 8668be168c0dSopenharmony_ci+ int axis_{0}; 8669be168c0dSopenharmony_ci+}; 8670be168c0dSopenharmony_ci+} // namespace mindspore::kernel 8671be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_GRAD_CUSTOM_GATHER_D_GRAD_V2_H_ 8672be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc 8673be168c0dSopenharmony_ciindex 48c037b2..7982f818 100644 8674be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/graph_fusion.cc 8675be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/graph_fusion.cc 8676be168c0dSopenharmony_ci@@ -25,6 +25,8 @@ 8677be168c0dSopenharmony_ci #include "src/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.h" 8678be168c0dSopenharmony_ci #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" 8679be168c0dSopenharmony_ci #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 8680be168c0dSopenharmony_ci+#include "src/train/optimizer/fusion/matmul_add_fusion_pass.h" 8681be168c0dSopenharmony_ci+#include "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h" 8682be168c0dSopenharmony_ci 8683be168c0dSopenharmony_ci namespace mindspore { 8684be168c0dSopenharmony_ci namespace lite { 8685be168c0dSopenharmony_ci@@ -52,7 +54,9 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) { 8686be168c0dSopenharmony_ci Optimizer fusion_optimizer; 8687be168c0dSopenharmony_ci fusion_optimizer.AddPass(new (std::nothrow) ReshapeGatherReshapeFusionPass()); 8688be168c0dSopenharmony_ci fusion_optimizer.AddPass(new (std::nothrow) MatMulBiasAddFusionPass()); 8689be168c0dSopenharmony_ci+ fusion_optimizer.AddPass(new (std::nothrow) MatMulAddFusionPass()); 8690be168c0dSopenharmony_ci fusion_optimizer.AddPass(new (std::nothrow) MatMulActivationFusionPass()); 8691be168c0dSopenharmony_ci+ fusion_optimizer.AddPass(new (std::nothrow) MatMulMatMulAddFusionPass()); 8692be168c0dSopenharmony_ci fusion_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); 8693be168c0dSopenharmony_ci fusion_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); 8694be168c0dSopenharmony_ci auto status = fusion_optimizer.Run(graph); 8695be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.cc 8696be168c0dSopenharmony_cinew file mode 100644 8697be168c0dSopenharmony_ciindex 00000000..34bed911 8698be168c0dSopenharmony_ci--- /dev/null 8699be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.cc 8700be168c0dSopenharmony_ci@@ -0,0 +1,127 @@ 8701be168c0dSopenharmony_ci+/** 8702be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8703be168c0dSopenharmony_ci+ * 8704be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8705be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8706be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8707be168c0dSopenharmony_ci+ * 8708be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8709be168c0dSopenharmony_ci+ * 8710be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8711be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8712be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8713be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8714be168c0dSopenharmony_ci+ * limitations under the License. 8715be168c0dSopenharmony_ci+ */ 8716be168c0dSopenharmony_ci+#include "src/train/optimizer/fusion/matmul_add_fusion_pass.h" 8717be168c0dSopenharmony_ci+#include <string> 8718be168c0dSopenharmony_ci+#include <unordered_map> 8719be168c0dSopenharmony_ci+#include <vector> 8720be168c0dSopenharmony_ci+#include <memory> 8721be168c0dSopenharmony_ci+#include "schema/inner/model_generated.h" 8722be168c0dSopenharmony_ci+#include "tools/common/meta_graph_utils.h" 8723be168c0dSopenharmony_ci+namespace { 8724be168c0dSopenharmony_ci+constexpr int kNumAddMatchPathLen = 2; 8725be168c0dSopenharmony_ci+constexpr std::string_view MulName = "MATMUL"; 8726be168c0dSopenharmony_ci+constexpr std::string_view AddName = "ADD"; 8727be168c0dSopenharmony_ci+} // namespace 8728be168c0dSopenharmony_ci+namespace mindspore { 8729be168c0dSopenharmony_ci+namespace lite { 8730be168c0dSopenharmony_ci+namespace { 8731be168c0dSopenharmony_ci+int CalNewCnodeBias(const std::unique_ptr<mindspore::schema::TensorT> &add_weight_tensor, 8732be168c0dSopenharmony_ci+ const std::unique_ptr<mindspore::schema::TensorT> &matmul_bias_tensor) { 8733be168c0dSopenharmony_ci+ if (add_weight_tensor->dataType != kNumberTypeFloat32 || matmul_bias_tensor->dataType != kNumberTypeFloat32) { 8734be168c0dSopenharmony_ci+ MS_LOG(INFO) << "only support float32 data type"; 8735be168c0dSopenharmony_ci+ return RET_ERROR; 8736be168c0dSopenharmony_ci+ } 8737be168c0dSopenharmony_ci+ std::vector<int32_t> matmul_bias_shape = matmul_bias_tensor->dims; 8738be168c0dSopenharmony_ci+ std::vector<int32_t> add_weight_shape = add_weight_tensor->dims; 8739be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_bias_shape == add_weight_shape, RET_ERROR); 8740be168c0dSopenharmony_ci+ auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->data.data()); 8741be168c0dSopenharmony_ci+ auto matmul_bias_data = reinterpret_cast<float *>(matmul_bias_tensor->data.data()); 8742be168c0dSopenharmony_ci+ int num = static_cast<int>(matmul_bias_tensor->data.size() / sizeof(float)); 8743be168c0dSopenharmony_ci+ for (int i = 0; i < num; ++i) { 8744be168c0dSopenharmony_ci+ matmul_bias_data[i] += add_weight_data[i]; 8745be168c0dSopenharmony_ci+ } 8746be168c0dSopenharmony_ci+ return RET_OK; 8747be168c0dSopenharmony_ci+} 8748be168c0dSopenharmony_ci+} // namespace 8749be168c0dSopenharmony_ci+STATUS MatMulAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); } 8750be168c0dSopenharmony_ci+STATUS MatMulAddFusionPass::DefinePattern() { 8751be168c0dSopenharmony_ci+ auto mul_op = std::make_shared<PatternOp>(); 8752be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(mul_op != nullptr, RET_NULL_PTR); 8753be168c0dSopenharmony_ci+ mul_op->id = MulName; 8754be168c0dSopenharmony_ci+ mul_op->types = {schema::PrimitiveType_MatMulFusion}; 8755be168c0dSopenharmony_ci+ auto add_op = std::make_shared<PatternOp>(); 8756be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(add_op != nullptr, RET_NULL_PTR); 8757be168c0dSopenharmony_ci+ add_op->id = AddName; 8758be168c0dSopenharmony_ci+ add_op->types = {schema::PrimitiveType_AddFusion}; 8759be168c0dSopenharmony_ci+ add_op->left = mul_op; 8760be168c0dSopenharmony_ci+ std::unique_ptr<FusionPattern> fusion_pattern(new (std::nothrow) FusionPattern("MatMulAddFusion")); 8761be168c0dSopenharmony_ci+ if (fusion_pattern == nullptr) { 8762be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new fusion_pattern failed"; 8763be168c0dSopenharmony_ci+ return RET_ERROR; 8764be168c0dSopenharmony_ci+ } 8765be168c0dSopenharmony_ci+ fusion_pattern->AddPatternOp(mul_op); 8766be168c0dSopenharmony_ci+ fusion_pattern->AddPatternOp(add_op); 8767be168c0dSopenharmony_ci+ fusion_pattern->Finish(); 8768be168c0dSopenharmony_ci+ this->patterns.emplace_back(fusion_pattern.release()); 8769be168c0dSopenharmony_ci+ return RET_OK; 8770be168c0dSopenharmony_ci+} 8771be168c0dSopenharmony_ci+STATUS MatMulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pattern_name, 8772be168c0dSopenharmony_ci+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) { 8773be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR); 8774be168c0dSopenharmony_ci+ if (matched_path.size() != kNumAddMatchPathLen) { 8775be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "MatMul-Add-Fusion should have two NodeIndex in matchedPair"; 8776be168c0dSopenharmony_ci+ return RET_PARAM_INVALID; 8777be168c0dSopenharmony_ci+ } 8778be168c0dSopenharmony_ci+ auto mul_path_iter = matched_path.find(std::string(MulName)); 8779be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(mul_path_iter != matched_path.end(), RET_NO_CHANGE); 8780be168c0dSopenharmony_ci+ auto &mul_path = mul_path_iter->second; 8781be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(mul_path != nullptr, RET_NULL_PTR); 8782be168c0dSopenharmony_ci+ auto add_path_iter = matched_path.find(std::string(AddName)); 8783be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(add_path_iter != matched_path.end(), RET_NO_CHANGE); 8784be168c0dSopenharmony_ci+ auto &add_path = add_path_iter->second; 8785be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(add_path != nullptr, RET_NULL_PTR); 8786be168c0dSopenharmony_ci+ auto mul_index = mul_path->nodeIdx; 8787be168c0dSopenharmony_ci+ auto add_index = add_path->nodeIdx; 8788be168c0dSopenharmony_ci+ auto &mul_node = graph->nodes.at(mul_index); 8789be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(mul_node != nullptr, RET_NULL_PTR); 8790be168c0dSopenharmony_ci+ auto &add_node = graph->nodes.at(add_index); 8791be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(add_node != nullptr, RET_NULL_PTR); 8792be168c0dSopenharmony_ci+ if (mul_node->quantType == schema::QuantType_QUANT_ALL || mul_node->quantType == schema::QuantType_QUANT_DYNAMIC || 8793be168c0dSopenharmony_ci+ add_node->quantType == schema::QuantType_QUANT_ALL || add_node->quantType == schema::QuantType_QUANT_DYNAMIC) { 8794be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "cannot fusion."; 8795be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 8796be168c0dSopenharmony_ci+ } 8797be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(mul_node->primitive != nullptr, RET_NULL_PTR); 8798be168c0dSopenharmony_ci+ auto matmul_type = mul_node->primitive->value.AsMatMulFusion(); 8799be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_type->activation_type == ActivationType::ActivationType_NO_ACTIVATION, RET_NO_CHANGE); 8800be168c0dSopenharmony_ci+ auto add_param_shape = graph->allTensors.at(add_node->inputIndex.at(SECOND_INPUT))->dims; 8801be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(add_param_shape.size() == DIMENSION_1D, RET_NO_CHANGE, "only support bias with shape size of 1."); 8802be168c0dSopenharmony_ci+ if (mul_node->inputIndex.size() == C3NUM) { 8803be168c0dSopenharmony_ci+ auto &mul_bias_tensor = graph->allTensors.at(mul_node->inputIndex.at(THIRD_INPUT)); 8804be168c0dSopenharmony_ci+ if (mul_bias_tensor->data.data() == nullptr) { 8805be168c0dSopenharmony_ci+ MS_LOG(INFO) << mul_node->name << "'s bias is not const"; 8806be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 8807be168c0dSopenharmony_ci+ } 8808be168c0dSopenharmony_ci+ auto &add_weight_tensor = graph->allTensors.at(add_node->inputIndex.at(SECOND_INPUT)); 8809be168c0dSopenharmony_ci+ if (CalNewCnodeBias(add_weight_tensor, mul_bias_tensor) != RET_OK) { 8810be168c0dSopenharmony_ci+ MS_LOG(INFO) << add_node->name << " failed to fusion with " << mul_node->name; 8811be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 8812be168c0dSopenharmony_ci+ } 8813be168c0dSopenharmony_ci+ } 8814be168c0dSopenharmony_ci+ auto add_tensor_index = add_node->inputIndex.at(SECOND_INPUT); 8815be168c0dSopenharmony_ci+ if (mul_node->inputIndex.size() == C2NUM) { 8816be168c0dSopenharmony_ci+ mul_node->inputIndex.push_back(add_tensor_index); 8817be168c0dSopenharmony_ci+ } 8818be168c0dSopenharmony_ci+ mul_node->outputIndex = {add_node->outputIndex}; 8819be168c0dSopenharmony_ci+ // cannot delete node here, otherwise will destroy order in other pattern's node index 8820be168c0dSopenharmony_ci+ // make it an isolated node to be removed in IsolatedNodeRemovePass 8821be168c0dSopenharmony_ci+ add_node->inputIndex.clear(); 8822be168c0dSopenharmony_ci+ add_node->outputIndex.clear(); 8823be168c0dSopenharmony_ci+ return RET_OK; 8824be168c0dSopenharmony_ci+} 8825be168c0dSopenharmony_ci+MatMulAddFusionPass::~MatMulAddFusionPass() = default; 8826be168c0dSopenharmony_ci+} // namespace lite 8827be168c0dSopenharmony_ci+} // namespace mindspore 8828be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.h 8829be168c0dSopenharmony_cinew file mode 100644 8830be168c0dSopenharmony_ciindex 00000000..8eb4ab2e 8831be168c0dSopenharmony_ci--- /dev/null 8832be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_add_fusion_pass.h 8833be168c0dSopenharmony_ci@@ -0,0 +1,37 @@ 8834be168c0dSopenharmony_ci+/** 8835be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8836be168c0dSopenharmony_ci+ * 8837be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8838be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8839be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8840be168c0dSopenharmony_ci+ * 8841be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8842be168c0dSopenharmony_ci+ * 8843be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8844be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8845be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8846be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8847be168c0dSopenharmony_ci+ * limitations under the License. 8848be168c0dSopenharmony_ci+ */ 8849be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_MATMUL_ADD_FUSION_PASS_H_ 8850be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_MATMUL_ADD_FUSION_PASS_H_ 8851be168c0dSopenharmony_ci+#include <string> 8852be168c0dSopenharmony_ci+#include <unordered_map> 8853be168c0dSopenharmony_ci+#include <memory> 8854be168c0dSopenharmony_ci+#include <algorithm> 8855be168c0dSopenharmony_ci+#include <utility> 8856be168c0dSopenharmony_ci+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" 8857be168c0dSopenharmony_ci+namespace mindspore { 8858be168c0dSopenharmony_ci+namespace lite { 8859be168c0dSopenharmony_ci+class MatMulAddFusionPass : public FusionPass { 8860be168c0dSopenharmony_ci+ public: 8861be168c0dSopenharmony_ci+ MatMulAddFusionPass() = default; 8862be168c0dSopenharmony_ci+ ~MatMulAddFusionPass() override; 8863be168c0dSopenharmony_ci+ STATUS DefinePattern() override; 8864be168c0dSopenharmony_ci+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name, 8865be168c0dSopenharmony_ci+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override; 8866be168c0dSopenharmony_ci+ STATUS Run(MetaGraphT *graph) override; 8867be168c0dSopenharmony_ci+}; 8868be168c0dSopenharmony_ci+} // namespace lite 8869be168c0dSopenharmony_ci+} // namespace mindspore 8870be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_MATMUL_ADD_FUSION_PASS_H_ 8871be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 8872be168c0dSopenharmony_cinew file mode 100644 8873be168c0dSopenharmony_ciindex 00000000..d1a63c2d 8874be168c0dSopenharmony_ci--- /dev/null 8875be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 8876be168c0dSopenharmony_ci@@ -0,0 +1,163 @@ 8877be168c0dSopenharmony_ci+/** 8878be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 8879be168c0dSopenharmony_ci+ * 8880be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8881be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8882be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8883be168c0dSopenharmony_ci+ * 8884be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8885be168c0dSopenharmony_ci+ * 8886be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8887be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8888be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8889be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8890be168c0dSopenharmony_ci+ * limitations under the License. 8891be168c0dSopenharmony_ci+ */ 8892be168c0dSopenharmony_ci+ 8893be168c0dSopenharmony_ci+#include "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h" 8894be168c0dSopenharmony_ci+#include <string> 8895be168c0dSopenharmony_ci+#include <unordered_map> 8896be168c0dSopenharmony_ci+#include <vector> 8897be168c0dSopenharmony_ci+#include <memory> 8898be168c0dSopenharmony_ci+#include "schema/inner/model_generated.h" 8899be168c0dSopenharmony_ci+#include "tools/common/meta_graph_utils.h" 8900be168c0dSopenharmony_ci+#include "src/train/optimizer/common/fusion_utils.h" 8901be168c0dSopenharmony_ci+namespace { 8902be168c0dSopenharmony_ci+constexpr std::string_view kFirstMatMulName = "MATMUL1"; 8903be168c0dSopenharmony_ci+constexpr std::string_view kSecondMatMulName = "MATMUL2"; 8904be168c0dSopenharmony_ci+constexpr std::string_view kAddName = "ADD"; 8905be168c0dSopenharmony_ci+} // namespace 8906be168c0dSopenharmony_ci+namespace mindspore { 8907be168c0dSopenharmony_ci+namespace lite { 8908be168c0dSopenharmony_ci+/* 8909be168c0dSopenharmony_ci+ * The subgraph such as the following. 8910be168c0dSopenharmony_ci+ * any any 8911be168c0dSopenharmony_ci+ * / \ | 8912be168c0dSopenharmony_ci+ * matmul matmul matmul 8913be168c0dSopenharmony_ci+ * \ / ----> | 8914be168c0dSopenharmony_ci+ * add any 8915be168c0dSopenharmony_ci+ * | 8916be168c0dSopenharmony_ci+ * any 8917be168c0dSopenharmony_ci+ */ 8918be168c0dSopenharmony_ci+namespace { 8919be168c0dSopenharmony_ci+int CalNewMatMulNode(MetaGraphT *graph, const std::unique_ptr<mindspore::schema::CNodeT> &matmul_node1, 8920be168c0dSopenharmony_ci+ const std::unique_ptr<mindspore::schema::CNodeT> &matmul_node2) { 8921be168c0dSopenharmony_ci+ auto &matrix_b_1 = graph->allTensors.at(matmul_node1->inputIndex.at(opt::kInputIndexOne)); 8922be168c0dSopenharmony_ci+ auto &matrix_b_2 = graph->allTensors.at(matmul_node2->inputIndex.at(opt::kInputIndexOne)); 8923be168c0dSopenharmony_ci+ if (matrix_b_1->dims != matrix_b_2->dims) { 8924be168c0dSopenharmony_ci+ MS_LOG(INFO) << "currently, matmul fusion only support the same shape tensor"; 8925be168c0dSopenharmony_ci+ return RET_ERROR; 8926be168c0dSopenharmony_ci+ } 8927be168c0dSopenharmony_ci+ if (matrix_b_1->dataType != kNumberTypeFloat32 || matrix_b_2->dataType != kNumberTypeFloat32) { 8928be168c0dSopenharmony_ci+ MS_LOG(INFO) << "only support float32 data type"; 8929be168c0dSopenharmony_ci+ return RET_ERROR; 8930be168c0dSopenharmony_ci+ } 8931be168c0dSopenharmony_ci+ auto matrix_b_1_data = reinterpret_cast<float *>(matrix_b_1->data.data()); 8932be168c0dSopenharmony_ci+ auto matrix_b_2_data = reinterpret_cast<float *>(matrix_b_2->data.data()); 8933be168c0dSopenharmony_ci+ int num_b = static_cast<int>(matrix_b_1->data.size() / sizeof(float)); 8934be168c0dSopenharmony_ci+ for (int j = 0; j < num_b; ++j) { 8935be168c0dSopenharmony_ci+ matrix_b_1_data[j] += matrix_b_2_data[j]; 8936be168c0dSopenharmony_ci+ } 8937be168c0dSopenharmony_ci+ return RET_OK; 8938be168c0dSopenharmony_ci+} 8939be168c0dSopenharmony_ci+} // namespace 8940be168c0dSopenharmony_ci+STATUS MatMulMatMulAddFusionPass::DefinePattern() { 8941be168c0dSopenharmony_ci+ auto matmul_op1 = std::make_shared<PatternOp>(); 8942be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_op1 != nullptr, RET_NULL_PTR); 8943be168c0dSopenharmony_ci+ matmul_op1->id = kFirstMatMulName; 8944be168c0dSopenharmony_ci+ matmul_op1->types = {schema::PrimitiveType_MatMulFusion}; 8945be168c0dSopenharmony_ci+ auto matmul_op2 = std::make_shared<PatternOp>(); 8946be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_op2 != nullptr, RET_NULL_PTR); 8947be168c0dSopenharmony_ci+ matmul_op2->id = kSecondMatMulName; 8948be168c0dSopenharmony_ci+ matmul_op2->types = {schema::PrimitiveType_MatMulFusion}; 8949be168c0dSopenharmony_ci+ auto add_op = std::make_shared<PatternOp>(); 8950be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(add_op != nullptr, RET_NULL_PTR); 8951be168c0dSopenharmony_ci+ add_op->id = kAddName; 8952be168c0dSopenharmony_ci+ add_op->types = {schema::PrimitiveType_AddFusion}; 8953be168c0dSopenharmony_ci+ add_op->left = matmul_op1; 8954be168c0dSopenharmony_ci+ add_op->right = matmul_op2; 8955be168c0dSopenharmony_ci+ auto fusion_pattern = std::make_unique<FusionPattern>("MatMulMatMulAddFusion"); 8956be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(fusion_pattern != nullptr, RET_NULL_PTR, "new fusion_pattern failed"); 8957be168c0dSopenharmony_ci+ fusion_pattern->AddPatternOp(matmul_op1); 8958be168c0dSopenharmony_ci+ fusion_pattern->AddPatternOp(matmul_op2); 8959be168c0dSopenharmony_ci+ fusion_pattern->AddPatternOp(add_op); 8960be168c0dSopenharmony_ci+ fusion_pattern->Finish(); 8961be168c0dSopenharmony_ci+ this->patterns.emplace_back(fusion_pattern.release()); 8962be168c0dSopenharmony_ci+ return RET_OK; 8963be168c0dSopenharmony_ci+} 8964be168c0dSopenharmony_ci+ 8965be168c0dSopenharmony_ci+STATUS MatMulMatMulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &pattern_name, 8966be168c0dSopenharmony_ci+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) { 8967be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR); 8968be168c0dSopenharmony_ci+ if (matched_path.size() != opt::kMatchPathLenThree) { 8969be168c0dSopenharmony_ci+ MS_LOG(INFO) << "MatMul-MatMul-Add-Fusion should have three NodeIndex in matchedPair"; 8970be168c0dSopenharmony_ci+ return RET_PARAM_INVALID; 8971be168c0dSopenharmony_ci+ } 8972be168c0dSopenharmony_ci+ 8973be168c0dSopenharmony_ci+ size_t matmul_index1 = 0; 8974be168c0dSopenharmony_ci+ auto ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(kFirstMatMulName), &matmul_index1); 8975be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index1"); 8976be168c0dSopenharmony_ci+ auto &matmul_node1 = graph->nodes.at(matmul_index1); 8977be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(matmul_node1 != nullptr, RET_NULL_PTR, "matmul_node1 is nullptr"); 8978be168c0dSopenharmony_ci+ size_t matmul_index2 = 0; 8979be168c0dSopenharmony_ci+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(kSecondMatMulName), &matmul_index2); 8980be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get matmul_index2"); 8981be168c0dSopenharmony_ci+ auto &matmul_node2 = graph->nodes.at(matmul_index2); 8982be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(matmul_node2 != nullptr, RET_NULL_PTR, "matmul_node2 is nullptr"); 8983be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(matmul_node1->inputIndex.size() > C1NUM && matmul_node2->inputIndex.size() > C1NUM, 8984be168c0dSopenharmony_ci+ RET_PARAM_INVALID, "matmul should have two input at least"); 8985be168c0dSopenharmony_ci+ if (matmul_node1->inputIndex.size() < matmul_node2->inputIndex.size()) { 8986be168c0dSopenharmony_ci+ matmul_node1.swap(matmul_node2); 8987be168c0dSopenharmony_ci+ } 8988be168c0dSopenharmony_ci+ size_t add_index = 0; 8989be168c0dSopenharmony_ci+ ret = opt::GetMatchNodeIndex(graph, matched_path, std::string(kAddName), &add_index); 8990be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "cannot get add_index"); 8991be168c0dSopenharmony_ci+ auto &add_node = graph->nodes.at(add_index); 8992be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(add_node != nullptr, RET_NULL_PTR, "add_node is nullptr"); 8993be168c0dSopenharmony_ci+ 8994be168c0dSopenharmony_ci+ if (matmul_node1->quantType == schema::QuantType_QUANT_ALL || 8995be168c0dSopenharmony_ci+ matmul_node1->quantType == schema::QuantType_QUANT_DYNAMIC || 8996be168c0dSopenharmony_ci+ matmul_node2->quantType == schema::QuantType_QUANT_ALL || 8997be168c0dSopenharmony_ci+ matmul_node2->quantType == schema::QuantType_QUANT_DYNAMIC || 8998be168c0dSopenharmony_ci+ add_node->quantType == schema::QuantType_QUANT_ALL || add_node->quantType == schema::QuantType_QUANT_DYNAMIC) { 8999be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "cannot fusion with quant node"; 9000be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 9001be168c0dSopenharmony_ci+ } 9002be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_node1->primitive != nullptr, RET_NULL_PTR); 9003be168c0dSopenharmony_ci+ auto matmul_type1 = matmul_node1->primitive->value.AsMatMulFusion()->activation_type; 9004be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_node2->primitive != nullptr, RET_NULL_PTR); 9005be168c0dSopenharmony_ci+ auto matmul_type2 = matmul_node2->primitive->value.AsMatMulFusion()->activation_type; 9006be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(add_node->primitive != nullptr, RET_NULL_PTR); 9007be168c0dSopenharmony_ci+ auto add_type = add_node->primitive->value.AsAddFusion()->activation_type; 9008be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_type1 == ActivationType::ActivationType_NO_ACTIVATION && 9009be168c0dSopenharmony_ci+ matmul_type2 == ActivationType::ActivationType_NO_ACTIVATION && 9010be168c0dSopenharmony_ci+ add_type == ActivationType::ActivationType_NO_ACTIVATION, 9011be168c0dSopenharmony_ci+ RET_NO_CHANGE); 9012be168c0dSopenharmony_ci+ 9013be168c0dSopenharmony_ci+ if (matmul_node1->inputIndex.at(FIRST_INPUT) != matmul_node2->inputIndex.at(FIRST_INPUT)) { 9014be168c0dSopenharmony_ci+ MS_LOG(INFO) << "matmul should have the same first input"; 9015be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 9016be168c0dSopenharmony_ci+ } 9017be168c0dSopenharmony_ci+ auto &matmul_left_b = graph->allTensors[matmul_node1->inputIndex.at(SECOND_INPUT)]; 9018be168c0dSopenharmony_ci+ auto &matmul_right_b = graph->allTensors[matmul_node2->inputIndex.at(SECOND_INPUT)]; 9019be168c0dSopenharmony_ci+ if (matmul_left_b->data.empty() || matmul_right_b->data.empty()) { 9020be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 9021be168c0dSopenharmony_ci+ } 9022be168c0dSopenharmony_ci+ if (CalNewMatMulNode(graph, matmul_node1, matmul_node2) != RET_OK) { 9023be168c0dSopenharmony_ci+ MS_LOG(INFO) << "failed to fusion two matmul"; 9024be168c0dSopenharmony_ci+ return RET_NO_CHANGE; 9025be168c0dSopenharmony_ci+ } 9026be168c0dSopenharmony_ci+ 9027be168c0dSopenharmony_ci+ matmul_node1->outputIndex = {add_node->outputIndex}; 9028be168c0dSopenharmony_ci+ // cannot delete node here, otherwise will destroy order in other pattern's node index 9029be168c0dSopenharmony_ci+ // make it an isolated node to be removed in IsolatedNodeRemovePass 9030be168c0dSopenharmony_ci+ matmul_node2->inputIndex.clear(); 9031be168c0dSopenharmony_ci+ matmul_node2->outputIndex.clear(); 9032be168c0dSopenharmony_ci+ add_node->inputIndex.clear(); 9033be168c0dSopenharmony_ci+ add_node->outputIndex.clear(); 9034be168c0dSopenharmony_ci+ return RET_OK; 9035be168c0dSopenharmony_ci+} 9036be168c0dSopenharmony_ci+ 9037be168c0dSopenharmony_ci+MatMulMatMulAddFusionPass::~MatMulMatMulAddFusionPass() = default; 9038be168c0dSopenharmony_ci+} // namespace lite 9039be168c0dSopenharmony_ci+} // namespace mindspore 9040be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h 9041be168c0dSopenharmony_cinew file mode 100644 9042be168c0dSopenharmony_ciindex 00000000..9ee6d711 9043be168c0dSopenharmony_ci--- /dev/null 9044be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h 9045be168c0dSopenharmony_ci@@ -0,0 +1,43 @@ 9046be168c0dSopenharmony_ci+/** 9047be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 9048be168c0dSopenharmony_ci+ * 9049be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 9050be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 9051be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 9052be168c0dSopenharmony_ci+ * 9053be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 9054be168c0dSopenharmony_ci+ * 9055be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 9056be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 9057be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9058be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 9059be168c0dSopenharmony_ci+ * limitations under the License. 9060be168c0dSopenharmony_ci+ */ 9061be168c0dSopenharmony_ci+ 9062be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_MATMUL_ADD_FUSION_PASS_H_ 9063be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_MATMUL_ADD_FUSION_PASS_H_ 9064be168c0dSopenharmony_ci+ 9065be168c0dSopenharmony_ci+#include <string> 9066be168c0dSopenharmony_ci+#include <unordered_map> 9067be168c0dSopenharmony_ci+#include <memory> 9068be168c0dSopenharmony_ci+#include <algorithm> 9069be168c0dSopenharmony_ci+#include <utility> 9070be168c0dSopenharmony_ci+#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" 9071be168c0dSopenharmony_ci+ 9072be168c0dSopenharmony_ci+namespace mindspore { 9073be168c0dSopenharmony_ci+namespace lite { 9074be168c0dSopenharmony_ci+class MatMulMatMulAddFusionPass : public FusionPass { 9075be168c0dSopenharmony_ci+ public: 9076be168c0dSopenharmony_ci+ MatMulMatMulAddFusionPass() = default; 9077be168c0dSopenharmony_ci+ 9078be168c0dSopenharmony_ci+ ~MatMulMatMulAddFusionPass() override; 9079be168c0dSopenharmony_ci+ 9080be168c0dSopenharmony_ci+ STATUS DefinePattern() override; 9081be168c0dSopenharmony_ci+ 9082be168c0dSopenharmony_ci+ STATUS DoFusion(MetaGraphT *graph, const std::string &pattern_name, 9083be168c0dSopenharmony_ci+ const std::unordered_map<std::string, std::shared_ptr<Path>> &matched_path) override; 9084be168c0dSopenharmony_ci+}; 9085be168c0dSopenharmony_ci+} // namespace lite 9086be168c0dSopenharmony_ci+} // namespace mindspore 9087be168c0dSopenharmony_ci+ 9088be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_FUSION_MATMUL_MATMUL_ADD_FUSION_PASS_H_ 9089be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_export.cc b/mindspore/lite/src/train/train_export.cc 9090be168c0dSopenharmony_ciindex 7534ed2f..5bace006 100644 9091be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_export.cc 9092be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_export.cc 9093be168c0dSopenharmony_ci@@ -151,11 +151,18 @@ int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tenso 9094be168c0dSopenharmony_ci return RET_OK; 9095be168c0dSopenharmony_ci } 9096be168c0dSopenharmony_ci 9097be168c0dSopenharmony_ci-std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite::Tensor *tensor, 9098be168c0dSopenharmony_ci- schema::Tensor *scTensor, int preferred_dim, 9099be168c0dSopenharmony_ci- const int tensor_quant_type) { 9100be168c0dSopenharmony_ci+std::unique_ptr<schema::TensorT> TrainExport::CreateTensor( 9101be168c0dSopenharmony_ci+ const mindspore::lite::Tensor *tensor, const std::vector<mindspore::lite::Tensor *> const_folded_output, 9102be168c0dSopenharmony_ci+ schema::Tensor *scTensor, int preferred_dim, const int tensor_quant_type) { 9103be168c0dSopenharmony_ci auto tensorT = std::make_unique<schema::TensorT>(); 9104be168c0dSopenharmony_ci- tensorT->nodeType = scTensor->nodeType(); 9105be168c0dSopenharmony_ci+ bool const_fold = false; 9106be168c0dSopenharmony_ci+ if (quant_type_ == QT_NONE && !const_folded_output.empty() && 9107be168c0dSopenharmony_ci+ std::find(const_folded_output.begin(), const_folded_output.end(), tensor) != const_folded_output.end()) { 9108be168c0dSopenharmony_ci+ tensorT->nodeType = NodeType_ValueNode; 9109be168c0dSopenharmony_ci+ const_fold = true; 9110be168c0dSopenharmony_ci+ } else { 9111be168c0dSopenharmony_ci+ tensorT->nodeType = scTensor->nodeType(); 9112be168c0dSopenharmony_ci+ } 9113be168c0dSopenharmony_ci tensorT->dims = tensor->shape(); 9114be168c0dSopenharmony_ci tensorT->format = static_cast<schema::Format>(tensor->format()); 9115be168c0dSopenharmony_ci tensorT->name = tensor->tensor_name(); 9116be168c0dSopenharmony_ci@@ -163,7 +170,8 @@ std::unique_ptr<schema::TensorT> TrainExport::CreateTensor(const mindspore::lite 9117be168c0dSopenharmony_ci tensorT->offset = 0; 9118be168c0dSopenharmony_ci tensorT->dataType = tensor->data_type(); 9119be168c0dSopenharmony_ci tensorT->enableHuffmanCode = false; 9120be168c0dSopenharmony_ci- if ((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) { 9121be168c0dSopenharmony_ci+ if (((tensorT->nodeType == NodeType_ValueNode) && (scTensor->data() != nullptr) && (scTensor->data()->size() > 0)) || 9122be168c0dSopenharmony_ci+ const_fold) { 9123be168c0dSopenharmony_ci if (NeedQuantization(tensor, tensor_quant_type)) { 9124be168c0dSopenharmony_ci auto ret = QuantTensorData(tensorT.get(), tensor, preferred_dim); 9125be168c0dSopenharmony_ci if (ret != RET_OK) { 9126be168c0dSopenharmony_ci@@ -392,6 +400,7 @@ int TrainExport::KeepGraphInputsInOrder(const Model *model) { 9127be168c0dSopenharmony_ci return RET_OK; 9128be168c0dSopenharmony_ci } 9129be168c0dSopenharmony_ci int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset, 9130be168c0dSopenharmony_ci+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9131be168c0dSopenharmony_ci const std::vector<std::pair<size_t, tensor_info>> &map_index, 9132be168c0dSopenharmony_ci const std::vector<std::string> &output_names, const std::set<size_t> &out_set) { 9133be168c0dSopenharmony_ci std::vector<mindspore::lite::Tensor *> in_tensors; 9134be168c0dSopenharmony_ci@@ -401,6 +410,7 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l 9135be168c0dSopenharmony_ci mindspore::lite::Tensor *tensor = tensors.at(pid); 9136be168c0dSopenharmony_ci in_tensors.push_back(tensor); 9137be168c0dSopenharmony_ci } 9138be168c0dSopenharmony_ci+ std::map<std::string, uint32_t> ordered_output_names; 9139be168c0dSopenharmony_ci for (auto index : map_index) { 9140be168c0dSopenharmony_ci auto id = index.first; 9141be168c0dSopenharmony_ci size_t pid = id - static_cast<size_t>(offset); 9142be168c0dSopenharmony_ci@@ -408,7 +418,8 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l 9143be168c0dSopenharmony_ci schema::Tensor *scTensor = model->graph_.all_tensors_.at(pid); 9144be168c0dSopenharmony_ci auto preferred_dim = WeightDecoder::GetPreferredDim(in_tensors, index.second.op_parameter, index.second.input_index, 9145be168c0dSopenharmony_ci tensor->shape(), model->graph_.version_); 9146be168c0dSopenharmony_ci- auto tensorT = CreateTensor(tensor, scTensor, preferred_dim, index.second.op_parameter->quant_type_); 9147be168c0dSopenharmony_ci+ auto tensorT = 9148be168c0dSopenharmony_ci+ CreateTensor(tensor, const_folded_output, scTensor, preferred_dim, index.second.op_parameter->quant_type_); 9149be168c0dSopenharmony_ci if (tensorT == nullptr) { 9150be168c0dSopenharmony_ci MS_LOG(ERROR) << "error in tensor creation"; 9151be168c0dSopenharmony_ci return RET_ERROR; 9152be168c0dSopenharmony_ci@@ -423,21 +434,27 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l 9153be168c0dSopenharmony_ci } 9154be168c0dSopenharmony_ci // find output tensor 9155be168c0dSopenharmony_ci if (std::find(output_names.begin(), output_names.end(), tensor->tensor_name()) != output_names.end()) { 9156be168c0dSopenharmony_ci- meta_graph_->outputIndex.push_back(remap_[id]); 9157be168c0dSopenharmony_ci- if (!meta_graph_->subGraph.empty()) { 9158be168c0dSopenharmony_ci- meta_graph_->subGraph[0]->outputIndices.push_back(remap_[id]); 9159be168c0dSopenharmony_ci- } 9160be168c0dSopenharmony_ci+ ordered_output_names[tensor->tensor_name()] = remap_[id]; 9161be168c0dSopenharmony_ci } 9162be168c0dSopenharmony_ci meta_graph_->allTensors.emplace_back(std::move(tensorT)); 9163be168c0dSopenharmony_ci if (!meta_graph_->subGraph.empty()) { 9164be168c0dSopenharmony_ci meta_graph_->subGraph[0]->tensorIndices.push_back(meta_graph_->allTensors.size() - 1); 9165be168c0dSopenharmony_ci } 9166be168c0dSopenharmony_ci } 9167be168c0dSopenharmony_ci+ for (auto &output_name : output_names) { 9168be168c0dSopenharmony_ci+ if (ordered_output_names.find(output_name) != ordered_output_names.end()) { 9169be168c0dSopenharmony_ci+ meta_graph_->outputIndex.push_back(ordered_output_names[output_name]); 9170be168c0dSopenharmony_ci+ if (!meta_graph_->subGraph.empty()) { 9171be168c0dSopenharmony_ci+ meta_graph_->subGraph[0]->outputIndices.push_back(ordered_output_names[output_name]); 9172be168c0dSopenharmony_ci+ } 9173be168c0dSopenharmony_ci+ } 9174be168c0dSopenharmony_ci+ } 9175be168c0dSopenharmony_ci return RET_OK; 9176be168c0dSopenharmony_ci } 9177be168c0dSopenharmony_ci 9178be168c0dSopenharmony_ci int TrainExport::ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, 9179be168c0dSopenharmony_ci const std::vector<mindspore::lite::Tensor *> &tensors, 9180be168c0dSopenharmony_ci+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9181be168c0dSopenharmony_ci const std::vector<std::string> &output_names, const Model *model, 9182be168c0dSopenharmony_ci QuantizationType quant_type, const Model *bb_model) { 9183be168c0dSopenharmony_ci std::vector<std::pair<size_t, tensor_info>> map_index; 9184be168c0dSopenharmony_ci@@ -498,7 +515,7 @@ int TrainExport::ExportNet(const std::vector<mindspore::kernel::KernelExec *> &k 9185be168c0dSopenharmony_ci } 9186be168c0dSopenharmony_ci } 9187be168c0dSopenharmony_ci 9188be168c0dSopenharmony_ci- auto status = ExportTensor(model, tensors, offset, map_index, output_names, out_set); 9189be168c0dSopenharmony_ci+ auto status = ExportTensor(model, tensors, offset, const_folded_output, map_index, output_names, out_set); 9190be168c0dSopenharmony_ci if (status != RET_OK) { 9191be168c0dSopenharmony_ci MS_LOG(ERROR) << "ExportTensor failed."; 9192be168c0dSopenharmony_ci return RET_ERROR; 9193be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_export.h b/mindspore/lite/src/train/train_export.h 9194be168c0dSopenharmony_ciindex b44f6526..8428c9b9 100644 9195be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_export.h 9196be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_export.h 9197be168c0dSopenharmony_ci@@ -47,8 +47,10 @@ class TrainExport { 9198be168c0dSopenharmony_ci explicit TrainExport(Buffer *model_buffer) : model_buffer_(model_buffer) {} 9199be168c0dSopenharmony_ci virtual ~TrainExport(); 9200be168c0dSopenharmony_ci int ExportNet(const std::vector<mindspore::kernel::KernelExec *> &kernels, 9201be168c0dSopenharmony_ci- const std::vector<mindspore::lite::Tensor *> &tensors, const std::vector<std::string> &output_names, 9202be168c0dSopenharmony_ci- const Model *model, QuantizationType quant_type, const Model *bb_model = nullptr); 9203be168c0dSopenharmony_ci+ const std::vector<mindspore::lite::Tensor *> &tensors, 9204be168c0dSopenharmony_ci+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9205be168c0dSopenharmony_ci+ const std::vector<std::string> &output_names, const Model *model, QuantizationType quant_type, 9206be168c0dSopenharmony_ci+ const Model *bb_model = nullptr); 9207be168c0dSopenharmony_ci int ExportInit(const std::string model_name, std::string version); 9208be168c0dSopenharmony_ci int SaveToFile(); 9209be168c0dSopenharmony_ci int SaveToBuffer(); 9210be168c0dSopenharmony_ci@@ -75,7 +77,9 @@ class TrainExport { 9211be168c0dSopenharmony_ci int TopologicalSort(); 9212be168c0dSopenharmony_ci void PrepareRemap(int offset); 9213be168c0dSopenharmony_ci LiteGraph::Node *FindNode(const mindspore::kernel::KernelExec *kernel, const Model *model); 9214be168c0dSopenharmony_ci- std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, schema::Tensor *scTensor, int preferred_dim, 9215be168c0dSopenharmony_ci+ std::unique_ptr<schema::TensorT> CreateTensor(const Tensor *tensor, 9216be168c0dSopenharmony_ci+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9217be168c0dSopenharmony_ci+ schema::Tensor *scTensor, int preferred_dim, 9218be168c0dSopenharmony_ci const int tensor_quant_type); 9219be168c0dSopenharmony_ci std::unique_ptr<schema::CNodeT> CreateCNode(const mindspore::kernel::KernelExec *kernel, 9220be168c0dSopenharmony_ci std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex, 9221be168c0dSopenharmony_ci@@ -93,6 +97,7 @@ class TrainExport { 9222be168c0dSopenharmony_ci size_t *target_index); 9223be168c0dSopenharmony_ci int KeepGraphInputsInOrder(const Model *model); 9224be168c0dSopenharmony_ci int ExportTensor(const Model *model, const std::vector<mindspore::lite::Tensor *> &tensors, int offset, 9225be168c0dSopenharmony_ci+ const std::vector<mindspore::lite::Tensor *> const_folded_output, 9226be168c0dSopenharmony_ci const std::vector<std::pair<size_t, tensor_info>> &map_index, 9227be168c0dSopenharmony_ci const std::vector<std::string> &output_names, const std::set<size_t> &out_set); 9228be168c0dSopenharmony_ci virtual int QuantTensorData(schema::TensorT *dest_tensor, const mindspore::lite::Tensor *src_tensor, 9229be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 9230be168c0dSopenharmony_ciindex b581b389..c123cba8 100644 9231be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_session.cc 9232be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_session.cc 9233be168c0dSopenharmony_ci@@ -399,6 +399,8 @@ int TrainSession::CompileTrainGraph(std::shared_ptr<Model> model) { 9234be168c0dSopenharmony_ci MS_LOG(ERROR) << "failed to allocate space"; 9235be168c0dSopenharmony_ci return RET_ERROR; 9236be168c0dSopenharmony_ci } 9237be168c0dSopenharmony_ci+ // Prepare a list of kernels which are const folded 9238be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(CompileConstFoldedKernels() == RET_OK, RET_ERROR, "CompileConstFoldedKernels failed."); 9239be168c0dSopenharmony_ci return RET_OK; 9240be168c0dSopenharmony_ci } 9241be168c0dSopenharmony_ci 9242be168c0dSopenharmony_ci@@ -697,20 +699,30 @@ void TrainSession::CompileEvalOutputs() { 9243be168c0dSopenharmony_ci } 9244be168c0dSopenharmony_ci if (is_loss) continue; 9245be168c0dSopenharmony_ci // insert if not already in 9246be168c0dSopenharmony_ci- if (eval_output_node_map_.find(in_kernel->name()) == eval_output_node_map_.end()) { 9247be168c0dSopenharmony_ci- auto *ms_tensor = in_kernel->out_tensors().at(0); 9248be168c0dSopenharmony_ci- if (ms_tensor != nullptr) { 9249be168c0dSopenharmony_ci- ms_tensor->set_init_ref_count(ms_tensor->init_ref_count() + 1); 9250be168c0dSopenharmony_ci- eval_output_node_map_[in_kernel->name()].emplace_back(ms_tensor); 9251be168c0dSopenharmony_ci- auto index = TSFindTensor(tensors_, ms_tensor); 9252be168c0dSopenharmony_ci- if (index != tensors_.size()) { 9253be168c0dSopenharmony_ci- if (!ms_tensor->tensor_name().empty()) { 9254be168c0dSopenharmony_ci- eval_output_tensor_map_.insert(std::make_pair(ms_tensor->tensor_name(), ms_tensor)); 9255be168c0dSopenharmony_ci- eval_output_tensor_names_.emplace_back(ms_tensor->tensor_name()); 9256be168c0dSopenharmony_ci- } else { 9257be168c0dSopenharmony_ci- eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), ms_tensor)); 9258be168c0dSopenharmony_ci- eval_output_tensor_names_.emplace_back(std::to_string(index)); 9259be168c0dSopenharmony_ci- } 9260be168c0dSopenharmony_ci+ auto out_tensors = TSFindTensors(in_kernel, kernel); 9261be168c0dSopenharmony_ci+ if (eval_output_node_map_.find(in_kernel->name()) != eval_output_node_map_.end()) { 9262be168c0dSopenharmony_ci+ auto exist_out_tensors = eval_output_node_map_[in_kernel->name()]; 9263be168c0dSopenharmony_ci+ std::vector<Tensor *> all_out_tensors; 9264be168c0dSopenharmony_ci+ auto kernel_all_out_tensors = in_kernel->out_tensors(); 9265be168c0dSopenharmony_ci+ eval_output_node_map_[in_kernel->name()] = {}; 9266be168c0dSopenharmony_ci+ for (auto tensor : kernel_all_out_tensors) { 9267be168c0dSopenharmony_ci+ if (std::find(out_tensors.begin(), out_tensors.end(), tensor) != out_tensors.end() || 9268be168c0dSopenharmony_ci+ std::find(exist_out_tensors.begin(), exist_out_tensors.end(), tensor) != exist_out_tensors.end()) { 9269be168c0dSopenharmony_ci+ eval_output_node_map_[in_kernel->name()].emplace_back(tensor); 9270be168c0dSopenharmony_ci+ } 9271be168c0dSopenharmony_ci+ } 9272be168c0dSopenharmony_ci+ } else { 9273be168c0dSopenharmony_ci+ eval_output_node_map_[in_kernel->name()] = out_tensors; 9274be168c0dSopenharmony_ci+ } 9275be168c0dSopenharmony_ci+ for (auto out_tensor : out_tensors) { 9276be168c0dSopenharmony_ci+ auto index = TSFindTensor(tensors_, out_tensor); 9277be168c0dSopenharmony_ci+ if (index != tensors_.size()) { 9278be168c0dSopenharmony_ci+ if (!out_tensor->tensor_name().empty()) { 9279be168c0dSopenharmony_ci+ eval_output_tensor_map_.insert(std::make_pair(out_tensor->tensor_name(), out_tensor)); 9280be168c0dSopenharmony_ci+ eval_output_tensor_names_.emplace_back(out_tensor->tensor_name()); 9281be168c0dSopenharmony_ci+ } else { 9282be168c0dSopenharmony_ci+ eval_output_tensor_map_.insert(std::make_pair(std::to_string(index), out_tensor)); 9283be168c0dSopenharmony_ci+ eval_output_tensor_names_.emplace_back(std::to_string(index)); 9284be168c0dSopenharmony_ci } 9285be168c0dSopenharmony_ci } 9286be168c0dSopenharmony_ci } 9287be168c0dSopenharmony_ci@@ -863,6 +875,35 @@ void TrainSession::CompileOptimizedKernels() { 9288be168c0dSopenharmony_ci } 9289be168c0dSopenharmony_ci } 9290be168c0dSopenharmony_ci 9291be168c0dSopenharmony_ci+int TrainSession::CompileConstFoldedKernels() { 9292be168c0dSopenharmony_ci+ const_output_tensors_.clear(); 9293be168c0dSopenharmony_ci+ for (auto kernel : this->inference_kernels_) { 9294be168c0dSopenharmony_ci+ bool is_input_const = true; 9295be168c0dSopenharmony_ci+ for (auto input : kernel->in_tensors()) { 9296be168c0dSopenharmony_ci+ if ((!input->IsConst() || input->IsGraphInput()) && 9297be168c0dSopenharmony_ci+ std::find(const_output_tensors_.begin(), const_output_tensors_.end(), input) == const_output_tensors_.end()) { 9298be168c0dSopenharmony_ci+ is_input_const = false; 9299be168c0dSopenharmony_ci+ } 9300be168c0dSopenharmony_ci+ if (!is_input_const) { 9301be168c0dSopenharmony_ci+ const_fold_kernels_.emplace_back(kernel); 9302be168c0dSopenharmony_ci+ break; 9303be168c0dSopenharmony_ci+ } 9304be168c0dSopenharmony_ci+ } 9305be168c0dSopenharmony_ci+ if (is_input_const) { 9306be168c0dSopenharmony_ci+ auto ret = kernel->Execute(); 9307be168c0dSopenharmony_ci+ if (RET_OK != ret) { 9308be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name(); 9309be168c0dSopenharmony_ci+ return ret; 9310be168c0dSopenharmony_ci+ } 9311be168c0dSopenharmony_ci+ for (auto output : kernel->out_tensors()) { 9312be168c0dSopenharmony_ci+ const_output_tensors_.emplace_back(output); 9313be168c0dSopenharmony_ci+ output->set_category(Category::CONST_TENSOR); 9314be168c0dSopenharmony_ci+ } 9315be168c0dSopenharmony_ci+ } 9316be168c0dSopenharmony_ci+ } 9317be168c0dSopenharmony_ci+ return RET_OK; 9318be168c0dSopenharmony_ci+} 9319be168c0dSopenharmony_ci+ 9320be168c0dSopenharmony_ci void TrainSession::CompileTrainableParams() { 9321be168c0dSopenharmony_ci for (auto kernel : this->train_kernels_) { 9322be168c0dSopenharmony_ci if (!IsOptimizer(kernel)) { 9323be168c0dSopenharmony_ci@@ -1214,9 +1255,10 @@ int TrainSession::ExportByDifferentType(DestType destination, ModelType model_ty 9324be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 9325be168c0dSopenharmony_ci if (!output_tensor_name.empty() && model_type == MT_INFERENCE) { 9326be168c0dSopenharmony_ci std::vector<kernel::KernelExec *> export_kernels = {}; 9327be168c0dSopenharmony_ci- status = FindExportKernels(&export_kernels, output_tensor_name, inference_kernels_); 9328be168c0dSopenharmony_ci+ status = FindExportKernels(&export_kernels, output_tensor_name, const_fold_kernels_); 9329be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); 9330be168c0dSopenharmony_ci- status = texport.ExportNet(export_kernels, tensors_, output_tensor_name, model_.get(), quant_type); 9331be168c0dSopenharmony_ci+ status = 9332be168c0dSopenharmony_ci+ texport.ExportNet(export_kernels, tensors_, const_output_tensors_, output_tensor_name, model_.get(), quant_type); 9333be168c0dSopenharmony_ci } else { 9334be168c0dSopenharmony_ci if (!output_tensor_name.empty() && model_type == MT_TRAIN) { 9335be168c0dSopenharmony_ci MS_LOG(WARNING) << "Train model does not support to export selected output tensor, and all of the train kernels " 9336be168c0dSopenharmony_ci@@ -1234,9 +1276,15 @@ int TrainSession::ExportByDifferentType(DestType destination, ModelType model_ty 9337be168c0dSopenharmony_ci } 9338be168c0dSopenharmony_ci return status; 9339be168c0dSopenharmony_ci } else { 9340be168c0dSopenharmony_ci- status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, 9341be168c0dSopenharmony_ci- (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, 9342be168c0dSopenharmony_ci- model_.get(), quant_type); 9343be168c0dSopenharmony_ci+ if (quant_type == QT_NONE) { 9344be168c0dSopenharmony_ci+ status = texport.ExportNet( 9345be168c0dSopenharmony_ci+ (model_type == MT_TRAIN) ? train_kernels_ : const_fold_kernels_, tensors_, const_output_tensors_, 9346be168c0dSopenharmony_ci+ (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_.get(), quant_type); 9347be168c0dSopenharmony_ci+ } else { 9348be168c0dSopenharmony_ci+ status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, {}, 9349be168c0dSopenharmony_ci+ (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, 9350be168c0dSopenharmony_ci+ model_.get(), quant_type); 9351be168c0dSopenharmony_ci+ } 9352be168c0dSopenharmony_ci } 9353be168c0dSopenharmony_ci } 9354be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 9355be168c0dSopenharmony_ci@@ -1322,14 +1370,13 @@ int TrainSession::ExportWeightsCollaborateWithMicro(const std::string &file_name 9356be168c0dSopenharmony_ci MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty"); 9357be168c0dSopenharmony_ci MS_CHECK_FALSE_MSG(model_type != mindspore::lite::MT_INFERENCE, RET_ERROR, 9358be168c0dSopenharmony_ci "Currently, can only export inference-model's weights."); 9359be168c0dSopenharmony_ci- int status = Eval(); 9360be168c0dSopenharmony_ci- TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed"); 9361be168c0dSopenharmony_ci 9362be168c0dSopenharmony_ci TrainExport texport(file_name); 9363be168c0dSopenharmony_ci- status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 9364be168c0dSopenharmony_ci+ auto status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 9365be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 9366be168c0dSopenharmony_ci 9367be168c0dSopenharmony_ci- status = texport.ExportNet(inference_kernels_, tensors_, eval_output_tensor_names_, model_.get(), QT_DEFAULT); 9368be168c0dSopenharmony_ci+ status = texport.ExportNet(const_fold_kernels_, tensors_, const_output_tensors_, eval_output_tensor_names_, 9369be168c0dSopenharmony_ci+ model_.get(), QT_NONE); 9370be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 9371be168c0dSopenharmony_ci status = texport.TrainModelDrop(); 9372be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); 9373be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h 9374be168c0dSopenharmony_ciindex 24f10065..0bd14b21 100644 9375be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_session.h 9376be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_session.h 9377be168c0dSopenharmony_ci@@ -128,6 +128,7 @@ class TrainSession : virtual public lite::LiteSession { 9378be168c0dSopenharmony_ci virtual int CompileInferenceKernels(); 9379be168c0dSopenharmony_ci virtual void CompileOptimizedKernels(); 9380be168c0dSopenharmony_ci virtual void CompileTrainableParams(); 9381be168c0dSopenharmony_ci+ virtual int CompileConstFoldedKernels(); 9382be168c0dSopenharmony_ci virtual void CompileTrainOutputs(); 9383be168c0dSopenharmony_ci virtual void CompileEvalOutputs(); 9384be168c0dSopenharmony_ci virtual int InitCallBack(); 9385be168c0dSopenharmony_ci@@ -146,6 +147,8 @@ class TrainSession : virtual public lite::LiteSession { 9386be168c0dSopenharmony_ci 9387be168c0dSopenharmony_ci std::vector<kernel::KernelExec *> inference_kernels_; 9388be168c0dSopenharmony_ci std::vector<kernel::KernelExec *> train_kernels_; 9389be168c0dSopenharmony_ci+ std::vector<kernel::KernelExec *> const_fold_kernels_; 9390be168c0dSopenharmony_ci+ std::vector<lite::Tensor *> const_output_tensors_; 9391be168c0dSopenharmony_ci TrainCfg cfg_; 9392be168c0dSopenharmony_ci 9393be168c0dSopenharmony_ci private: 9394be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_utils.cc b/mindspore/lite/src/train/train_utils.cc 9395be168c0dSopenharmony_ciindex 32c4a502..cb7b669a 100644 9396be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_utils.cc 9397be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_utils.cc 9398be168c0dSopenharmony_ci@@ -204,5 +204,20 @@ int ScaleTensor(Tensor *tensor, float scale) { 9399be168c0dSopenharmony_ci MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale; 9400be168c0dSopenharmony_ci return tensor->Scale<float>(scale); 9401be168c0dSopenharmony_ci } 9402be168c0dSopenharmony_ci+ 9403be168c0dSopenharmony_ci+std::vector<Tensor *> TSFindTensors(const kernel::KernelExec *pre_kernel, const kernel::KernelExec *post_kernel) { 9404be168c0dSopenharmony_ci+ MS_ASSERT(pre_kernel != nullptr); 9405be168c0dSopenharmony_ci+ MS_ASSERT(post_kernel != nullptr); 9406be168c0dSopenharmony_ci+ auto out_tensors = pre_kernel->out_tensors(); 9407be168c0dSopenharmony_ci+ auto in_tensors = post_kernel->in_tensors(); 9408be168c0dSopenharmony_ci+ std::vector<Tensor *> res; 9409be168c0dSopenharmony_ci+ for (auto tensor : out_tensors) { 9410be168c0dSopenharmony_ci+ if (std::find(in_tensors.begin(), in_tensors.end(), tensor) == in_tensors.end()) { 9411be168c0dSopenharmony_ci+ continue; 9412be168c0dSopenharmony_ci+ } 9413be168c0dSopenharmony_ci+ res.push_back(tensor); 9414be168c0dSopenharmony_ci+ } 9415be168c0dSopenharmony_ci+ return res; 9416be168c0dSopenharmony_ci+} 9417be168c0dSopenharmony_ci } // namespace lite 9418be168c0dSopenharmony_ci } // namespace mindspore 9419be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_utils.h b/mindspore/lite/src/train/train_utils.h 9420be168c0dSopenharmony_ciindex 5c85738f..9b2d62dc 100644 9421be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_utils.h 9422be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_utils.h 9423be168c0dSopenharmony_ci@@ -36,6 +36,7 @@ float CalculateSparseClassification(lite::Tensor *input, lite::Tensor *output); 9424be168c0dSopenharmony_ci float CalculateOneHotClassification(lite::Tensor *input, lite::Tensor *output); 9425be168c0dSopenharmony_ci Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16); 9426be168c0dSopenharmony_ci int ScaleTensor(Tensor *tensor, float scale); 9427be168c0dSopenharmony_ci+std::vector<Tensor *> TSFindTensors(const kernel::KernelExec *pre_kernel, const kernel::KernelExec *post_kernel); 9428be168c0dSopenharmony_ci } // namespace lite 9429be168c0dSopenharmony_ci } // namespace mindspore 9430be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_UTILS_H_ 9431be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc 9432be168c0dSopenharmony_ciindex 48191b4f..b1cb7b3e 100644 9433be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/transfer_session.cc 9434be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/transfer_session.cc 9435be168c0dSopenharmony_ci@@ -230,10 +230,10 @@ int TransferSession::ExportInner(DestType destination, ModelType model_type, Qua 9436be168c0dSopenharmony_ci MS_LOG(ERROR) << "FindExportKernels failed."; 9437be168c0dSopenharmony_ci return RET_ERROR; 9438be168c0dSopenharmony_ci } 9439be168c0dSopenharmony_ci- status = texport.ExportNet(export_kernels, tensors_, out_put_tensor_name, model_.get(), quant_type, 9440be168c0dSopenharmony_ci+ status = texport.ExportNet(export_kernels, tensors_, {}, out_put_tensor_name, model_.get(), quant_type, 9441be168c0dSopenharmony_ci backbone_session_->model_); 9442be168c0dSopenharmony_ci } else { 9443be168c0dSopenharmony_ci- status = texport.ExportNet(inference_kernels_, tensors_, GetOutputTensorNames(), model_.get(), quant_type, 9444be168c0dSopenharmony_ci+ status = texport.ExportNet(inference_kernels_, tensors_, {}, GetOutputTensorNames(), model_.get(), quant_type, 9445be168c0dSopenharmony_ci backbone_session_->model_); 9446be168c0dSopenharmony_ci } 9447be168c0dSopenharmony_ci if (status != RET_OK) { 9448be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/common/string_util.cc b/mindspore/lite/tools/common/string_util.cc 9449be168c0dSopenharmony_ciindex 8d7076e5..13cddb3a 100644 9450be168c0dSopenharmony_ci--- a/mindspore/lite/tools/common/string_util.cc 9451be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/common/string_util.cc 9452be168c0dSopenharmony_ci@@ -199,5 +199,9 @@ size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size 9453be168c0dSopenharmony_ci } 9454be168c0dSopenharmony_ci return byte_len; 9455be168c0dSopenharmony_ci } 9456be168c0dSopenharmony_ci+ 9457be168c0dSopenharmony_ci+bool IsNumber(const std::string &item) { 9458be168c0dSopenharmony_ci+ return std::all_of(item.begin(), item.end(), [](char ch) { return ch >= '0' && ch <= '9'; }); 9459be168c0dSopenharmony_ci+} 9460be168c0dSopenharmony_ci } // namespace lite 9461be168c0dSopenharmony_ci } // namespace mindspore 9462be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/common/string_util.h b/mindspore/lite/tools/common/string_util.h 9463be168c0dSopenharmony_ciindex 0fb9c0b2..95bdd742 100644 9464be168c0dSopenharmony_ci--- a/mindspore/lite/tools/common/string_util.h 9465be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/common/string_util.h 9466be168c0dSopenharmony_ci@@ -45,6 +45,8 @@ bool ConvertBool(std::string str, bool *value); 9467be168c0dSopenharmony_ci bool ConvertDoubleVector(const std::string &str, std::vector<double> *value); 9468be168c0dSopenharmony_ci 9469be168c0dSopenharmony_ci size_t Hex2ByteArray(const std::string &hex_str, unsigned char *byte_array, size_t max_len); 9470be168c0dSopenharmony_ci+ 9471be168c0dSopenharmony_ci+bool IsNumber(const std::string &item); 9472be168c0dSopenharmony_ci } // namespace lite 9473be168c0dSopenharmony_ci } // namespace mindspore 9474be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_TOOLS_COMMON_STRING_UTIL_H_ 9475be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc 9476be168c0dSopenharmony_ciindex c4f84163..b63912fa 100644 9477be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/anf_transform.cc 9478be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/anf_transform.cc 9479be168c0dSopenharmony_ci@@ -135,6 +135,7 @@ 9480be168c0dSopenharmony_ci #include "tools/common/string_util.h" 9481be168c0dSopenharmony_ci #include "src/common/common.h" 9482be168c0dSopenharmony_ci #include "tools/optimizer/graph/miniaturization_pass.h" 9483be168c0dSopenharmony_ci+#include "tools/optimizer/fusion/tile_matmul_fusion.h" 9484be168c0dSopenharmony_ci 9485be168c0dSopenharmony_ci using std::string; 9486be168c0dSopenharmony_ci namespace mindspore::lite { 9487be168c0dSopenharmony_ci@@ -317,7 +318,8 @@ std::vector<opt::PassPtr> InitFusions(const std::shared_ptr<ConverterPara> ¶ 9488be168c0dSopenharmony_ci std::make_shared<opt::MulActivationFusion>(), 9489be168c0dSopenharmony_ci std::make_shared<opt::AddActivationFusion>(), 9490be168c0dSopenharmony_ci std::make_shared<opt::ExpandDimsReshapeFusion>(), 9491be168c0dSopenharmony_ci- std::make_shared<opt::SqueezeExpandDimsFusion>()}; 9492be168c0dSopenharmony_ci+ std::make_shared<opt::SqueezeExpandDimsFusion>(), 9493be168c0dSopenharmony_ci+ std::make_shared<opt::TileMatMulFusion>()}; 9494be168c0dSopenharmony_ci if (param->optimize_transformer) { 9495be168c0dSopenharmony_ci fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>()); 9496be168c0dSopenharmony_ci fusions.push_back(std::make_shared<opt::EncoderLayerFusion>()); 9497be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9498be168c0dSopenharmony_ciindex 2e7ca749..7b47fb8c 100644 9499be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9500be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9501be168c0dSopenharmony_ci@@ -19,10 +19,10 @@ 9502be168c0dSopenharmony_ci #include "include/errorcode.h" 9503be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 9504be168c0dSopenharmony_ci #include "tools/converter/converter_context.h" 9505be168c0dSopenharmony_ci- 9506be168c0dSopenharmony_ci #include "tools/common/string_util.h" 9507be168c0dSopenharmony_ci #include "src/common/config_infos.h" 9508be168c0dSopenharmony_ci #include "src/common/common.h" 9509be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 9510be168c0dSopenharmony_ci 9511be168c0dSopenharmony_ci namespace mindspore { 9512be168c0dSopenharmony_ci namespace lite { 9513be168c0dSopenharmony_ci@@ -208,6 +208,75 @@ void SetDynParams(const std::shared_ptr<mindspore::ConverterPara> ¶m, 9514be168c0dSopenharmony_ci } 9515be168c0dSopenharmony_ci } 9516be168c0dSopenharmony_ci 9517be168c0dSopenharmony_ci+int ParseInputShapeTemplate(const std::string &shape_template, std::set<std::string> *dynamic_symbols) { 9518be168c0dSopenharmony_ci+ // the inputs_shape config is like: input1:[d0,d1,3];input2:[4,d0] 9519be168c0dSopenharmony_ci+ auto graph_inputs_shape_vec = SplitStringToVector(shape_template, ';'); 9520be168c0dSopenharmony_ci+ for (const auto &graph_input_shape : graph_inputs_shape_vec) { 9521be168c0dSopenharmony_ci+ auto graph_input_shape_info = SplitStringToVector(graph_input_shape, ':'); 9522be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(graph_input_shape_info.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the inputs_shape is invalid"); 9523be168c0dSopenharmony_ci+ auto input_shape = graph_input_shape_info[1]; 9524be168c0dSopenharmony_ci+ if (input_shape[0] != '[' || input_shape[input_shape.size() - 1] != ']') { 9525be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "the inputs_shape is invalid"; 9526be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9527be168c0dSopenharmony_ci+ } 9528be168c0dSopenharmony_ci+ input_shape = input_shape.substr(1, input_shape.size() - kIndex2); 9529be168c0dSopenharmony_ci+ auto input_shape_vec = SplitStringToVector(input_shape, ','); 9530be168c0dSopenharmony_ci+ for (const auto &shape : input_shape_vec) { 9531be168c0dSopenharmony_ci+ if (!IsNumber(shape)) { 9532be168c0dSopenharmony_ci+ dynamic_symbols->insert(shape); 9533be168c0dSopenharmony_ci+ } 9534be168c0dSopenharmony_ci+ } 9535be168c0dSopenharmony_ci+ } 9536be168c0dSopenharmony_ci+ return RET_OK; 9537be168c0dSopenharmony_ci+} 9538be168c0dSopenharmony_ci+ 9539be168c0dSopenharmony_ci+int ParseDynmiacDimTemplate(const std::string &dims_template, std::set<std::string> *dynamic_symbols, 9540be168c0dSopenharmony_ci+ MicroParamString *micro_param_string) { 9541be168c0dSopenharmony_ci+ // the dynamic_dim_params config is like: d0:[1,3~6];d1:[1~8] 9542be168c0dSopenharmony_ci+ auto dim_info_vec = SplitStringToVector(dims_template, ';'); 9543be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(dim_info_vec.size() <= kIndex2, RET_NOT_SUPPORT, "currently, only support to set two dynamic dims"); 9544be168c0dSopenharmony_ci+ for (const auto &dim_info : dim_info_vec) { 9545be168c0dSopenharmony_ci+ auto dim_vec = SplitStringToVector(dim_info, ':'); 9546be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(dim_vec.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the dynamic_dim_params is invalid"); 9547be168c0dSopenharmony_ci+ std::string symbol = dim_vec[0]; 9548be168c0dSopenharmony_ci+ if (dynamic_symbols->find(symbol) == dynamic_symbols->end()) { 9549be168c0dSopenharmony_ci+ MS_LOG(ERROR) << symbol << "is invalid, because it's not set in the inputs_shape."; 9550be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9551be168c0dSopenharmony_ci+ } 9552be168c0dSopenharmony_ci+ std::string dim_range = dim_vec[1]; 9553be168c0dSopenharmony_ci+ if (dim_range[0] != '[' || dim_range[dim_range.size() - 1] != ']') { 9554be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "the dynamic_dim_params is invalid"; 9555be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9556be168c0dSopenharmony_ci+ } 9557be168c0dSopenharmony_ci+ dim_range = dim_range.substr(1, dim_range.size() - kIndex2); 9558be168c0dSopenharmony_ci+ auto discrete_vec = SplitStringToVector(dim_range, ','); 9559be168c0dSopenharmony_ci+ for (const auto &dim : discrete_vec) { 9560be168c0dSopenharmony_ci+ auto continuous_dim = SplitStringToVector(dim, '~'); 9561be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(continuous_dim.size() == C1NUM || continuous_dim.size() == kIndex2, RET_INPUT_PARAM_INVALID, 9562be168c0dSopenharmony_ci+ "the dynamic_dim_params is invalid"); 9563be168c0dSopenharmony_ci+ if (continuous_dim.size() == C1NUM) { 9564be168c0dSopenharmony_ci+ if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0) { 9565be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0"; 9566be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9567be168c0dSopenharmony_ci+ } 9568be168c0dSopenharmony_ci+ micro_param_string->dynamic_symbols_map[symbol] += continuous_dim[0] + ","; 9569be168c0dSopenharmony_ci+ continue; 9570be168c0dSopenharmony_ci+ } 9571be168c0dSopenharmony_ci+ if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0 || !IsNumber(continuous_dim[1]) || 9572be168c0dSopenharmony_ci+ std::stoi(continuous_dim[1]) <= 0) { 9573be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0"; 9574be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9575be168c0dSopenharmony_ci+ } 9576be168c0dSopenharmony_ci+ auto start = std::stoi(continuous_dim[0]); 9577be168c0dSopenharmony_ci+ auto end = std::stoi(continuous_dim[1]); 9578be168c0dSopenharmony_ci+ for (auto i = start; i <= end; ++i) { 9579be168c0dSopenharmony_ci+ micro_param_string->dynamic_symbols_map[symbol] += std::to_string(i) + ","; 9580be168c0dSopenharmony_ci+ } 9581be168c0dSopenharmony_ci+ } 9582be168c0dSopenharmony_ci+ } 9583be168c0dSopenharmony_ci+ return RET_OK; 9584be168c0dSopenharmony_ci+} 9585be168c0dSopenharmony_ci+ 9586be168c0dSopenharmony_ci void ConfigFileParser::SetParamByConfigfile(const std::shared_ptr<mindspore::ConverterPara> ¶m, 9587be168c0dSopenharmony_ci const std::map<std::string, std::string> &ascend_map) { 9588be168c0dSopenharmony_ci std::string ascend_string = ""; 9589be168c0dSopenharmony_ci@@ -377,8 +446,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin 9590be168c0dSopenharmony_ci } 9591be168c0dSopenharmony_ci 9592be168c0dSopenharmony_ci int ConfigFileParser::SetMapData(const std::map<std::string, std::string> &input_map, 9593be168c0dSopenharmony_ci- const std::map<std::string, std::string &> &parse_map, const std::string §ion) { 9594be168c0dSopenharmony_ci+ const std::map<std::string, std::string &> &parse_map, const std::string §ion, 9595be168c0dSopenharmony_ci+ const std::set<std::string> &dynamic_key) { 9596be168c0dSopenharmony_ci for (const auto &map : input_map) { 9597be168c0dSopenharmony_ci+ if (dynamic_key.find(map.first) != dynamic_key.end()) { 9598be168c0dSopenharmony_ci+ continue; 9599be168c0dSopenharmony_ci+ } 9600be168c0dSopenharmony_ci if (parse_map.find(map.first) == parse_map.end()) { 9601be168c0dSopenharmony_ci MS_LOG(ERROR) << "INPUT ILLEGAL: `" << map.first << "` is not supported in " 9602be168c0dSopenharmony_ci << "[" << section << "]"; 9603be168c0dSopenharmony_ci@@ -511,21 +584,34 @@ int ConfigFileParser::ParseAclOptionCfgString(const std::map<std::string, std::m 9604be168c0dSopenharmony_ci } 9605be168c0dSopenharmony_ci 9606be168c0dSopenharmony_ci int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps) { 9607be168c0dSopenharmony_ci- if (maps.find(kMicroParam) != maps.end()) { 9608be168c0dSopenharmony_ci- const auto &map = maps.at(kMicroParam); 9609be168c0dSopenharmony_ci- std::map<std::string, std::string &> parse_map{ 9610be168c0dSopenharmony_ci- {"target", micro_param_string_.target}, 9611be168c0dSopenharmony_ci- {"codegen_mode", micro_param_string_.codegen_mode}, 9612be168c0dSopenharmony_ci- {"debug_mode", micro_param_string_.debug_mode}, 9613be168c0dSopenharmony_ci- {"support_parallel", micro_param_string_.support_parallel}, 9614be168c0dSopenharmony_ci- {"enable_micro", micro_param_string_.enable_micro}, 9615be168c0dSopenharmony_ci- {"save_path", micro_param_string_.save_path}, 9616be168c0dSopenharmony_ci- {"project_name", micro_param_string_.project_name}, 9617be168c0dSopenharmony_ci- {"keep_original_weight", micro_param_string_.keep_original_weight}, 9618be168c0dSopenharmony_ci- {"changeable_weights_name", micro_param_string_.changeable_weights_name}}; 9619be168c0dSopenharmony_ci- return SetMapData(map, parse_map, kMicroParam); 9620be168c0dSopenharmony_ci+ if (maps.find(kMicroParam) == maps.end()) { 9621be168c0dSopenharmony_ci+ return RET_OK; 9622be168c0dSopenharmony_ci } 9623be168c0dSopenharmony_ci- return RET_OK; 9624be168c0dSopenharmony_ci+ const auto &map = maps.at(kMicroParam); 9625be168c0dSopenharmony_ci+ const std::string graph_inputs_shape_template = "inputs_shape"; 9626be168c0dSopenharmony_ci+ std::set<std::string> dynamic_symbols; 9627be168c0dSopenharmony_ci+ if (map.find(graph_inputs_shape_template) != map.end()) { 9628be168c0dSopenharmony_ci+ const auto &shape_template = map.at(graph_inputs_shape_template); 9629be168c0dSopenharmony_ci+ ParseInputShapeTemplate(shape_template, &dynamic_symbols); 9630be168c0dSopenharmony_ci+ } 9631be168c0dSopenharmony_ci+ const std::string dynamic_dims = "dynamic_dim_params"; 9632be168c0dSopenharmony_ci+ if (!dynamic_symbols.empty() && map.find(dynamic_dims) != map.end()) { 9633be168c0dSopenharmony_ci+ const auto &dims_template = map.at(dynamic_dims); 9634be168c0dSopenharmony_ci+ ParseDynmiacDimTemplate(dims_template, &dynamic_symbols, µ_param_string_); 9635be168c0dSopenharmony_ci+ } 9636be168c0dSopenharmony_ci+ std::map<std::string, std::string &> parse_map{ 9637be168c0dSopenharmony_ci+ {"target", micro_param_string_.target}, 9638be168c0dSopenharmony_ci+ {"codegen_mode", micro_param_string_.codegen_mode}, 9639be168c0dSopenharmony_ci+ {"debug_mode", micro_param_string_.debug_mode}, 9640be168c0dSopenharmony_ci+ {"support_parallel", micro_param_string_.support_parallel}, 9641be168c0dSopenharmony_ci+ {"enable_micro", micro_param_string_.enable_micro}, 9642be168c0dSopenharmony_ci+ {"save_path", micro_param_string_.save_path}, 9643be168c0dSopenharmony_ci+ {"project_name", micro_param_string_.project_name}, 9644be168c0dSopenharmony_ci+ {"keep_original_weight", micro_param_string_.keep_original_weight}, 9645be168c0dSopenharmony_ci+ {"changeable_weights_name", micro_param_string_.changeable_weights_name}, 9646be168c0dSopenharmony_ci+ {"inputs_shape", micro_param_string_.inputs_shape}, 9647be168c0dSopenharmony_ci+ {"dynamic_dim_params", micro_param_string_.dynamic_dim_params}}; 9648be168c0dSopenharmony_ci+ return SetMapData(map, parse_map, kMicroParam); 9649be168c0dSopenharmony_ci } 9650be168c0dSopenharmony_ci 9651be168c0dSopenharmony_ci int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) { 9652be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9653be168c0dSopenharmony_ciindex 6997bac8..163782b7 100644 9654be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9655be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9656be168c0dSopenharmony_ci@@ -108,17 +108,20 @@ struct MicroParamString { 9657be168c0dSopenharmony_ci std::string project_name; 9658be168c0dSopenharmony_ci std::string keep_original_weight; 9659be168c0dSopenharmony_ci std::string changeable_weights_name; 9660be168c0dSopenharmony_ci+ std::string inputs_shape; 9661be168c0dSopenharmony_ci+ std::string dynamic_dim_params; 9662be168c0dSopenharmony_ci+ std::map<std::string, std::string> dynamic_symbols_map; 9663be168c0dSopenharmony_ci }; 9664be168c0dSopenharmony_ci 9665be168c0dSopenharmony_ci struct ThirdPartyModelString { 9666be168c0dSopenharmony_ci std::string input_dtypes; 9667be168c0dSopenharmony_ci std::string input_shapes; 9668be168c0dSopenharmony_ci- std::string input_names; // optional, default: "" 9669be168c0dSopenharmony_ci+ std::string input_names; // optional, default: "" 9670be168c0dSopenharmony_ci std::string input_formats; // optional, default: NHWC 9671be168c0dSopenharmony_ci std::string output_dtypes; 9672be168c0dSopenharmony_ci std::string output_shapes; 9673be168c0dSopenharmony_ci- std::string output_names; // optional, default: "" 9674be168c0dSopenharmony_ci- std::string output_formats; // optional, default: NHWC 9675be168c0dSopenharmony_ci+ std::string output_names; // optional, default: "" 9676be168c0dSopenharmony_ci+ std::string output_formats; // optional, default: NHWC 9677be168c0dSopenharmony_ci std::string extended_parameters; // format: {key1:value1;ker2:value2} 9678be168c0dSopenharmony_ci }; 9679be168c0dSopenharmony_ci 9680be168c0dSopenharmony_ci@@ -172,7 +175,8 @@ class ConfigFileParser { 9681be168c0dSopenharmony_ci int ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9682be168c0dSopenharmony_ci int ParseAclOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9683be168c0dSopenharmony_ci int SetMapData(const std::map<std::string, std::string> &input_map, 9684be168c0dSopenharmony_ci- const std::map<std::string, std::string &> &parse_map, const std::string §ion); 9685be168c0dSopenharmony_ci+ const std::map<std::string, std::string &> &parse_map, const std::string §ion, 9686be168c0dSopenharmony_ci+ const std::set<std::string> &dynamic_key = {}); 9687be168c0dSopenharmony_ci int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9688be168c0dSopenharmony_ci int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> §ions); 9689be168c0dSopenharmony_ci int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9690be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 9691be168c0dSopenharmony_ciindex c9998cc8..903f2863 100644 9692be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 9693be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.cc 9694be168c0dSopenharmony_ci@@ -19,6 +19,7 @@ 9695be168c0dSopenharmony_ci #include "tools/common/string_util.h" 9696be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 9697be168c0dSopenharmony_ci #include "src/common/log_util.h" 9698be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 9699be168c0dSopenharmony_ci 9700be168c0dSopenharmony_ci namespace mindspore { 9701be168c0dSopenharmony_ci namespace lite { 9702be168c0dSopenharmony_ci@@ -115,6 +116,80 @@ STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeabl 9703be168c0dSopenharmony_ci return RET_OK; 9704be168c0dSopenharmony_ci } 9705be168c0dSopenharmony_ci 9706be168c0dSopenharmony_ci+STATUS MicroParamParser::ParseGraphInputsShapeTemplate(const std::string &graph_inputs_shape_template, 9707be168c0dSopenharmony_ci+ const std::map<std::string, std::string> &dynamic_symbols_map, 9708be168c0dSopenharmony_ci+ micro::MicroParam *micro_param) { 9709be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Micro record inputs shape: " << graph_inputs_shape_template; 9710be168c0dSopenharmony_ci+ if (!graph_inputs_shape_template.empty()) { 9711be168c0dSopenharmony_ci+ auto graph_inputs_shape_vec = SplitStringToVector(graph_inputs_shape_template, ';'); 9712be168c0dSopenharmony_ci+ std::map<std::string, std::vector<std::string>> graph_inputs_info; 9713be168c0dSopenharmony_ci+ std::vector<std::vector<std::string>> graph_inputs_shape; 9714be168c0dSopenharmony_ci+ std::vector<std::string> inputs_name; 9715be168c0dSopenharmony_ci+ for (const auto &graph_input_shape : graph_inputs_shape_vec) { 9716be168c0dSopenharmony_ci+ auto input_shape_info = SplitStringToVector(graph_input_shape, ':'); 9717be168c0dSopenharmony_ci+ std::string input_name = input_shape_info[0]; 9718be168c0dSopenharmony_ci+ std::string input_shape = input_shape_info[1].substr(1, input_shape_info[1].size() - C2NUM); 9719be168c0dSopenharmony_ci+ auto input_shape_vec = SplitStringToVector(input_shape, ','); 9720be168c0dSopenharmony_ci+ graph_inputs_info[input_name] = input_shape_vec; 9721be168c0dSopenharmony_ci+ graph_inputs_shape.push_back(input_shape_vec); 9722be168c0dSopenharmony_ci+ inputs_name.push_back(input_name); 9723be168c0dSopenharmony_ci+ } 9724be168c0dSopenharmony_ci+ micro_param->graph_inputs_origin_info = graph_inputs_info; 9725be168c0dSopenharmony_ci+ micro_param->inputs_shape_by_scenes.clear(); 9726be168c0dSopenharmony_ci+ std::map<std::string, std::vector<int>> symbols_to_num; 9727be168c0dSopenharmony_ci+ std::map<std::string, int> symbols_index; 9728be168c0dSopenharmony_ci+ std::vector<std::string> symbols; 9729be168c0dSopenharmony_ci+ std::vector<size_t> scene_num_by_symbol; 9730be168c0dSopenharmony_ci+ int index = 0; 9731be168c0dSopenharmony_ci+ size_t scene_num = 1; 9732be168c0dSopenharmony_ci+ for (const auto &item : dynamic_symbols_map) { 9733be168c0dSopenharmony_ci+ symbols_index[item.first] = index++; 9734be168c0dSopenharmony_ci+ symbols.push_back(item.first); 9735be168c0dSopenharmony_ci+ auto num_str_list = SplitStringToVector(item.second, ','); 9736be168c0dSopenharmony_ci+ for (const auto &num_str : num_str_list) { 9737be168c0dSopenharmony_ci+ symbols_to_num[item.first].push_back(std::stoi(num_str)); 9738be168c0dSopenharmony_ci+ } 9739be168c0dSopenharmony_ci+ if (symbols_to_num[item.first].empty()) { 9740be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Micro param invalid, dynamic symbol must have value."; 9741be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9742be168c0dSopenharmony_ci+ } 9743be168c0dSopenharmony_ci+ scene_num_by_symbol.push_back(symbols_to_num[item.first].size()); 9744be168c0dSopenharmony_ci+ scene_num *= symbols_to_num[item.first].size(); 9745be168c0dSopenharmony_ci+ } 9746be168c0dSopenharmony_ci+ micro_param->dynamic_symbols = symbols; 9747be168c0dSopenharmony_ci+ micro_param->dynamic_symbols_num = scene_num_by_symbol; 9748be168c0dSopenharmony_ci+ std::vector<size_t> post_multi(symbols.size(), 1); 9749be168c0dSopenharmony_ci+ for (int i = static_cast<int>(post_multi.size()) - 2; i >= 0; --i) { 9750be168c0dSopenharmony_ci+ post_multi[i] = post_multi[i + 1] * scene_num_by_symbol[i + 1]; 9751be168c0dSopenharmony_ci+ } 9752be168c0dSopenharmony_ci+ std::vector<int> real_num(symbols.size()); 9753be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 9754be168c0dSopenharmony_ci+ size_t remain = i; 9755be168c0dSopenharmony_ci+ for (size_t j = 0; j < symbols.size(); ++j) { 9756be168c0dSopenharmony_ci+ real_num[j] = remain / post_multi[j]; 9757be168c0dSopenharmony_ci+ remain %= post_multi[j]; 9758be168c0dSopenharmony_ci+ } 9759be168c0dSopenharmony_ci+ for (size_t j = 0; j < graph_inputs_shape.size(); ++j) { 9760be168c0dSopenharmony_ci+ const auto &input_template = graph_inputs_shape[j]; 9761be168c0dSopenharmony_ci+ std::vector<int> input_shape; 9762be168c0dSopenharmony_ci+ for (const auto &dim : input_template) { 9763be168c0dSopenharmony_ci+ if (IsNumber(dim)) { 9764be168c0dSopenharmony_ci+ input_shape.push_back(std::stoi(dim)); 9765be168c0dSopenharmony_ci+ continue; 9766be168c0dSopenharmony_ci+ } 9767be168c0dSopenharmony_ci+ if (symbols_index.find(dim) == symbols_index.end()) { 9768be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Dynamic symbol cannot find real num."; 9769be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9770be168c0dSopenharmony_ci+ } 9771be168c0dSopenharmony_ci+ input_shape.push_back(symbols_to_num[dim][real_num[symbols_index[dim]]]); 9772be168c0dSopenharmony_ci+ } 9773be168c0dSopenharmony_ci+ micro_param->inputs_shape_by_scenes[inputs_name[j]].push_back(input_shape); 9774be168c0dSopenharmony_ci+ } 9775be168c0dSopenharmony_ci+ } 9776be168c0dSopenharmony_ci+ } 9777be168c0dSopenharmony_ci+ return RET_OK; 9778be168c0dSopenharmony_ci+} 9779be168c0dSopenharmony_ci+ 9780be168c0dSopenharmony_ci STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_string, micro::MicroParam *micro_param) { 9781be168c0dSopenharmony_ci CHECK_NULL_RETURN(micro_param); 9782be168c0dSopenharmony_ci if (ParseTarget(micro_param_string.target, micro_param) != RET_OK) { 9783be168c0dSopenharmony_ci@@ -145,9 +220,11 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_str 9784be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse project name val failed: " << micro_param_string.project_name; 9785be168c0dSopenharmony_ci return RET_INPUT_PARAM_INVALID; 9786be168c0dSopenharmony_ci } 9787be168c0dSopenharmony_ci- if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) { 9788be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Parse keep_original_weight failed, the val: " << micro_param_string.keep_original_weight; 9789be168c0dSopenharmony_ci- return RET_INPUT_PARAM_INVALID; 9790be168c0dSopenharmony_ci+ if (!micro_param_string.keep_original_weight.empty()) { 9791be168c0dSopenharmony_ci+ if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) { 9792be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse keep_original_weight val; " << micro_param_string.keep_original_weight; 9793be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9794be168c0dSopenharmony_ci+ } 9795be168c0dSopenharmony_ci } 9796be168c0dSopenharmony_ci if (!micro_param_string.changeable_weights_name.empty() && !micro_param->keep_original_weight) { 9797be168c0dSopenharmony_ci MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true."; 9798be168c0dSopenharmony_ci@@ -157,6 +234,12 @@ STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_str 9799be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse changeable_weights_name failed, the val: " << micro_param_string.changeable_weights_name; 9800be168c0dSopenharmony_ci return RET_INPUT_PARAM_INVALID; 9801be168c0dSopenharmony_ci } 9802be168c0dSopenharmony_ci+ if (ParseGraphInputsShapeTemplate(micro_param_string.inputs_shape, micro_param_string.dynamic_symbols_map, 9803be168c0dSopenharmony_ci+ micro_param) != RET_OK) { 9804be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse inputs_shape & dynamic_dim_params failed, the inputs_shape val: " 9805be168c0dSopenharmony_ci+ << micro_param_string.inputs_shape; 9806be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9807be168c0dSopenharmony_ci+ } 9808be168c0dSopenharmony_ci return RET_OK; 9809be168c0dSopenharmony_ci } 9810be168c0dSopenharmony_ci } // namespace lite 9811be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 9812be168c0dSopenharmony_ciindex b6efb4c7..eb95c571 100644 9813be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 9814be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/micro_param_parser.h 9815be168c0dSopenharmony_ci@@ -37,6 +37,9 @@ class MicroParamParser { 9816be168c0dSopenharmony_ci STATUS ParseProjName(const std::string &debug_mode, micro::MicroParam *micro_param); 9817be168c0dSopenharmony_ci STATUS ParseKeepOriginalWeight(const std::string &keep_weight, micro::MicroParam *micro_param); 9818be168c0dSopenharmony_ci STATUS ParseChangeableWeightsName(const std::string &changeable_weights_name, micro::MicroParam *micro_param); 9819be168c0dSopenharmony_ci+ STATUS ParseGraphInputsShapeTemplate(const std::string &graph_inputs_shape_template, 9820be168c0dSopenharmony_ci+ const std::map<std::string, std::string> &dynamic_symbols_map, 9821be168c0dSopenharmony_ci+ micro::MicroParam *micro_param); 9822be168c0dSopenharmony_ci }; 9823be168c0dSopenharmony_ci } // namespace lite 9824be168c0dSopenharmony_ci } // namespace mindspore 9825be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 9826be168c0dSopenharmony_ciindex a61bd51c..4703e889 100644 9827be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter.cc 9828be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter.cc 9829be168c0dSopenharmony_ci@@ -56,6 +56,7 @@ 9830be168c0dSopenharmony_ci #include "src/common/file_utils.h" 9831be168c0dSopenharmony_ci #include "ops/dynamic_shape.h" 9832be168c0dSopenharmony_ci #include "tools/common/parse_config_utils.h" 9833be168c0dSopenharmony_ci+#include "src/common/file_utils.h" 9834be168c0dSopenharmony_ci #include "tools/converter/converter_packed_node.h" 9835be168c0dSopenharmony_ci #include "tools/converter/config_parser/cpu_option_param_parser.h" 9836be168c0dSopenharmony_ci #include "tools/converter/export_model.h" 9837be168c0dSopenharmony_ci@@ -432,54 +433,34 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> ¶m, 9838be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse config param failed."; 9839be168c0dSopenharmony_ci return ret; 9840be168c0dSopenharmony_ci } 9841be168c0dSopenharmony_ci- ret = ParseParam(&config_parser, param, model_param_infos, maps); 9842be168c0dSopenharmony_ci- if (ret != RET_OK) { 9843be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Parse param failed."; 9844be168c0dSopenharmony_ci- return ret; 9845be168c0dSopenharmony_ci- } 9846be168c0dSopenharmony_ci- return RET_OK; 9847be168c0dSopenharmony_ci-} 9848be168c0dSopenharmony_ci- 9849be168c0dSopenharmony_ci-int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std::shared_ptr<ConverterPara> ¶m, 9850be168c0dSopenharmony_ci- const std::map<int, std::map<std::string, std::string>> *model_param_infos, 9851be168c0dSopenharmony_ci- const std::map<std::string, std::map<std::string, std::string>> maps) { 9852be168c0dSopenharmony_ci- param->config_infos = maps; 9853be168c0dSopenharmony_ci- auto ret = RET_OK; 9854be168c0dSopenharmony_ci if (model_param_infos->empty()) { 9855be168c0dSopenharmony_ci- ret = 9856be168c0dSopenharmony_ci- lite::PreprocessParser::ParsePreprocess(config_parser->GetDataPreProcessString(), ¶m->dataPreProcessParam); 9857be168c0dSopenharmony_ci+ ret = lite::PreprocessParser::ParsePreprocess(config_parser.GetDataPreProcessString(), ¶m->dataPreProcessParam); 9858be168c0dSopenharmony_ci if (ret != RET_OK) { 9859be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse preprocess failed."; 9860be168c0dSopenharmony_ci return ret; 9861be168c0dSopenharmony_ci } 9862be168c0dSopenharmony_ci- ret = lite::QuantParamParser::ParseCommonQuant(config_parser->GetCommonQuantString(), ¶m->commonQuantParam); 9863be168c0dSopenharmony_ci+ ret = lite::QuantParamParser::ParseCommonQuant(config_parser.GetCommonQuantString(), ¶m->commonQuantParam); 9864be168c0dSopenharmony_ci if (ret != RET_OK) { 9865be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse common quant param failed."; 9866be168c0dSopenharmony_ci return ret; 9867be168c0dSopenharmony_ci } 9868be168c0dSopenharmony_ci- ret = lite::QuantParamParser::ParseFullQuant(config_parser->GetFullQuantString(), ¶m->fullQuantParam); 9869be168c0dSopenharmony_ci+ ret = lite::QuantParamParser::ParseFullQuant(config_parser.GetFullQuantString(), ¶m->fullQuantParam); 9870be168c0dSopenharmony_ci if (ret != RET_OK) { 9871be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse full quant param failed."; 9872be168c0dSopenharmony_ci return ret; 9873be168c0dSopenharmony_ci } 9874be168c0dSopenharmony_ci- ret = lite::QuantParamParser::ParseWeightQuant(config_parser->GetWeightQuantString(), ¶m->weightQuantParam); 9875be168c0dSopenharmony_ci- if (ret != RET_OK) { 9876be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Parse full quant param failed."; 9877be168c0dSopenharmony_ci- return ret; 9878be168c0dSopenharmony_ci- } 9879be168c0dSopenharmony_ci- ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser->GetMixedBitWeightQuantString(), 9880be168c0dSopenharmony_ci+ ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(), 9881be168c0dSopenharmony_ci ¶m->mixedBitWeightQuantParam); 9882be168c0dSopenharmony_ci if (ret != RET_OK) { 9883be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; 9884be168c0dSopenharmony_ci return ret; 9885be168c0dSopenharmony_ci } 9886be168c0dSopenharmony_ci- ret = lite::ThirdPartyParamParser::Parse(config_parser->GetThirdPartyModelString(), 9887be168c0dSopenharmony_ci- ¶m->thirdPartyModelParam); 9888be168c0dSopenharmony_ci+ ret = lite::ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m->thirdPartyModelParam); 9889be168c0dSopenharmony_ci if (ret != RET_OK) { 9890be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse third party param failed."; 9891be168c0dSopenharmony_ci return ret; 9892be168c0dSopenharmony_ci } 9893be168c0dSopenharmony_ci- ret = InitExtendedIntegrationInfo(param, *config_parser); 9894be168c0dSopenharmony_ci+ ret = InitExtendedIntegrationInfo(param, config_parser); 9895be168c0dSopenharmony_ci if (ret != RET_OK) { 9896be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse extended integration info failed."; 9897be168c0dSopenharmony_ci return ret; 9898be168c0dSopenharmony_ci@@ -490,7 +471,7 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9899be168c0dSopenharmony_ci param->aclModelOptionCfgParam.dump_model_name = 9900be168c0dSopenharmony_ci dir_pos != std::string::npos ? output_file.substr(dir_pos + 1) : output_file; 9901be168c0dSopenharmony_ci lite::AclOptionParamParser acl_param_parser; 9902be168c0dSopenharmony_ci- ret = acl_param_parser.ParseAclOptionCfg(config_parser->GetAclOptionCfgString(), ¶m->aclModelOptionCfgParam); 9903be168c0dSopenharmony_ci+ ret = acl_param_parser.ParseAclOptionCfg(config_parser.GetAclOptionCfgString(), ¶m->aclModelOptionCfgParam); 9904be168c0dSopenharmony_ci if (ret != RET_OK) { 9905be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse acl option param failed."; 9906be168c0dSopenharmony_ci return ret; 9907be168c0dSopenharmony_ci@@ -498,14 +479,14 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9908be168c0dSopenharmony_ci // parse ascend_context in config file, the priority is higher 9909be168c0dSopenharmony_ci if (maps.find("ascend_context") != maps.end()) { 9910be168c0dSopenharmony_ci auto map = maps.at("ascend_context"); 9911be168c0dSopenharmony_ci- config_parser->SetParamByConfigfile(param, map); 9912be168c0dSopenharmony_ci+ config_parser.SetParamByConfigfile(param, map); 9913be168c0dSopenharmony_ci } 9914be168c0dSopenharmony_ci if (!param->config_file.empty()) { 9915be168c0dSopenharmony_ci (void)CheckOfflineParallelConfig(param->config_file, ¶m->parallel_split_config); 9916be168c0dSopenharmony_ci } 9917be168c0dSopenharmony_ci 9918be168c0dSopenharmony_ci lite::CpuOptionParamParser cpu_param_parser; 9919be168c0dSopenharmony_ci- ret = cpu_param_parser.ParseCpuOptionCfg(config_parser->GetCpuOptionCfgString(), ¶m->cpuOptionCfgParam); 9920be168c0dSopenharmony_ci+ ret = cpu_param_parser.ParseCpuOptionCfg(config_parser.GetCpuOptionCfgString(), ¶m->cpuOptionCfgParam); 9921be168c0dSopenharmony_ci if (ret != RET_OK) { 9922be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse cpu option param failed."; 9923be168c0dSopenharmony_ci return ret; 9924be168c0dSopenharmony_ci@@ -515,29 +496,29 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9925be168c0dSopenharmony_ci << "If there are multi models, only support micro_param and model_param, other configure can not take effect"; 9926be168c0dSopenharmony_ci 9927be168c0dSopenharmony_ci lite::MicroParamParser micro_param_parser; 9928be168c0dSopenharmony_ci- ret = micro_param_parser.ParseMicroParam(config_parser->GetMicroParamString(), ¶m->microParam); 9929be168c0dSopenharmony_ci+ ret = micro_param_parser.ParseMicroParam(config_parser.GetMicroParamString(), ¶m->microParam); 9930be168c0dSopenharmony_ci if (ret != RET_OK) { 9931be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse micro param failed."; 9932be168c0dSopenharmony_ci return ret; 9933be168c0dSopenharmony_ci } 9934be168c0dSopenharmony_ci ret = 9935be168c0dSopenharmony_ci- lite::QuantParamParser::ParseTransformQuant(config_parser->GetTransformQuantString(), ¶m->transformQuantParam); 9936be168c0dSopenharmony_ci+ lite::QuantParamParser::ParseTransformQuant(config_parser.GetTransformQuantString(), ¶m->transformQuantParam); 9937be168c0dSopenharmony_ci if (ret != RET_OK) { 9938be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse transform quant param failed."; 9939be168c0dSopenharmony_ci return ret; 9940be168c0dSopenharmony_ci } 9941be168c0dSopenharmony_ci- ret = lite::QuantParamParser::ParseAscendQuant(config_parser->GetAscendQuantString(), ¶m->ascendQuantParam); 9942be168c0dSopenharmony_ci+ ret = lite::QuantParamParser::ParseAscendQuant(config_parser.GetAscendQuantString(), ¶m->ascendQuantParam); 9943be168c0dSopenharmony_ci if (ret != RET_OK) { 9944be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse ascend quant param failed."; 9945be168c0dSopenharmony_ci return ret; 9946be168c0dSopenharmony_ci } 9947be168c0dSopenharmony_ci- ret = lite::QuantParamParser::ParseDynamicQuant(config_parser->GetDynamicQuantString(), ¶m->dynamicQuantParam); 9948be168c0dSopenharmony_ci+ ret = lite::QuantParamParser::ParseDynamicQuant(config_parser.GetDynamicQuantString(), ¶m->dynamicQuantParam); 9949be168c0dSopenharmony_ci if (ret != RET_OK) { 9950be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse dynamic quant param failed."; 9951be168c0dSopenharmony_ci return ret; 9952be168c0dSopenharmony_ci } 9953be168c0dSopenharmony_ci lite::GraphKernelParamParser graph_kernel_parser; 9954be168c0dSopenharmony_ci- ret = graph_kernel_parser.ParseGraphKernelCfg(config_parser->GetGraphKernelString(), ¶m->graphKernelParam); 9955be168c0dSopenharmony_ci+ ret = graph_kernel_parser.ParseGraphKernelCfg(config_parser.GetGraphKernelString(), ¶m->graphKernelParam); 9956be168c0dSopenharmony_ci if (ret != RET_OK) { 9957be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse graph kernel param failed."; 9958be168c0dSopenharmony_ci return ret; 9959be168c0dSopenharmony_ci@@ -708,9 +689,9 @@ int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) { 9960be168c0dSopenharmony_ci if (param != nullptr) { 9961be168c0dSopenharmony_ci return RET_OK; 9962be168c0dSopenharmony_ci } 9963be168c0dSopenharmony_ci- std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9964be168c0dSopenharmony_ci- FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9965be168c0dSopenharmony_ci- FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty}; 9966be168c0dSopenharmony_ci+ std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9967be168c0dSopenharmony_ci+ FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9968be168c0dSopenharmony_ci+ FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty}; 9969be168c0dSopenharmony_ci if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { 9970be168c0dSopenharmony_ci MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be " 9971be168c0dSopenharmony_ci "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY" 9972be168c0dSopenharmony_ci@@ -1010,7 +991,6 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, void **m 9973be168c0dSopenharmony_ci model_index++; 9974be168c0dSopenharmony_ci } 9975be168c0dSopenharmony_ci } 9976be168c0dSopenharmony_ci- 9977be168c0dSopenharmony_ci return RET_OK; 9978be168c0dSopenharmony_ci } 9979be168c0dSopenharmony_ci 9980be168c0dSopenharmony_ci@@ -1045,7 +1025,6 @@ int ConverterImpl::HandleGraphCommon(const std::shared_ptr<ConverterPara> ¶m 9981be168c0dSopenharmony_ci MS_LOG(ERROR) << "Save graph failed: " << ret << " " << GetErrorInfo(ret); 9982be168c0dSopenharmony_ci return ret; 9983be168c0dSopenharmony_ci } 9984be168c0dSopenharmony_ci- 9985be168c0dSopenharmony_ci return RET_OK; 9986be168c0dSopenharmony_ci } 9987be168c0dSopenharmony_ci 9988be168c0dSopenharmony_ci@@ -1067,8 +1046,8 @@ int ConverterImpl::ExecuteMicro(const schema::MetaGraphT *meta_graph, const std: 9989be168c0dSopenharmony_ci } 9990be168c0dSopenharmony_ci auto status = 9991be168c0dSopenharmony_ci meta_graph != nullptr 9992be168c0dSopenharmony_ci- ? micro::Coder::MicroSourceCodeGeneration(*meta_graph, output_path, param->microParam, param->weight_fp16) 9993be168c0dSopenharmony_ci- : micro::Coder::MicroSourceCodeGeneration(param->model_file, output_path, param->microParam, param->weight_fp16); 9994be168c0dSopenharmony_ci+ ? micro::Coder::MicroSourceCodeGeneration(*meta_graph, output_path, ¶m->microParam, param->weight_fp16) 9995be168c0dSopenharmony_ci+ : micro::Coder::MicroSourceCodeGeneration(param->model_file, output_path, ¶m->microParam, param->weight_fp16); 9996be168c0dSopenharmony_ci if (status != RET_OK) { 9997be168c0dSopenharmony_ci MS_LOG(ERROR) << "Execute Micro failed."; 9998be168c0dSopenharmony_ci } 9999be168c0dSopenharmony_ci@@ -1123,7 +1102,6 @@ int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<Converter 10000be168c0dSopenharmony_ci MS_LOG(ERROR) << "Save failed:" << status << " " << GetErrorInfo(status); 10001be168c0dSopenharmony_ci return status; 10002be168c0dSopenharmony_ci } 10003be168c0dSopenharmony_ci- 10004be168c0dSopenharmony_ci return RET_OK; 10005be168c0dSopenharmony_ci } 10006be168c0dSopenharmony_ci 10007be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/import/mindspore_importer.cc b/mindspore/lite/tools/converter/import/mindspore_importer.cc 10008be168c0dSopenharmony_ciindex 1d5afde4..aee0c854 100644 10009be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/import/mindspore_importer.cc 10010be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/import/mindspore_importer.cc 10011be168c0dSopenharmony_ci@@ -39,6 +39,7 @@ 10012be168c0dSopenharmony_ci #include "tools/optimizer/graph/redundant_op_remove_pass.h" 10013be168c0dSopenharmony_ci #include "nnacl/op_base.h" 10014be168c0dSopenharmony_ci #include "src/common/common.h" 10015be168c0dSopenharmony_ci+#include "tools/converter/import/to_custom_op_pass.h" 10016be168c0dSopenharmony_ci 10017be168c0dSopenharmony_ci namespace mindspore::lite { 10018be168c0dSopenharmony_ci namespace { 10019be168c0dSopenharmony_ci@@ -89,6 +90,13 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, 10020be168c0dSopenharmony_ci ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); 10021be168c0dSopenharmony_ci return RET_ERROR; 10022be168c0dSopenharmony_ci } 10023be168c0dSopenharmony_ci+ auto to_custom_op_pass = std::make_shared<mindspore::opt::ToCustomOpPass>(); 10024be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(to_custom_op_pass != nullptr, RET_NULL_PTR, "to_custom_op_pass is nullptr."); 10025be168c0dSopenharmony_ci+ if (!to_custom_op_pass->Run(func_graph)) { 10026be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "To custom op pass run failed!"; 10027be168c0dSopenharmony_ci+ ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); 10028be168c0dSopenharmony_ci+ return RET_ERROR; 10029be168c0dSopenharmony_ci+ } 10030be168c0dSopenharmony_ci return RET_OK; 10031be168c0dSopenharmony_ci } 10032be168c0dSopenharmony_ci 10033be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/import/to_custom_op_pass.cc b/mindspore/lite/tools/converter/import/to_custom_op_pass.cc 10034be168c0dSopenharmony_cinew file mode 100644 10035be168c0dSopenharmony_ciindex 00000000..55e524e6 10036be168c0dSopenharmony_ci--- /dev/null 10037be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/import/to_custom_op_pass.cc 10038be168c0dSopenharmony_ci@@ -0,0 +1,86 @@ 10039be168c0dSopenharmony_ci+/** 10040be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 10041be168c0dSopenharmony_ci+ * 10042be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 10043be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 10044be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 10045be168c0dSopenharmony_ci+ * 10046be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 10047be168c0dSopenharmony_ci+ * 10048be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 10049be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 10050be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10051be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 10052be168c0dSopenharmony_ci+ * limitations under the License. 10053be168c0dSopenharmony_ci+ */ 10054be168c0dSopenharmony_ci+ 10055be168c0dSopenharmony_ci+#include "tools/converter/import/to_custom_op_pass.h" 10056be168c0dSopenharmony_ci+#include "ops/grad/gather_d_grad_v2.h" 10057be168c0dSopenharmony_ci+#include "ops/masked_fill.h" 10058be168c0dSopenharmony_ci+#include "ops/custom.h" 10059be168c0dSopenharmony_ci+#include "ops/op_utils.h" 10060be168c0dSopenharmony_ci+#include "mindspore/ccsrc/include/common/utils/utils.h" 10061be168c0dSopenharmony_ci+#include "nnacl/custom_gather_d_grad_v2_parameter.h" 10062be168c0dSopenharmony_ci+ 10063be168c0dSopenharmony_ci+using mindspore::ops::kNameGatherDGradV2; 10064be168c0dSopenharmony_ci+using mindspore::ops::kNameMaskedFill; 10065be168c0dSopenharmony_ci+ 10066be168c0dSopenharmony_ci+namespace mindspore { 10067be168c0dSopenharmony_ci+namespace opt { 10068be168c0dSopenharmony_ci+bool ToCustomOpPass::Run(const FuncGraphPtr &graph) { 10069be168c0dSopenharmony_ci+ MS_ASSERT(graph != nullptr); 10070be168c0dSopenharmony_ci+ auto manager = graph->manager(); 10071be168c0dSopenharmony_ci+ MS_ASSERT(manager != nullptr); 10072be168c0dSopenharmony_ci+ auto node_list = TopoSort(graph->get_return()); 10073be168c0dSopenharmony_ci+ 10074be168c0dSopenharmony_ci+ for (auto &node : node_list) { 10075be168c0dSopenharmony_ci+ if (!utils::isa<CNodePtr>(node)) { 10076be168c0dSopenharmony_ci+ continue; 10077be168c0dSopenharmony_ci+ } 10078be168c0dSopenharmony_ci+ auto cnode = node->cast<CNodePtr>(); 10079be168c0dSopenharmony_ci+ MS_ASSERT(cnode != nullptr); 10080be168c0dSopenharmony_ci+ auto value_node = cnode->input(0); 10081be168c0dSopenharmony_ci+ auto prim = GetValueNode<PrimitivePtr>(value_node); 10082be168c0dSopenharmony_ci+ if (prim == nullptr) { 10083be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "this is a call cnode, which input[0] is fg."; 10084be168c0dSopenharmony_ci+ continue; 10085be168c0dSopenharmony_ci+ } 10086be168c0dSopenharmony_ci+ 10087be168c0dSopenharmony_ci+ auto func = ToCustomOpRegistry::GetInstance()->GetToCustomOpFunc(prim->name()); 10088be168c0dSopenharmony_ci+ if (func == nullptr) { 10089be168c0dSopenharmony_ci+ continue; 10090be168c0dSopenharmony_ci+ } 10091be168c0dSopenharmony_ci+ 10092be168c0dSopenharmony_ci+ auto ret = func(cnode); 10093be168c0dSopenharmony_ci+ if (ret != RET_OK) { 10094be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "failed to convert normal cnode node to custom cnode"; 10095be168c0dSopenharmony_ci+ return false; 10096be168c0dSopenharmony_ci+ } 10097be168c0dSopenharmony_ci+ } 10098be168c0dSopenharmony_ci+ return true; 10099be168c0dSopenharmony_ci+} 10100be168c0dSopenharmony_ci+ 10101be168c0dSopenharmony_ci+int GatherDGradV2ToCustomOp(const CNodePtr &cnode) { 10102be168c0dSopenharmony_ci+ auto ori_prim = ops::GetOperator<ops::GatherDGradV2>(cnode->input(kAnfPrimitiveIndex)); 10103be168c0dSopenharmony_ci+ auto dim = ori_prim->get_dim(); 10104be168c0dSopenharmony_ci+ auto dim_str = std::to_string(dim); 10105be168c0dSopenharmony_ci+ std::map<std::string, std::vector<uint8_t>> attrs; 10106be168c0dSopenharmony_ci+ attrs["dim"] = std::vector<uint8_t>(dim_str.begin(), dim_str.end()); 10107be168c0dSopenharmony_ci+ auto custom_prim = std::make_shared<mindspore::ops::Custom>(); 10108be168c0dSopenharmony_ci+ custom_prim->set_type(kNameGatherDGradV2); 10109be168c0dSopenharmony_ci+ cnode->set_input(kAnfPrimitiveIndex, NewValueNode(custom_prim->GetPrim())); 10110be168c0dSopenharmony_ci+ custom_prim->set_attr(attrs); 10111be168c0dSopenharmony_ci+ return RET_OK; 10112be168c0dSopenharmony_ci+} 10113be168c0dSopenharmony_ci+ 10114be168c0dSopenharmony_ci+int MaskedFillToCustomOp(const CNodePtr &cnode) { 10115be168c0dSopenharmony_ci+ auto custom_prim = std::make_shared<mindspore::ops::Custom>(); 10116be168c0dSopenharmony_ci+ custom_prim->set_type(kNameMaskedFill); 10117be168c0dSopenharmony_ci+ cnode->set_input(kAnfPrimitiveIndex, NewValueNode(custom_prim->GetPrim())); 10118be168c0dSopenharmony_ci+ return RET_OK; 10119be168c0dSopenharmony_ci+} 10120be168c0dSopenharmony_ci+ 10121be168c0dSopenharmony_ci+REGISTER_TO_CUSTOM_OP(kNameGatherDGradV2, GatherDGradV2ToCustomOp); 10122be168c0dSopenharmony_ci+REGISTER_TO_CUSTOM_OP(kNameMaskedFill, MaskedFillToCustomOp); 10123be168c0dSopenharmony_ci+} // namespace opt 10124be168c0dSopenharmony_ci+} // namespace mindspore 10125be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/import/to_custom_op_pass.h b/mindspore/lite/tools/converter/import/to_custom_op_pass.h 10126be168c0dSopenharmony_cinew file mode 100644 10127be168c0dSopenharmony_ciindex 00000000..7108e48b 10128be168c0dSopenharmony_ci--- /dev/null 10129be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/import/to_custom_op_pass.h 10130be168c0dSopenharmony_ci@@ -0,0 +1,68 @@ 10131be168c0dSopenharmony_ci+/** 10132be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 10133be168c0dSopenharmony_ci+ * 10134be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 10135be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 10136be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 10137be168c0dSopenharmony_ci+ * 10138be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 10139be168c0dSopenharmony_ci+ * 10140be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 10141be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 10142be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10143be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 10144be168c0dSopenharmony_ci+ * limitations under the License. 10145be168c0dSopenharmony_ci+ */ 10146be168c0dSopenharmony_ci+ 10147be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_TO_CUSTOM_OP_PASS_H_ 10148be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_TO_CUSTOM_OP_PASS_H_ 10149be168c0dSopenharmony_ci+#include <string> 10150be168c0dSopenharmony_ci+#include "backend/common/optimizer/pass.h" 10151be168c0dSopenharmony_ci+#include "tools/optimizer/common/gllo_utils.h" 10152be168c0dSopenharmony_ci+ 10153be168c0dSopenharmony_ci+namespace mindspore { 10154be168c0dSopenharmony_ci+namespace opt { 10155be168c0dSopenharmony_ci+ 10156be168c0dSopenharmony_ci+typedef int (*ToCustomOpFunc)(const CNodePtr &cnode); 10157be168c0dSopenharmony_ci+class ToCustomOpRegistry { 10158be168c0dSopenharmony_ci+ public: 10159be168c0dSopenharmony_ci+ static ToCustomOpRegistry *GetInstance() { 10160be168c0dSopenharmony_ci+ static ToCustomOpRegistry registry; 10161be168c0dSopenharmony_ci+ return ®istry; 10162be168c0dSopenharmony_ci+ } 10163be168c0dSopenharmony_ci+ 10164be168c0dSopenharmony_ci+ void InsertToCustomOpMap(const std::string &key, ToCustomOpFunc creator) { to_custom_op_funcs_[key] = creator; } 10165be168c0dSopenharmony_ci+ 10166be168c0dSopenharmony_ci+ ToCustomOpFunc GetToCustomOpFunc(const std::string &key) { 10167be168c0dSopenharmony_ci+ if (to_custom_op_funcs_.find(key) != to_custom_op_funcs_.end()) { 10168be168c0dSopenharmony_ci+ return to_custom_op_funcs_[key]; 10169be168c0dSopenharmony_ci+ } else { 10170be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Unsupported primitive type : " << key; 10171be168c0dSopenharmony_ci+ return nullptr; 10172be168c0dSopenharmony_ci+ } 10173be168c0dSopenharmony_ci+ } 10174be168c0dSopenharmony_ci+ 10175be168c0dSopenharmony_ci+ protected: 10176be168c0dSopenharmony_ci+ std::map<std::string, ToCustomOpFunc> to_custom_op_funcs_; 10177be168c0dSopenharmony_ci+}; 10178be168c0dSopenharmony_ci+ 10179be168c0dSopenharmony_ci+class RegistryToCustomOp { 10180be168c0dSopenharmony_ci+ public: 10181be168c0dSopenharmony_ci+ RegistryToCustomOp(const std::string &key, ToCustomOpFunc creator) { 10182be168c0dSopenharmony_ci+ ToCustomOpRegistry::GetInstance()->InsertToCustomOpMap(key, creator); 10183be168c0dSopenharmony_ci+ } 10184be168c0dSopenharmony_ci+ virtual ~RegistryToCustomOp() = default; 10185be168c0dSopenharmony_ci+}; 10186be168c0dSopenharmony_ci+ 10187be168c0dSopenharmony_ci+#define REGISTER_TO_CUSTOM_OP(type, to_custom_op_func) \ 10188be168c0dSopenharmony_ci+ RegistryToCustomOp g_##type##_to_custom_op(type, to_custom_op_func); 10189be168c0dSopenharmony_ci+ 10190be168c0dSopenharmony_ci+class ToCustomOpPass : public Pass { 10191be168c0dSopenharmony_ci+ public: 10192be168c0dSopenharmony_ci+ ToCustomOpPass() : Pass("ToCustomOpPass") {} 10193be168c0dSopenharmony_ci+ ~ToCustomOpPass() = default; 10194be168c0dSopenharmony_ci+ bool Run(const FuncGraphPtr &graph) override; 10195be168c0dSopenharmony_ci+}; 10196be168c0dSopenharmony_ci+} // namespace opt 10197be168c0dSopenharmony_ci+} // namespace mindspore 10198be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_IMPORT_TO_CUSTOM_OP_PASS_H_ 10199be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc 10200be168c0dSopenharmony_ciindex 8ea838cf..a551196d 100644 10201be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc 10202be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc 10203be168c0dSopenharmony_ci@@ -287,7 +287,6 @@ bool FusionPass::MatchTree(const schema::MetaGraphT &graph, size_t nodeIdx, cons 10204be168c0dSopenharmony_ci bool FusionPass::CheckMatchParams(const schema::MetaGraphT &graph, size_t nodeIdx, 10205be168c0dSopenharmony_ci const std::shared_ptr<PatternOp> &target, const std::vector<size_t> &sinkIdes, 10206be168c0dSopenharmony_ci const std::vector<size_t> &pathSinkIdes) { 10207be168c0dSopenharmony_ci- MS_ASSERT(target != nullptr); 10208be168c0dSopenharmony_ci MS_ASSERT(nodeIdx < graph.nodes.size()); 10209be168c0dSopenharmony_ci auto &scope = graph.nodes.at(nodeIdx); 10210be168c0dSopenharmony_ci MS_CHECK_TRUE_MSG(scope != nullptr, false, "Node in graph is nullptr"); 10211be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc 10212be168c0dSopenharmony_ciindex 371e93fb..ff99f1f4 100644 10213be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc 10214be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc 10215be168c0dSopenharmony_ci@@ -660,7 +660,9 @@ int InferShapePass::InitSearchTensor(const int64_t &subgraph_index, MetaGraphT * 10216be168c0dSopenharmony_ci } 10217be168c0dSopenharmony_ci auto &subgraph = graph->subGraph.at(subgraph_index); 10218be168c0dSopenharmony_ci for (uint32_t i = 0; i < tensors_.size(); i++) { 10219be168c0dSopenharmony_ci- if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty()) { 10220be168c0dSopenharmony_ci+ if (IsContain(subgraph->inputIndices, i) || !graph->allTensors.at(i)->data.empty() || 10221be168c0dSopenharmony_ci+ (graph->allTensors.at(i)->nodeType == NodeType_ValueNode && graph->allTensors.at(i)->dims.size() == 1 && 10222be168c0dSopenharmony_ci+ graph->allTensors.at(i)->dims[0] == 0)) { 10223be168c0dSopenharmony_ci tensors_[i].is_inferred_ = true; 10224be168c0dSopenharmony_ci } 10225be168c0dSopenharmony_ci } 10226be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 10227be168c0dSopenharmony_ciindex c132460e..5dcf0bb7 100644 10228be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 10229be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/cmake/file_list.cmake 10230be168c0dSopenharmony_ci@@ -4,6 +4,8 @@ set(CODER_SRC 10231be168c0dSopenharmony_ci ${MICRO_DIR}/coder/context.cc 10232be168c0dSopenharmony_ci ${MICRO_DIR}/coder/graph.cc 10233be168c0dSopenharmony_ci ${MICRO_DIR}/coder/session.cc 10234be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/shape_info_container.cc 10235be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/dynamic_mem_manager.cc 10236be168c0dSopenharmony_ci ${MICRO_DIR}/coder/utils/coder_utils.cc 10237be168c0dSopenharmony_ci ${MICRO_DIR}/coder/utils/dir_utils.cc 10238be168c0dSopenharmony_ci ${MICRO_DIR}/coder/utils/train_utils.cc 10239be168c0dSopenharmony_ci@@ -23,6 +25,7 @@ set(CODER_ALLOCATOR_SRC 10240be168c0dSopenharmony_ci set(CODER_GENERATOR_SRC 10241be168c0dSopenharmony_ci ${MICRO_DIR}/coder/generator/generator.cc 10242be168c0dSopenharmony_ci ${MICRO_DIR}/coder/generator/inference/inference_generator.cc 10243be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/generator/component/allocator_component.cc 10244be168c0dSopenharmony_ci ${MICRO_DIR}/coder/generator/component/common_component.cc 10245be168c0dSopenharmony_ci ${MICRO_DIR}/coder/generator/component/weight_component.cc 10246be168c0dSopenharmony_ci ${MICRO_DIR}/coder/generator/component/allocator_component.cc 10247be168c0dSopenharmony_ci@@ -66,6 +69,8 @@ set(CODER_OPCODERS_SRC 10248be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/base/stack_base_coder.cc 10249be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/base/unstack_base_coder.cc 10250be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/base/strided_slice_base_coder.cc 10251be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/base/reshape_dynamic_base_coder.cc 10252be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/base/strided_slice_dynamic_base_coder.cc 10253be168c0dSopenharmony_ci #### cmsis int8 coder 10254be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/cmsis-nn/int8/add_int8_coder.cc 10255be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/cmsis-nn/int8/conv2d_base_coder.cc 10256be168c0dSopenharmony_ci@@ -81,23 +86,37 @@ set(CODER_OPCODERS_SRC 10257be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_fp16_coder.cc 10258be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/avg_pooling_fp16_coder.cc 10259be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_fp16_coder.cc 10260be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc 10261be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc 10262be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc 10263be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc 10264be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc 10265be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc 10266be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc 10267be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/custom_gru_fp16_coder.cc 10268be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/deconv2d_fp16_coder.cc 10269be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc 10270be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 10271be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc 10272be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc 10273be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_fp16_coder.cc 10274be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/layernorm_fp16_coder.cc 10275be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/reduce_fp16_coder.cc 10276be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp16/resize_fp16_coder.cc 10277be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 10278be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_fp16_coder.cc 10279be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv2d_delegate_fp16_coder.cc 10280be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_fp16_coder.cc 10281be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_fp16_coder.cc 10282be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_winograd_fp16_coder.cc 10283be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_1x1_fp16_coder.cc 10284be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_3x3_fp16_coder.cc 10285be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv_depthwise_sw_fp16_coder.cc 10286be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp16/lstm_fp16_coder.cc 10287be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc 10288be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc 10289be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc 10290be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc 10291be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc 10292be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc 10293be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc 10294be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc 10295be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc 10296be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc 10297be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc 10298be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc 10299be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc 10300be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc 10301be168c0dSopenharmony_ci #### nnacl fp32 coder 10302be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/activation_fp32_coder.cc 10303be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/addn_fp32_coder.cc 10304be168c0dSopenharmony_ci@@ -122,6 +141,7 @@ set(CODER_OPCODERS_SRC 10305be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/lstm_fp32_coder.cc 10306be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_base_coder.cc 10307be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/matmul_fp32_coder.cc 10308be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc 10309be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/pad_fp32_coder.cc 10310be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/pooling_fp32_coder.cc 10311be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/power_fp32_coder.cc 10312be168c0dSopenharmony_ci@@ -133,17 +153,14 @@ set(CODER_OPCODERS_SRC 10313be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_fp32_coder.cc 10314be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/splice_fp32_coder.cc 10315be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/exp_fp32_coder.cc 10316be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc 10317be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/deconv2d_fp32_coder.cc 10318be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/prelu_fp32_coder.cc 10319be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/fp32/layernorm_fp32_coder.cc 10320be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32/ones_like_fp32_coder.cc 10321be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32/fill_fp32_coder.cc 10322be168c0dSopenharmony_ci- #### nnacl fp32_grad coder 10323be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/activation_grad_coder.cc 10324be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/adam_coder.cc 10325be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/assign_coder.cc 10326be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/biasadd_grad_coder.cc 10327be168c0dSopenharmony_ci- ${MICRO_DIR}/coder/opcoders/nnacl/fp32_grad/softmax_cross_entropy_with_logits_coder.cc 10328be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc 10329be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc 10330be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc 10331be168c0dSopenharmony_ci+ ${MICRO_DIR}/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc 10332be168c0dSopenharmony_ci #### nnacl int8 coder 10333be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/int8/activation_int8_coder.cc 10334be168c0dSopenharmony_ci ${MICRO_DIR}/coder/opcoders/nnacl/int8/affine_int8_coder.cc 10335be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/coder.cc b/mindspore/lite/tools/converter/micro/coder/coder.cc 10336be168c0dSopenharmony_ciindex cc224ae5..a502500d 100644 10337be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/coder.cc 10338be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/coder.cc 10339be168c0dSopenharmony_ci@@ -42,6 +42,34 @@ std::shared_ptr<CoderSession> CreateCoderSession() { 10340be168c0dSopenharmony_ci } 10341be168c0dSopenharmony_ci return session; 10342be168c0dSopenharmony_ci } 10343be168c0dSopenharmony_ci+ 10344be168c0dSopenharmony_ci+int ParseMicroDynamicShape(const schema::MetaGraphT &graph, micro::MicroParam *micro_param) { 10345be168c0dSopenharmony_ci+ for (auto index : graph.inputIndex) { 10346be168c0dSopenharmony_ci+ auto input_name = graph.allTensors.at(index)->name; 10347be168c0dSopenharmony_ci+ if (micro_param->graph_inputs_origin_info.find(input_name) == micro_param->graph_inputs_origin_info.end() || 10348be168c0dSopenharmony_ci+ micro_param->inputs_shape_by_scenes.find(input_name) == micro_param->inputs_shape_by_scenes.end()) { 10349be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Micro param: dynamic inputs name is invalid"; 10350be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 10351be168c0dSopenharmony_ci+ } 10352be168c0dSopenharmony_ci+ micro_param->graph_inputs_template.emplace_back(micro_param->graph_inputs_origin_info[input_name]); 10353be168c0dSopenharmony_ci+ micro_param->graph_inputs_shape_infos.emplace_back(micro_param->inputs_shape_by_scenes[input_name]); 10354be168c0dSopenharmony_ci+ } 10355be168c0dSopenharmony_ci+ return RET_OK; 10356be168c0dSopenharmony_ci+} 10357be168c0dSopenharmony_ci+ 10358be168c0dSopenharmony_ci+int ParseMicroDynamicShape(const Model &model, micro::MicroParam *micro_param) { 10359be168c0dSopenharmony_ci+ for (auto index : model.graph_.input_indices_) { 10360be168c0dSopenharmony_ci+ auto input_name = model.graph_.all_tensors_.at(index)->name()->str(); 10361be168c0dSopenharmony_ci+ if (micro_param->graph_inputs_origin_info.find(input_name) == micro_param->graph_inputs_origin_info.end() || 10362be168c0dSopenharmony_ci+ micro_param->inputs_shape_by_scenes.find(input_name) == micro_param->inputs_shape_by_scenes.end()) { 10363be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Micro param: dynamic inputs name is invalid"; 10364be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 10365be168c0dSopenharmony_ci+ } 10366be168c0dSopenharmony_ci+ micro_param->graph_inputs_template.emplace_back(micro_param->graph_inputs_origin_info[input_name]); 10367be168c0dSopenharmony_ci+ micro_param->graph_inputs_shape_infos.emplace_back(micro_param->inputs_shape_by_scenes[input_name]); 10368be168c0dSopenharmony_ci+ } 10369be168c0dSopenharmony_ci+ return RET_OK; 10370be168c0dSopenharmony_ci+} 10371be168c0dSopenharmony_ci } // namespace 10372be168c0dSopenharmony_ci int Coder::Run(const void *model_buff, size_t size, const std::string &model_name, bool end_flag, bool enable_fp16) { 10373be168c0dSopenharmony_ci session_ = CreateCoderSession(); 10374be168c0dSopenharmony_ci@@ -109,29 +137,37 @@ bool Coder::InitPath(const std::string &output_path) { 10375be168c0dSopenharmony_ci return true; 10376be168c0dSopenharmony_ci } 10377be168c0dSopenharmony_ci 10378be168c0dSopenharmony_ci-int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, 10379be168c0dSopenharmony_ci- const MicroParam ¶m, bool enable_fp16) { 10380be168c0dSopenharmony_ci+int Coder::MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, MicroParam *param, 10381be168c0dSopenharmony_ci+ bool enable_fp16) { 10382be168c0dSopenharmony_ci flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); 10383be168c0dSopenharmony_ci auto offset = schema::MetaGraph::Pack(builder, &graph); 10384be168c0dSopenharmony_ci builder.Finish(offset); 10385be168c0dSopenharmony_ci schema::FinishMetaGraphBuffer(builder, offset); 10386be168c0dSopenharmony_ci size_t size = builder.GetSize(); 10387be168c0dSopenharmony_ci- if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, param, enable_fp16) != RET_OK) { 10388be168c0dSopenharmony_ci+ if (!param->dynamic_symbols.empty()) { 10389be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ParseMicroDynamicShape(graph, param) == RET_OK, RET_ERROR, "ParseMicroDynamicShape failed."); 10390be168c0dSopenharmony_ci+ } 10391be168c0dSopenharmony_ci+ if (ExecuteMicroGeneration(builder.GetBufferPointer(), size, output_path, *param, enable_fp16) != RET_OK) { 10392be168c0dSopenharmony_ci MS_LOG(ERROR) << "Execute Micro failed."; 10393be168c0dSopenharmony_ci return RET_ERROR; 10394be168c0dSopenharmony_ci } 10395be168c0dSopenharmony_ci return RET_OK; 10396be168c0dSopenharmony_ci } 10397be168c0dSopenharmony_ci 10398be168c0dSopenharmony_ci-int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, 10399be168c0dSopenharmony_ci- const MicroParam ¶m, bool enable_fp16) { 10400be168c0dSopenharmony_ci+int Coder::MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, MicroParam *param, 10401be168c0dSopenharmony_ci+ bool enable_fp16) { 10402be168c0dSopenharmony_ci size_t buffer_size; 10403be168c0dSopenharmony_ci auto model_buf = lite::ReadFile(model_file.c_str(), &buffer_size); 10404be168c0dSopenharmony_ci if (model_buf == nullptr) { 10405be168c0dSopenharmony_ci MS_LOG(ERROR) << "Read model-file failed."; 10406be168c0dSopenharmony_ci return RET_NULL_PTR; 10407be168c0dSopenharmony_ci } 10408be168c0dSopenharmony_ci- auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, param, enable_fp16); 10409be168c0dSopenharmony_ci+ Model *model = lite::Model::Import(model_buf, buffer_size); 10410be168c0dSopenharmony_ci+ MS_CHECK_PTR(model); 10411be168c0dSopenharmony_ci+ if (!param->dynamic_symbols.empty()) { 10412be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ParseMicroDynamicShape(*model, param) == RET_OK, RET_ERROR, "ParseMicroDynamicShape failed."); 10413be168c0dSopenharmony_ci+ } 10414be168c0dSopenharmony_ci+ auto ret = ExecuteMicroGeneration(model_buf, buffer_size, output_path, *param, enable_fp16); 10415be168c0dSopenharmony_ci if (ret != RET_OK) { 10416be168c0dSopenharmony_ci MS_LOG(ERROR) << "Execute Micro failed."; 10417be168c0dSopenharmony_ci } 10418be168c0dSopenharmony_ci@@ -199,6 +235,10 @@ int Coder::Init(const MicroParam ¶m) const { 10419be168c0dSopenharmony_ci DirectoryGenerator::GetInstance()->project_name()); 10420be168c0dSopenharmony_ci config->set_keep_original_weight(param.keep_original_weight); 10421be168c0dSopenharmony_ci config->set_changeable_weights_name(param.changeable_weights_name); 10422be168c0dSopenharmony_ci+ config->set_graph_inputs_shape_infos(param.graph_inputs_shape_infos); 10423be168c0dSopenharmony_ci+ config->set_dynamic_symbols(param.dynamic_symbols); 10424be168c0dSopenharmony_ci+ config->set_dynamic_symbols_num(param.dynamic_symbols_num); 10425be168c0dSopenharmony_ci+ config->set_user_graph_inputs_template(param.graph_inputs_template); 10426be168c0dSopenharmony_ci 10427be168c0dSopenharmony_ci auto print_parameter = [](auto name, auto value) { 10428be168c0dSopenharmony_ci MS_LOG(INFO) << std::setw(20) << std::left << name << "= " << value; 10429be168c0dSopenharmony_ci@@ -209,6 +249,7 @@ int Coder::Init(const MicroParam ¶m) const { 10430be168c0dSopenharmony_ci print_parameter("codePath", config->code_path()); 10431be168c0dSopenharmony_ci print_parameter("codeMode", config->code_mode()); 10432be168c0dSopenharmony_ci print_parameter("debugMode", config->debug_mode()); 10433be168c0dSopenharmony_ci+ print_parameter("keepOriginalWeight", config->keep_original_weight()); 10434be168c0dSopenharmony_ci return RET_OK; 10435be168c0dSopenharmony_ci } 10436be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 10437be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/coder.h b/mindspore/lite/tools/converter/micro/coder/coder.h 10438be168c0dSopenharmony_ciindex c360f4c1..fad479aa 100644 10439be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/coder.h 10440be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/coder.h 10441be168c0dSopenharmony_ci@@ -31,9 +31,9 @@ class Coder final { 10442be168c0dSopenharmony_ci 10443be168c0dSopenharmony_ci ~Coder() = default; 10444be168c0dSopenharmony_ci static int MicroSourceCodeGeneration(const schema::MetaGraphT &graph, const std::string &output_path, 10445be168c0dSopenharmony_ci- const MicroParam ¶m, bool enable_fp16); 10446be168c0dSopenharmony_ci- static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, 10447be168c0dSopenharmony_ci- const MicroParam ¶m, bool enable_fp16); 10448be168c0dSopenharmony_ci+ MicroParam *param, bool enable_fp16); 10449be168c0dSopenharmony_ci+ static int MicroSourceCodeGeneration(const std::string &model_file, const std::string &output_path, MicroParam *param, 10450be168c0dSopenharmony_ci+ bool enable_fp16); 10451be168c0dSopenharmony_ci 10452be168c0dSopenharmony_ci private: 10453be168c0dSopenharmony_ci static int ExecuteMicroGeneration(const void *model_buf, size_t size, const std::string &output_path, 10454be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/config.h b/mindspore/lite/tools/converter/micro/coder/config.h 10455be168c0dSopenharmony_ciindex 9be56178..fb90a2fc 100644 10456be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/config.h 10457be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/config.h 10458be168c0dSopenharmony_ci@@ -34,6 +34,12 @@ struct MicroParam { 10459be168c0dSopenharmony_ci std::string project_name; 10460be168c0dSopenharmony_ci bool is_last_model{false}; 10461be168c0dSopenharmony_ci bool keep_original_weight{false}; 10462be168c0dSopenharmony_ci+ std::vector<std::vector<std::string>> graph_inputs_template; 10463be168c0dSopenharmony_ci+ std::map<std::string, std::vector<std::string>> graph_inputs_origin_info; 10464be168c0dSopenharmony_ci+ std::vector<std::string> dynamic_symbols; 10465be168c0dSopenharmony_ci+ std::vector<size_t> dynamic_symbols_num; 10466be168c0dSopenharmony_ci+ std::vector<std::vector<std::vector<int>>> graph_inputs_shape_infos; 10467be168c0dSopenharmony_ci+ std::map<std::string, std::vector<std::vector<int>>> inputs_shape_by_scenes; 10468be168c0dSopenharmony_ci }; 10469be168c0dSopenharmony_ci 10470be168c0dSopenharmony_ci class Configurator { 10471be168c0dSopenharmony_ci@@ -67,6 +73,29 @@ class Configurator { 10472be168c0dSopenharmony_ci void set_changeable_weights_name(const std::string &weights_name) { changeable_weights_name_ = weights_name; } 10473be168c0dSopenharmony_ci const std::string &changeable_weights_name() const { return changeable_weights_name_; } 10474be168c0dSopenharmony_ci 10475be168c0dSopenharmony_ci+ void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; } 10476be168c0dSopenharmony_ci+ bool dynamic_shape() const { return dynamic_shape_; } 10477be168c0dSopenharmony_ci+ 10478be168c0dSopenharmony_ci+ void set_dynamic_symbols(const std::vector<std::string> &dynamic_symbols) { dynamic_symbols_ = dynamic_symbols; } 10479be168c0dSopenharmony_ci+ const std::vector<std::string> &dynamic_symbols() const { return dynamic_symbols_; } 10480be168c0dSopenharmony_ci+ 10481be168c0dSopenharmony_ci+ void set_dynamic_symbols_num(const std::vector<size_t> &dynamic_symbols_num) { 10482be168c0dSopenharmony_ci+ dynamic_symbols_num_ = dynamic_symbols_num; 10483be168c0dSopenharmony_ci+ } 10484be168c0dSopenharmony_ci+ const std::vector<size_t> &dynamic_symbols_num() const { return dynamic_symbols_num_; } 10485be168c0dSopenharmony_ci+ 10486be168c0dSopenharmony_ci+ void set_user_graph_inputs_template(const std::vector<std::vector<std::string>> &graph_inputs_template) { 10487be168c0dSopenharmony_ci+ user_graph_inputs_template_ = graph_inputs_template; 10488be168c0dSopenharmony_ci+ } 10489be168c0dSopenharmony_ci+ const std::vector<std::vector<std::string>> &user_graph_inputs_template() const { 10490be168c0dSopenharmony_ci+ return user_graph_inputs_template_; 10491be168c0dSopenharmony_ci+ } 10492be168c0dSopenharmony_ci+ 10493be168c0dSopenharmony_ci+ void set_graph_inputs_shape_infos(const std::vector<std::vector<std::vector<int>>> &graph_inputs_shape_infos) { 10494be168c0dSopenharmony_ci+ graph_inputs_shape_infos_ = graph_inputs_shape_infos; 10495be168c0dSopenharmony_ci+ } 10496be168c0dSopenharmony_ci+ const std::vector<std::vector<std::vector<int>>> &graph_inputs_shape_infos() { return graph_inputs_shape_infos_; } 10497be168c0dSopenharmony_ci+ 10498be168c0dSopenharmony_ci private: 10499be168c0dSopenharmony_ci Configurator() = default; 10500be168c0dSopenharmony_ci ~Configurator() = default; 10501be168c0dSopenharmony_ci@@ -76,8 +105,13 @@ class Configurator { 10502be168c0dSopenharmony_ci bool support_parallel_{false}; 10503be168c0dSopenharmony_ci bool debug_mode_{false}; 10504be168c0dSopenharmony_ci bool keep_original_weight_{false}; 10505be168c0dSopenharmony_ci+ bool dynamic_shape_{false}; 10506be168c0dSopenharmony_ci std::string proj_dir_; 10507be168c0dSopenharmony_ci std::string changeable_weights_name_; 10508be168c0dSopenharmony_ci+ std::vector<std::string> dynamic_symbols_; 10509be168c0dSopenharmony_ci+ std::vector<size_t> dynamic_symbols_num_; 10510be168c0dSopenharmony_ci+ std::vector<std::vector<std::vector<int>>> graph_inputs_shape_infos_; 10511be168c0dSopenharmony_ci+ std::vector<std::vector<std::string>> user_graph_inputs_template_; 10512be168c0dSopenharmony_ci }; 10513be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 10514be168c0dSopenharmony_ci #endif // MICRO_CODER_CONFIG_H 10515be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/context.cc b/mindspore/lite/tools/converter/micro/coder/context.cc 10516be168c0dSopenharmony_ciindex 251b282f..7e7f640e 100644 10517be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/context.cc 10518be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/context.cc 10519be168c0dSopenharmony_ci@@ -50,4 +50,17 @@ std::vector<std::string> CoderContext::GetInitWeightSizeCode() const { 10520be168c0dSopenharmony_ci } 10521be168c0dSopenharmony_ci 10522be168c0dSopenharmony_ci void CoderContext::AppendInitWeightSizeCode(size_t w_buf_size) { weight_buffer_size_ += w_buf_size; } 10523be168c0dSopenharmony_ci+ 10524be168c0dSopenharmony_ci+const std::map<int, std::vector<int>> &CoderContext::shape_all_scenes() const { 10525be168c0dSopenharmony_ci+ return shape_info_container_->GetShapesWholeScenes(); 10526be168c0dSopenharmony_ci+} 10527be168c0dSopenharmony_ci+const std::map<const Tensor *, std::vector<std::string>> &CoderContext::shape_templates() { 10528be168c0dSopenharmony_ci+ return shape_info_container_->GetWholeTemplateShape(); 10529be168c0dSopenharmony_ci+} 10530be168c0dSopenharmony_ci+const std::map<int, std::vector<size_t>> &CoderContext::offset_all_scenes() { 10531be168c0dSopenharmony_ci+ return dynamic_mem_manager_->GetOffsetAllScenes(); 10532be168c0dSopenharmony_ci+} 10533be168c0dSopenharmony_ci+const std::vector<size_t> &CoderContext::buffer_sizes() const { return dynamic_mem_manager_->GetBufferSizes(); } 10534be168c0dSopenharmony_ci+const std::vector<size_t> &CoderContext::workspaces() const { return dynamic_mem_manager_->GetWorkSpaces(); } 10535be168c0dSopenharmony_ci+std::string CoderContext::tensor_addr(const Tensor *tensor) { return dynamic_mem_manager_->GetVarTensorAddr(tensor); } 10536be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 10537be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/context.h b/mindspore/lite/tools/converter/micro/coder/context.h 10538be168c0dSopenharmony_ciindex bad4ab40..b511eac1 100644 10539be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/context.h 10540be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/context.h 10541be168c0dSopenharmony_ci@@ -25,6 +25,8 @@ 10542be168c0dSopenharmony_ci #include <vector> 10543be168c0dSopenharmony_ci #include <algorithm> 10544be168c0dSopenharmony_ci #include "src/tensor.h" 10545be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 10546be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 10547be168c0dSopenharmony_ci 10548be168c0dSopenharmony_ci namespace mindspore::lite::micro { 10549be168c0dSopenharmony_ci class CoderContext { 10550be168c0dSopenharmony_ci@@ -146,6 +148,17 @@ class CoderContext { 10551be168c0dSopenharmony_ci 10552be168c0dSopenharmony_ci bool end_flag() { return end_flag_; } 10553be168c0dSopenharmony_ci 10554be168c0dSopenharmony_ci+ void set_shape_info_container(ShapeInfoContainer *shape_info_container) { 10555be168c0dSopenharmony_ci+ shape_info_container_ = shape_info_container; 10556be168c0dSopenharmony_ci+ } 10557be168c0dSopenharmony_ci+ void set_dynamic_mem_manager(DynamicMemManager *dynamic_mem_manager) { dynamic_mem_manager_ = dynamic_mem_manager; } 10558be168c0dSopenharmony_ci+ const std::map<int, std::vector<int>> &shape_all_scenes() const; 10559be168c0dSopenharmony_ci+ const std::map<const Tensor *, std::vector<std::string>> &shape_templates(); 10560be168c0dSopenharmony_ci+ const std::map<int, std::vector<size_t>> &offset_all_scenes(); 10561be168c0dSopenharmony_ci+ const std::vector<size_t> &buffer_sizes() const; 10562be168c0dSopenharmony_ci+ const std::vector<size_t> &workspaces() const; 10563be168c0dSopenharmony_ci+ std::string tensor_addr(const Tensor *tensor); 10564be168c0dSopenharmony_ci+ 10565be168c0dSopenharmony_ci private: 10566be168c0dSopenharmony_ci std::string model_name_; 10567be168c0dSopenharmony_ci std::vector<Tensor *> graph_inputs_; 10568be168c0dSopenharmony_ci@@ -195,6 +208,8 @@ class CoderContext { 10569be168c0dSopenharmony_ci // operator C Lang files list, depended by the net.c. it will be add to CMakeLists.txt 10570be168c0dSopenharmony_ci static std::set<std::string> c_files_; 10571be168c0dSopenharmony_ci static size_t max_buffer_size_; 10572be168c0dSopenharmony_ci+ ShapeInfoContainer *shape_info_container_; 10573be168c0dSopenharmony_ci+ DynamicMemManager *dynamic_mem_manager_; 10574be168c0dSopenharmony_ci }; 10575be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 10576be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_MICRO_CODER_CONTEXT_H_ 10577be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.cc b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.cc 10578be168c0dSopenharmony_cinew file mode 100644 10579be168c0dSopenharmony_ciindex 00000000..976bd852 10580be168c0dSopenharmony_ci--- /dev/null 10581be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.cc 10582be168c0dSopenharmony_ci@@ -0,0 +1,116 @@ 10583be168c0dSopenharmony_ci+/** 10584be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 10585be168c0dSopenharmony_ci+ * 10586be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 10587be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 10588be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 10589be168c0dSopenharmony_ci+ * 10590be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 10591be168c0dSopenharmony_ci+ * 10592be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 10593be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 10594be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10595be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 10596be168c0dSopenharmony_ci+ * limitations under the License. 10597be168c0dSopenharmony_ci+ */ 10598be168c0dSopenharmony_ci+ 10599be168c0dSopenharmony_ci+#include "coder/dynamic_mem_manager.h" 10600be168c0dSopenharmony_ci+#include <vector> 10601be168c0dSopenharmony_ci+#include "coder/allocator/memory_manager.h" 10602be168c0dSopenharmony_ci+#include "coder/generator/component/component.h" 10603be168c0dSopenharmony_ci+ 10604be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 10605be168c0dSopenharmony_ci+int DynamicMemManager::AllocDynamicMem(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10606be168c0dSopenharmony_ci+ const std::vector<Tensor *> &graph_inputs, 10607be168c0dSopenharmony_ci+ const std::vector<Tensor *> &graph_outputs, 10608be168c0dSopenharmony_ci+ const ShapeInfoContainer *shape_info_container) { 10609be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(shape_info_container, RET_NULL_PTR, "ShapeInfoContainer is a nullptr."); 10610be168c0dSopenharmony_ci+ for (size_t i = 0; i < graph_inputs.size(); ++i) { 10611be168c0dSopenharmony_ci+ graph_inputs_.insert(std::make_pair(graph_inputs.at(i), kInputPrefixName + std::to_string(i))); 10612be168c0dSopenharmony_ci+ } 10613be168c0dSopenharmony_ci+ auto var_tensor_shapes = shape_info_container->GetVarTensorInfos(); 10614be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!var_tensor_shapes.empty(), RET_ERROR, "Cannot get var-tensor's shape-info"); 10615be168c0dSopenharmony_ci+ auto scene_num = var_tensor_shapes.begin()->second.size(); 10616be168c0dSopenharmony_ci+ for (const auto &item : var_tensor_shapes) { 10617be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(item.first, RET_NULL_PTR, "Find a nullptr in shape-infos"); 10618be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(item.second.size() == scene_num, RET_ERROR, "Shape-info is invalid."); 10619be168c0dSopenharmony_ci+ } 10620be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 10621be168c0dSopenharmony_ci+ for (const auto &item : var_tensor_shapes) { 10622be168c0dSopenharmony_ci+ item.first->ResetRefCount(); 10623be168c0dSopenharmony_ci+ item.first->set_shape(item.second[i]); 10624be168c0dSopenharmony_ci+ } 10625be168c0dSopenharmony_ci+ auto ret = AllocDynamicMemCore(nodes, graph_outputs, i); 10626be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Alloc dynamic memory failed."); 10627be168c0dSopenharmony_ci+ } 10628be168c0dSopenharmony_ci+ return RET_OK; 10629be168c0dSopenharmony_ci+} 10630be168c0dSopenharmony_ci+ 10631be168c0dSopenharmony_ci+int DynamicMemManager::AllocDynamicMemCore(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10632be168c0dSopenharmony_ci+ const std::vector<Tensor *> &graph_outputs, int scene_index) { 10633be168c0dSopenharmony_ci+ if (offsets_all_scenes_.find(scene_index) != offsets_all_scenes_.end()) { 10634be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Current scene has been processed."; 10635be168c0dSopenharmony_ci+ return RET_ERROR; 10636be168c0dSopenharmony_ci+ } 10637be168c0dSopenharmony_ci+ auto manager = std::make_unique<MemoryManager>(); 10638be168c0dSopenharmony_ci+ int ret = manager->AssignMemory(nodes, graph_outputs); 10639be168c0dSopenharmony_ci+ if (ret != RET_OK) { 10640be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "assign memory failed"; 10641be168c0dSopenharmony_ci+ return RET_ERROR; 10642be168c0dSopenharmony_ci+ } 10643be168c0dSopenharmony_ci+ std::map<Tensor *, size_t> offsets = manager->variables_offset(); 10644be168c0dSopenharmony_ci+ if (offset_index_.empty()) { 10645be168c0dSopenharmony_ci+ int index = 0; 10646be168c0dSopenharmony_ci+ for (auto &item : offsets) { 10647be168c0dSopenharmony_ci+ offset_index_[item.first] = index++; 10648be168c0dSopenharmony_ci+ offsets_all_scenes_[scene_index].push_back(item.second); 10649be168c0dSopenharmony_ci+ } 10650be168c0dSopenharmony_ci+ } else { 10651be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(offsets.size() == offset_index_.size(), RET_ERROR, "Tensors num is not same."); 10652be168c0dSopenharmony_ci+ for (auto &item : offsets) { 10653be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(offset_index_.find(item.first) != offset_index_.end(), RET_ERROR, "Tensor cannot be found."); 10654be168c0dSopenharmony_ci+ offsets_all_scenes_[scene_index].push_back(item.second); 10655be168c0dSopenharmony_ci+ } 10656be168c0dSopenharmony_ci+ } 10657be168c0dSopenharmony_ci+ buffer_sizes_.push_back(manager->GetAllocatedSize()); 10658be168c0dSopenharmony_ci+ offsets_all_scenes_[scene_index].push_back(manager->GetAllocatedSize()); 10659be168c0dSopenharmony_ci+ return RET_OK; 10660be168c0dSopenharmony_ci+} 10661be168c0dSopenharmony_ci+ 10662be168c0dSopenharmony_ci+std::string DynamicMemManager::GetVarTensorAddr(const Tensor *tensor) const { 10663be168c0dSopenharmony_ci+ if (graph_inputs_.find(tensor) != graph_inputs_.end()) { 10664be168c0dSopenharmony_ci+ return graph_inputs_.at(tensor); 10665be168c0dSopenharmony_ci+ } 10666be168c0dSopenharmony_ci+ if (offset_index_.find(tensor) == offset_index_.end()) { 10667be168c0dSopenharmony_ci+ return ""; 10668be168c0dSopenharmony_ci+ } 10669be168c0dSopenharmony_ci+ if (kBufferPrefixName == nullptr || kOffsetPrefixName == nullptr) { 10670be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Buffer or Offset is a nullptr."; 10671be168c0dSopenharmony_ci+ return ""; 10672be168c0dSopenharmony_ci+ } 10673be168c0dSopenharmony_ci+ return std::string(kBufferPrefixName) + " + " + kOffsetPrefixName + "[" + std::to_string(offset_index_.at(tensor)) + 10674be168c0dSopenharmony_ci+ "]"; 10675be168c0dSopenharmony_ci+} 10676be168c0dSopenharmony_ci+ 10677be168c0dSopenharmony_ci+std::string DynamicMemManager::AllocWorkSpace(size_t size, int index) { 10678be168c0dSopenharmony_ci+ if (index < 0 || static_cast<size_t>(index) >= buffer_sizes_.size()) { 10679be168c0dSopenharmony_ci+ return ""; 10680be168c0dSopenharmony_ci+ } 10681be168c0dSopenharmony_ci+ if (static_cast<size_t>(index) + 1 >= workspaces_.size()) { 10682be168c0dSopenharmony_ci+ workspaces_.insert(workspaces_.end(), index + 1 - workspaces_.size(), 0); 10683be168c0dSopenharmony_ci+ } 10684be168c0dSopenharmony_ci+ if (workspaces_[index] < size) { 10685be168c0dSopenharmony_ci+ workspaces_[index] = size; 10686be168c0dSopenharmony_ci+ } 10687be168c0dSopenharmony_ci+ if (kBufferPrefixName == nullptr) { 10688be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Buffer is a nullptr."; 10689be168c0dSopenharmony_ci+ return ""; 10690be168c0dSopenharmony_ci+ } 10691be168c0dSopenharmony_ci+ if (kOffsetPrefixName == nullptr) { 10692be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Offset is a nullptr."; 10693be168c0dSopenharmony_ci+ return ""; 10694be168c0dSopenharmony_ci+ } 10695be168c0dSopenharmony_ci+ return "(" + std::string(kBufferPrefixName) + " + " + kOffsetPrefixName + "[" + 10696be168c0dSopenharmony_ci+ std::to_string(offsets_all_scenes_.begin()->second.size() - 1) + "])"; 10697be168c0dSopenharmony_ci+} 10698be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 10699be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.h b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.h 10700be168c0dSopenharmony_cinew file mode 100644 10701be168c0dSopenharmony_ciindex 00000000..6db7cff5 10702be168c0dSopenharmony_ci--- /dev/null 10703be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/dynamic_mem_manager.h 10704be168c0dSopenharmony_ci@@ -0,0 +1,53 @@ 10705be168c0dSopenharmony_ci+/** 10706be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 10707be168c0dSopenharmony_ci+ * 10708be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 10709be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 10710be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 10711be168c0dSopenharmony_ci+ * 10712be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 10713be168c0dSopenharmony_ci+ * 10714be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 10715be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 10716be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10717be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 10718be168c0dSopenharmony_ci+ * limitations under the License. 10719be168c0dSopenharmony_ci+ */ 10720be168c0dSopenharmony_ci+ 10721be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_DYNAMIC_MEM_MANAGER_H_ 10722be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_DYNAMIC_MEM_MANAGER_H_ 10723be168c0dSopenharmony_ci+ 10724be168c0dSopenharmony_ci+#include <map> 10725be168c0dSopenharmony_ci+#include <vector> 10726be168c0dSopenharmony_ci+#include "src/tensor.h" 10727be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 10728be168c0dSopenharmony_ci+ 10729be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 10730be168c0dSopenharmony_ci+class OperatorCoder; 10731be168c0dSopenharmony_ci+class DynamicMemManager { 10732be168c0dSopenharmony_ci+ public: 10733be168c0dSopenharmony_ci+ DynamicMemManager() = default; 10734be168c0dSopenharmony_ci+ virtual ~DynamicMemManager() = default; 10735be168c0dSopenharmony_ci+ int AllocDynamicMem(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10736be168c0dSopenharmony_ci+ const std::vector<Tensor *> &graph_inputs, const std::vector<Tensor *> &graph_outputs, 10737be168c0dSopenharmony_ci+ const ShapeInfoContainer *shape_info_container); 10738be168c0dSopenharmony_ci+ 10739be168c0dSopenharmony_ci+ std::string GetVarTensorAddr(const Tensor *tensor) const; 10740be168c0dSopenharmony_ci+ std::string AllocWorkSpace(size_t size, int index); 10741be168c0dSopenharmony_ci+ 10742be168c0dSopenharmony_ci+ const std::vector<size_t> &GetBufferSizes() const { return buffer_sizes_; } 10743be168c0dSopenharmony_ci+ const std::vector<size_t> &GetWorkSpaces() const { return workspaces_; } 10744be168c0dSopenharmony_ci+ const std::map<int, std::vector<size_t>> &GetOffsetAllScenes() { return offsets_all_scenes_; } 10745be168c0dSopenharmony_ci+ 10746be168c0dSopenharmony_ci+ private: 10747be168c0dSopenharmony_ci+ int AllocDynamicMemCore(const std::vector<std::unique_ptr<OperatorCoder>> &nodes, 10748be168c0dSopenharmony_ci+ const std::vector<Tensor *> &graph_outputs, int scene_index); 10749be168c0dSopenharmony_ci+ std::map<int, std::vector<size_t>> offsets_all_scenes_; 10750be168c0dSopenharmony_ci+ std::map<const Tensor *, int> offset_index_; 10751be168c0dSopenharmony_ci+ std::map<const Tensor *, std::string> graph_inputs_; 10752be168c0dSopenharmony_ci+ std::vector<size_t> buffer_sizes_; 10753be168c0dSopenharmony_ci+ std::vector<size_t> workspaces_; 10754be168c0dSopenharmony_ci+ int model_id_; 10755be168c0dSopenharmony_ci+}; 10756be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 10757be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_DYNAMIC_MEM_MANAGER_H_ 10758be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc 10759be168c0dSopenharmony_ciindex 643cf50b..831d4259 100644 10760be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc 10761be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/cmake_component.cc 10762be168c0dSopenharmony_ci@@ -5,7 +5,7 @@ 10763be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 10764be168c0dSopenharmony_ci * You may obtain a copy of the License at 10765be168c0dSopenharmony_ci * 10766be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 10767be168c0dSopenharmony_ci+ * http://www.apache.objrg/licenses/LICENSE-2.0 10768be168c0dSopenharmony_ci * 10769be168c0dSopenharmony_ci * Unless required by applicable law or agreed to in writing, software 10770be168c0dSopenharmony_ci * distributed under the License is distributed on an "AS IS" BASIS, 10771be168c0dSopenharmony_ci@@ -29,32 +29,32 @@ void CodeCMakeNetLibrary(std::ofstream &ofs, const std::unique_ptr<CoderContext> 10772be168c0dSopenharmony_ci } 10773be168c0dSopenharmony_ci ofs << "set(OP_SRC\n"; 10774be168c0dSopenharmony_ci for (const std::string &c_file : ctx->c_files()) { 10775be168c0dSopenharmony_ci- ofs << " " << c_file << ".o\n"; 10776be168c0dSopenharmony_ci+ ofs << " " << c_file << ".obj\n"; 10777be168c0dSopenharmony_ci } 10778be168c0dSopenharmony_ci for (int i = 0; i <= ctx->GetCurModelIndex(); ++i) { 10779be168c0dSopenharmony_ci- ofs << " weight" << i << ".c.o\n" 10780be168c0dSopenharmony_ci- << " net" << i << ".c.o\n" 10781be168c0dSopenharmony_ci- << " model" << i << ".c.o\n"; 10782be168c0dSopenharmony_ci+ ofs << " weight" << i << ".c.obj\n" 10783be168c0dSopenharmony_ci+ << " net" << i << ".c.obj\n" 10784be168c0dSopenharmony_ci+ << " model" << i << ".c.obj\n"; 10785be168c0dSopenharmony_ci } 10786be168c0dSopenharmony_ci- ofs << " model.c.o\n" 10787be168c0dSopenharmony_ci- << " context.c.o\n" 10788be168c0dSopenharmony_ci- << " tensor.c.o\n"; 10789be168c0dSopenharmony_ci- if (config->target() != kCortex_M) { 10790be168c0dSopenharmony_ci- ofs << " allocator.c.o\n"; 10791be168c0dSopenharmony_ci+ ofs << " model.c.obj\n" 10792be168c0dSopenharmony_ci+ << " context.c.obj\n" 10793be168c0dSopenharmony_ci+ << " tensor.c.obj\n"; 10794be168c0dSopenharmony_ci+ if (config->target() != kCortex_M && !config->dynamic_shape()) { 10795be168c0dSopenharmony_ci+ ofs << " allocator.c.obj\n"; 10796be168c0dSopenharmony_ci } 10797be168c0dSopenharmony_ci if (config->debug_mode()) { 10798be168c0dSopenharmony_ci- ofs << " debug_utils.c.o\n"; 10799be168c0dSopenharmony_ci+ ofs << " debug_utils.c.obj\n"; 10800be168c0dSopenharmony_ci } 10801be168c0dSopenharmony_ci if (config->support_parallel()) { 10802be168c0dSopenharmony_ci- ofs << " micro_core_affinity.c.o\n" 10803be168c0dSopenharmony_ci- " micro_thread_pool.c.o\n"; 10804be168c0dSopenharmony_ci+ ofs << " micro_core_affinity.c.obj\n" 10805be168c0dSopenharmony_ci+ " micro_thread_pool.c.obj\n"; 10806be168c0dSopenharmony_ci } 10807be168c0dSopenharmony_ci ofs << ")\n"; 10808be168c0dSopenharmony_ci std::set<std::string> kernel_cmake_asm_set_files = ctx->asm_files(); 10809be168c0dSopenharmony_ci if (!kernel_cmake_asm_set_files.empty() && (config->target() == kARM32 || config->target() == kARM64)) { 10810be168c0dSopenharmony_ci ofs << "set(ASSEMBLY_SRC\n"; 10811be168c0dSopenharmony_ci for (const std::string &asm_file : kernel_cmake_asm_set_files) { 10812be168c0dSopenharmony_ci- ofs << " " << asm_file << ".o\n"; 10813be168c0dSopenharmony_ci+ ofs << " " << asm_file << ".obj\n"; 10814be168c0dSopenharmony_ci } 10815be168c0dSopenharmony_ci ofs << ")\n" 10816be168c0dSopenharmony_ci << "set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)\n" 10817be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 10818be168c0dSopenharmony_ciindex 774e8353..62c2f668 100644 10819be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 10820be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.cc 10821be168c0dSopenharmony_ci@@ -16,6 +16,7 @@ 10822be168c0dSopenharmony_ci 10823be168c0dSopenharmony_ci #include "coder/generator/component/common_component.h" 10824be168c0dSopenharmony_ci #include <memory> 10825be168c0dSopenharmony_ci+#include "coder/generator/component/const_blocks/license.h" 10826be168c0dSopenharmony_ci #include "coder/generator/component/component.h" 10827be168c0dSopenharmony_ci #include "coder/utils/type_cast.h" 10828be168c0dSopenharmony_ci #include "coder/utils/coder_utils.h" 10829be168c0dSopenharmony_ci@@ -23,36 +24,59 @@ 10830be168c0dSopenharmony_ci #include "include/errorcode.h" 10831be168c0dSopenharmony_ci #include "nnacl/op_base.h" 10832be168c0dSopenharmony_ci #include "include/c_api/model_c.h" 10833be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 10834be168c0dSopenharmony_ci 10835be168c0dSopenharmony_ci namespace mindspore::lite::micro { 10836be168c0dSopenharmony_ci-const char handle_array_destroy_state[] = R"RAW( 10837be168c0dSopenharmony_ci-void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs); 10838be168c0dSopenharmony_ci+const char model_runtime_init_source[] = R"RAW( 10839be168c0dSopenharmony_ci+typedef struct { 10840be168c0dSopenharmony_ci+ void *runtime_buffer; 10841be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray inputs; 10842be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray outputs; 10843be168c0dSopenharmony_ci+} MicroModel; 10844be168c0dSopenharmony_ci+OH_AI_ModelHandle OH_AI_ModelCreate() { 10845be168c0dSopenharmony_ci+ MicroModel *micro_model = (MicroModel *)malloc(sizeof(MicroModel)); 10846be168c0dSopenharmony_ci+ if (micro_model == NULL) { 10847be168c0dSopenharmony_ci+ return NULL; 10848be168c0dSopenharmony_ci+ } 10849be168c0dSopenharmony_ci+)RAW"; 10850be168c0dSopenharmony_ci+const char model_runtime_malloc_source[] = R"RAW( 10851be168c0dSopenharmony_ci+ int buffer_size = GetBufferSize(); 10852be168c0dSopenharmony_ci+ void *runtime_buffer = malloc(buffer_size); 10853be168c0dSopenharmony_ci+ if (runtime_buffer == NULL) { 10854be168c0dSopenharmony_ci+ return NULL; 10855be168c0dSopenharmony_ci+ } 10856be168c0dSopenharmony_ci+ micro_model->runtime_buffer = runtime_buffer; 10857be168c0dSopenharmony_ci+ int ret = SetBuffer(runtime_buffer); 10858be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 10859be168c0dSopenharmony_ci+ return NULL; 10860be168c0dSopenharmony_ci+ } 10861be168c0dSopenharmony_ci+ 10862be168c0dSopenharmony_ci )RAW"; 10863be168c0dSopenharmony_ci 10864be168c0dSopenharmony_ci const char handle_array_destroy[] = R"RAW( 10865be168c0dSopenharmony_ci-void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs) { 10866be168c0dSopenharmony_ci- if (inputs.handle_list == NULL) { 10867be168c0dSopenharmony_ci- return; 10868be168c0dSopenharmony_ci- } 10869be168c0dSopenharmony_ci- for (size_t i = 0; i < inputs.handle_num; i++) { 10870be168c0dSopenharmony_ci- MicroTensor *micro_tensor = inputs.handle_list[i]; 10871be168c0dSopenharmony_ci- if (micro_tensor == NULL) { 10872be168c0dSopenharmony_ci- continue; 10873be168c0dSopenharmony_ci- } 10874be168c0dSopenharmony_ci- if (micro_tensor->data != NULL && micro_tensor->owned) { 10875be168c0dSopenharmony_ci- free(micro_tensor->data); 10876be168c0dSopenharmony_ci- micro_tensor->data = NULL; 10877be168c0dSopenharmony_ci- micro_tensor->owned = false; 10878be168c0dSopenharmony_ci- } 10879be168c0dSopenharmony_ci- if (micro_tensor->shape != NULL) { 10880be168c0dSopenharmony_ci- free(micro_tensor->shape); 10881be168c0dSopenharmony_ci- micro_tensor->shape = NULL; 10882be168c0dSopenharmony_ci- } 10883be168c0dSopenharmony_ci- free(micro_tensor); 10884be168c0dSopenharmony_ci- micro_tensor = NULL; 10885be168c0dSopenharmony_ci- } 10886be168c0dSopenharmony_ci- free(inputs.handle_list); 10887be168c0dSopenharmony_ci- inputs.handle_list = NULL; 10888be168c0dSopenharmony_ci+void OH_AI_TensorHandleArrayDestroy(OH_AI_TensorHandleArray inputs) { 10889be168c0dSopenharmony_ci+ if (inputs.handle_list == NULL) { 10890be168c0dSopenharmony_ci+ return; 10891be168c0dSopenharmony_ci+ } 10892be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs.handle_num; i++) { 10893be168c0dSopenharmony_ci+ MicroTensor *micro_tensor = inputs.handle_list[i]; 10894be168c0dSopenharmony_ci+ if (micro_tensor == NULL) { 10895be168c0dSopenharmony_ci+ continue; 10896be168c0dSopenharmony_ci+ } 10897be168c0dSopenharmony_ci+ if (micro_tensor->data != NULL && micro_tensor->owned) { 10898be168c0dSopenharmony_ci+ free(micro_tensor->data); 10899be168c0dSopenharmony_ci+ micro_tensor->data = NULL; 10900be168c0dSopenharmony_ci+ micro_tensor->owned = false; 10901be168c0dSopenharmony_ci+ } 10902be168c0dSopenharmony_ci+ if (micro_tensor->shape) { 10903be168c0dSopenharmony_ci+ free(micro_tensor->shape); 10904be168c0dSopenharmony_ci+ micro_tensor->shape = NULL; 10905be168c0dSopenharmony_ci+ } 10906be168c0dSopenharmony_ci+ free(micro_tensor); 10907be168c0dSopenharmony_ci+ micro_tensor = NULL; 10908be168c0dSopenharmony_ci+ } 10909be168c0dSopenharmony_ci+ free(inputs.handle_list); 10910be168c0dSopenharmony_ci+ inputs.handle_list = NULL; 10911be168c0dSopenharmony_ci } 10912be168c0dSopenharmony_ci 10913be168c0dSopenharmony_ci )RAW"; 10914be168c0dSopenharmony_ci@@ -62,7 +86,7 @@ const char cortex_set_workspace[] = R"RAW( 10915be168c0dSopenharmony_ci if (micro_model == NULL) { 10916be168c0dSopenharmony_ci return; 10917be168c0dSopenharmony_ci } 10918be168c0dSopenharmony_ci- if (workspace_size < MSModelCalcWorkspaceSize(model)) { 10919be168c0dSopenharmony_ci+ if (workspace_size < OH_AI_ModelCalcWorkspaceSize(model)) { 10920be168c0dSopenharmony_ci return; 10921be168c0dSopenharmony_ci } 10922be168c0dSopenharmony_ci if (micro_model->inputs.handle_num != GRAPH_INPUTS_SIZE) { 10923be168c0dSopenharmony_ci@@ -75,29 +99,29 @@ const char cortex_set_workspace[] = R"RAW( 10924be168c0dSopenharmony_ci )RAW"; 10925be168c0dSopenharmony_ci 10926be168c0dSopenharmony_ci const char micro_model_build_state[] = R"RAW( 10927be168c0dSopenharmony_ci-typedef MSStatus (*ModelBuild)(MSModelHandle model, const void *model_data, 10928be168c0dSopenharmony_ci+typedef OH_AI_Status (*ModelBuild)(OH_AI_ModelHandle model, const void *model_data, 10929be168c0dSopenharmony_ci size_t data_size, 10930be168c0dSopenharmony_ci- const MSContextHandle model_context); 10931be168c0dSopenharmony_ci+ const OH_AI_ContextHandle model_context); 10932be168c0dSopenharmony_ci )RAW"; 10933be168c0dSopenharmony_ci 10934be168c0dSopenharmony_ci const char micro_model_build_implement[] = R"RAW( 10935be168c0dSopenharmony_ci-MSStatus MSModelBuild(MSModelHandle model, const void *model_data, 10936be168c0dSopenharmony_ci- size_t data_size, MSModelType model_type, 10937be168c0dSopenharmony_ci- const MSContextHandle model_context) { 10938be168c0dSopenharmony_ci- if (model_type != kMSModelTypeMindIR) { 10939be168c0dSopenharmony_ci- return kMSStatusLiteNotSupport; 10940be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, 10941be168c0dSopenharmony_ci+ size_t data_size, OH_AI_ModelType model_type, 10942be168c0dSopenharmony_ci+ const OH_AI_ContextHandle model_context) { 10943be168c0dSopenharmony_ci+ if (model_type != OH_AI_MODELTYPE_MINDIR) { 10944be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NOT_SUPPORT; 10945be168c0dSopenharmony_ci } 10946be168c0dSopenharmony_ci if (model == NULL) { 10947be168c0dSopenharmony_ci- return kMSStatusLiteParamInvalid; 10948be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 10949be168c0dSopenharmony_ci } 10950be168c0dSopenharmony_ci )RAW"; 10951be168c0dSopenharmony_ci 10952be168c0dSopenharmony_ci const char micro_model_predict_state[] = R"RAW( 10953be168c0dSopenharmony_ci-typedef MSStatus (*ModelPredict)(MSModelHandle model, 10954be168c0dSopenharmony_ci- const MSTensorHandleArray inputs, 10955be168c0dSopenharmony_ci- MSTensorHandleArray *outputs, 10956be168c0dSopenharmony_ci- const MSKernelCallBackC before, 10957be168c0dSopenharmony_ci- const MSKernelCallBackC after); 10958be168c0dSopenharmony_ci+typedef OH_AI_Status (*ModelPredict)(OH_AI_ModelHandle model, 10959be168c0dSopenharmony_ci+ const OH_AI_TensorHandleArray inputs, 10960be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray *outputs, 10961be168c0dSopenharmony_ci+ const OH_AI_KernelCallBack before, 10962be168c0dSopenharmony_ci+ const OH_AI_KernelCallBack after); 10963be168c0dSopenharmony_ci )RAW"; 10964be168c0dSopenharmony_ci 10965be168c0dSopenharmony_ci const char free_resource_state[] = R"RAW( 10966be168c0dSopenharmony_ci@@ -107,7 +131,7 @@ typedef void (*FreeResource)(); 10967be168c0dSopenharmony_ci void CodeMSModelCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 10968be168c0dSopenharmony_ci const Configurator &config) { 10969be168c0dSopenharmony_ci if (config.target() == kCortex_M) { 10970be168c0dSopenharmony_ci- ofs << "size_t MSModelCalcWorkspaceSize(MSModelHandle model) {\n" 10971be168c0dSopenharmony_ci+ ofs << "size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {\n" 10972be168c0dSopenharmony_ci << " MicroModel *micro_model = (MicroModel *)model;\n" 10973be168c0dSopenharmony_ci << " if (micro_model == NULL) {\n" 10974be168c0dSopenharmony_ci << " return 0;\n" 10975be168c0dSopenharmony_ci@@ -118,13 +142,13 @@ void CodeMSModelCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<Code 10976be168c0dSopenharmony_ci << " return micro_model->calc_work_space(model);\n" 10977be168c0dSopenharmony_ci << "}\n"; 10978be168c0dSopenharmony_ci } else { 10979be168c0dSopenharmony_ci- ofs << "size_t MSModelCalcWorkspaceSize(MSModelHandle model) {\n return 0;\n}\n"; 10980be168c0dSopenharmony_ci+ ofs << "size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {\n return 0;\n}\n"; 10981be168c0dSopenharmony_ci } 10982be168c0dSopenharmony_ci ofs << "\n"; 10983be168c0dSopenharmony_ci } 10984be168c0dSopenharmony_ci 10985be168c0dSopenharmony_ci void CodeCortexCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 10986be168c0dSopenharmony_ci- ofs << "size_t MSModelCalcWorkspaceSize" << ctx->GetCurModelIndex() << "(MSModelHandle model) {\n" 10987be168c0dSopenharmony_ci+ ofs << "size_t OH_AI_ModelCalcWorkspaceSize" << ctx->GetCurModelIndex() << "(OH_AI_ModelHandle model) {\n" 10988be168c0dSopenharmony_ci << "size_t shape_size = 0;\n"; 10989be168c0dSopenharmony_ci std::vector<Tensor *> inputs = ctx->graph_inputs(); 10990be168c0dSopenharmony_ci for (size_t i = 0; i < inputs.size(); ++i) { 10991be168c0dSopenharmony_ci@@ -141,7 +165,7 @@ void CodeCortexCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<Coder 10992be168c0dSopenharmony_ci } 10993be168c0dSopenharmony_ci 10994be168c0dSopenharmony_ci void CodeMSModelSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 10995be168c0dSopenharmony_ci- ofs << "void MSModelSetWorkspace(MSModelHandle model, void *workspace, size_t workspace_size) {"; 10996be168c0dSopenharmony_ci+ ofs << "void OH_AI_ModelSetWorkspace(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {"; 10997be168c0dSopenharmony_ci if (config.target() == kCortex_M) { 10998be168c0dSopenharmony_ci ofs << " MicroModel *micro_model = (MicroModel *)model;\n" 10999be168c0dSopenharmony_ci << " if (micro_model == NULL) {\n" 11000be168c0dSopenharmony_ci@@ -156,8 +180,8 @@ void CodeMSModelSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderCont 11001be168c0dSopenharmony_ci } 11002be168c0dSopenharmony_ci 11003be168c0dSopenharmony_ci void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 11004be168c0dSopenharmony_ci- ofs << "void MSModelSetWorkspace" << ctx->GetCurModelIndex() 11005be168c0dSopenharmony_ci- << "(MSModelHandle model, void *workspace, size_t workspace_size) {\n"; 11006be168c0dSopenharmony_ci+ ofs << "void OH_AI_ModelSetWorkspace" << ctx->GetCurModelIndex() 11007be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, void *workspace, size_t workspace_size) {\n"; 11008be168c0dSopenharmony_ci ofs << cortex_set_workspace; 11009be168c0dSopenharmony_ci ofs << " micro_model->runtime_buffer = workspace;\n" 11010be168c0dSopenharmony_ci " int buffer_size = GetBufferSize" 11011be168c0dSopenharmony_ci@@ -173,12 +197,12 @@ void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderConte 11012be168c0dSopenharmony_ci buffer_size += WEIGHT_BUF_SIZE; 11013be168c0dSopenharmony_ci buffer_size = UP_ROUND(buffer_size,4); 11014be168c0dSopenharmony_ci 11015be168c0dSopenharmony_ci- micro_model->inputs.handle_list = (MSTensorHandle *)&buf[buffer_size]; 11016be168c0dSopenharmony_ci+ micro_model->inputs.handle_list = (OH_AI_TensorHandle *)&buf[buffer_size]; 11017be168c0dSopenharmony_ci buffer_size += GRAPH_INPUTS_SIZE * sizeof(MicroTensor *); 11018be168c0dSopenharmony_ci buffer_size = UP_ROUND(buffer_size,4); 11019be168c0dSopenharmony_ci MicroTensor **input_tensors = (MicroTensor **)micro_model->inputs.handle_list; 11020be168c0dSopenharmony_ci 11021be168c0dSopenharmony_ci- micro_model->outputs.handle_list = (MSTensorHandle *)&buf[buffer_size]; 11022be168c0dSopenharmony_ci+ micro_model->outputs.handle_list = (OH_AI_TensorHandle *)&buf[buffer_size]; 11023be168c0dSopenharmony_ci buffer_size += GRAPH_OUTPUTS_SIZE * sizeof(MicroTensor *); 11024be168c0dSopenharmony_ci buffer_size = UP_ROUND(buffer_size,4); 11025be168c0dSopenharmony_ci MicroTensor **output_tensors = (MicroTensor **)micro_model->outputs.handle_list; 11026be168c0dSopenharmony_ci@@ -215,7 +239,7 @@ void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderConte 11027be168c0dSopenharmony_ci auto array_tostring = [&ofs](Tensor *tensor, const std::string &prefix, size_t index) { 11028be168c0dSopenharmony_ci ofs << kAlignedString << prefix << "_tensors[" << index << "]->type = " << EnumNameMSDataType(tensor->data_type()) 11029be168c0dSopenharmony_ci << ";\n"; 11030be168c0dSopenharmony_ci- ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = kMSFormatNHWC;\n"; 11031be168c0dSopenharmony_ci+ ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = OH_AI_FORMAT_NHWC;\n"; 11032be168c0dSopenharmony_ci ofs << kAlignedString << prefix << "_tensors[" << index << "]->ndim = " << tensor->shape().size() << ";\n"; 11033be168c0dSopenharmony_ci size_t shape_size = tensor->shape().size(); 11034be168c0dSopenharmony_ci for (size_t i = 0; i < shape_size; i++) { 11035be168c0dSopenharmony_ci@@ -234,32 +258,31 @@ void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderConte 11036be168c0dSopenharmony_ci ofs << "}\n"; 11037be168c0dSopenharmony_ci } 11038be168c0dSopenharmony_ci 11039be168c0dSopenharmony_ci-void CodeMSTensorHandleArrayDestroyState(std::ofstream &ofs, const Configurator &config) { 11040be168c0dSopenharmony_ci- if (config.target() != kCortex_M) { 11041be168c0dSopenharmony_ci- ofs << handle_array_destroy_state; 11042be168c0dSopenharmony_ci- } 11043be168c0dSopenharmony_ci+void CodeMSModelCreateDefault(std::ofstream &ofs) { 11044be168c0dSopenharmony_ci+ ofs << "OH_AI_ModelHandle OH_AI_ModelCreate() { return model0; }\n"; 11045be168c0dSopenharmony_ci } 11046be168c0dSopenharmony_ci 11047be168c0dSopenharmony_ci-void CodeMSModelCreateDefault(std::ofstream &ofs) { ofs << "MSModelHandle MSModelCreate() { return model0; }\n"; } 11048be168c0dSopenharmony_ci- 11049be168c0dSopenharmony_ci void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11050be168c0dSopenharmony_ci if (config.target() != kCortex_M) { 11051be168c0dSopenharmony_ci- ofs << "MSStatus MSModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {"; 11052be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {"; 11053be168c0dSopenharmony_ci ofs << R"RAW( 11054be168c0dSopenharmony_ci if (micro_model == NULL) { 11055be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11056be168c0dSopenharmony_ci- } 11057be168c0dSopenharmony_ci- 11058be168c0dSopenharmony_ci- void *runtime_buffer = GlobalMemory(); 11059be168c0dSopenharmony_ci- if (runtime_buffer == NULL) { 11060be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11061be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 11062be168c0dSopenharmony_ci } 11063be168c0dSopenharmony_ci- micro_model->runtime_buffer = runtime_buffer; 11064be168c0dSopenharmony_ci )RAW"; 11065be168c0dSopenharmony_ci- ofs << " int ret = SetBuffer" << ctx->GetCurModelIndex() << "(((MemBlock *)runtime_buffer)->addr);\n" 11066be168c0dSopenharmony_ci- << " if (ret != kMSStatusSuccess) {\n" 11067be168c0dSopenharmony_ci- << " return kMSStatusLiteMemoryFailed;\n" 11068be168c0dSopenharmony_ci- << " }\n\n"; 11069be168c0dSopenharmony_ci+ if (!config.dynamic_shape()) { 11070be168c0dSopenharmony_ci+ ofs << "void *runtime_buffer = GlobalMemory();\n" 11071be168c0dSopenharmony_ci+ << "if (runtime_buffer == NULL) {\n" 11072be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_NULLPTR;\n" 11073be168c0dSopenharmony_ci+ << " }\n" 11074be168c0dSopenharmony_ci+ << " micro_model->runtime_buffer = runtime_buffer;\n"; 11075be168c0dSopenharmony_ci+ ofs << " int ret = SetBuffer" << ctx->GetCurModelIndex() << "(((MemBlock *)runtime_buffer)->addr);\n" 11076be168c0dSopenharmony_ci+ << " if (ret != OH_AI_STATUS_SUCCESS) {\n" 11077be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_MEMORY_FAILED;\n" 11078be168c0dSopenharmony_ci+ << " }\n\n"; 11079be168c0dSopenharmony_ci+ } else { 11080be168c0dSopenharmony_ci+ ofs << " micro_model->runtime_buffer = NULL;\n"; 11081be168c0dSopenharmony_ci+ } 11082be168c0dSopenharmony_ci if (config.code_mode() == CodeMode::Inference) { 11083be168c0dSopenharmony_ci ofs << " micro_model->train_mode = false;\n"; 11084be168c0dSopenharmony_ci } else if (config.code_mode() == CodeMode::Train) { 11085be168c0dSopenharmony_ci@@ -269,7 +292,7 @@ void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> & 11086be168c0dSopenharmony_ci ofs << kAlignedString << prefix << "_tensors[" << index << "] = malloc(sizeof(MicroTensor));\n"; 11087be168c0dSopenharmony_ci ofs << kAlignedString << prefix << "_tensors[" << index << "]->type = " << EnumNameMSDataType(tensor->data_type()) 11088be168c0dSopenharmony_ci << ";\n"; 11089be168c0dSopenharmony_ci- ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = kMSFormatNHWC;\n"; 11090be168c0dSopenharmony_ci+ ofs << kAlignedString << prefix << "_tensors[" << index << "]->format = OH_AI_FORMAT_NHWC;\n"; 11091be168c0dSopenharmony_ci ofs << kAlignedString << prefix << "_tensors[" << index << "]->ndim = " << tensor->shape().size() << ";\n"; 11092be168c0dSopenharmony_ci size_t shape_size = tensor->shape().size(); 11093be168c0dSopenharmony_ci ofs << kAlignedString << prefix << "_tensors[" << index << "]->shape = " 11094be168c0dSopenharmony_ci@@ -289,30 +312,30 @@ void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> & 11095be168c0dSopenharmony_ci outputs = ctx->graph_train_outputs(); 11096be168c0dSopenharmony_ci } 11097be168c0dSopenharmony_ci size_t inputs_size = inputs.size(); 11098be168c0dSopenharmony_ci- ofs << " MSTensorHandleArray model_inputs;\n"; 11099be168c0dSopenharmony_ci+ ofs << " OH_AI_TensorHandleArray model_inputs;\n"; 11100be168c0dSopenharmony_ci ofs << " model_inputs.handle_num = " << inputs_size << ";\n"; 11101be168c0dSopenharmony_ci ofs << " MicroTensor **input_tensors = malloc(" << inputs_size << " * sizeof(MicroTensor *));\n"; 11102be168c0dSopenharmony_ci- ofs << " model_inputs.handle_list = (MSTensorHandle *)(input_tensors);\n"; 11103be168c0dSopenharmony_ci+ ofs << " model_inputs.handle_list = (OH_AI_TensorHandle *)(input_tensors);\n"; 11104be168c0dSopenharmony_ci ofs << " micro_model->inputs = model_inputs;\n"; 11105be168c0dSopenharmony_ci for (size_t i = 0; i < inputs_size; ++i) { 11106be168c0dSopenharmony_ci Tensor *input = inputs[i]; 11107be168c0dSopenharmony_ci array_tostring(input, "input", i); 11108be168c0dSopenharmony_ci } 11109be168c0dSopenharmony_ci size_t outputs_size = outputs.size(); 11110be168c0dSopenharmony_ci- ofs << " MSTensorHandleArray model_outputs;\n"; 11111be168c0dSopenharmony_ci+ ofs << " OH_AI_TensorHandleArray model_outputs;\n"; 11112be168c0dSopenharmony_ci ofs << " model_outputs.handle_num = " << outputs_size << ";\n"; 11113be168c0dSopenharmony_ci ofs << " MicroTensor **output_tensors = malloc(" << outputs_size << " * sizeof(MicroTensor *));\n"; 11114be168c0dSopenharmony_ci- ofs << " model_outputs.handle_list = (MSTensorHandle *)(output_tensors);\n"; 11115be168c0dSopenharmony_ci+ ofs << " model_outputs.handle_list = (OH_AI_TensorHandle *)(output_tensors);\n"; 11116be168c0dSopenharmony_ci ofs << " micro_model->outputs = model_outputs;\n"; 11117be168c0dSopenharmony_ci for (size_t i = 0; i < outputs_size; ++i) { 11118be168c0dSopenharmony_ci Tensor *output = outputs[i]; 11119be168c0dSopenharmony_ci array_tostring(output, "output", i); 11120be168c0dSopenharmony_ci } 11121be168c0dSopenharmony_ci- ofs << " return kMSStatusSuccess;\n"; 11122be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_SUCCESS;\n"; 11123be168c0dSopenharmony_ci } else { 11124be168c0dSopenharmony_ci- ofs << "MSStatus MSModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {\n"; 11125be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelCreate" << ctx->GetCurModelIndex() << "(MicroModel *micro_model) {\n"; 11126be168c0dSopenharmony_ci ofs << " micro_model->train_mode = false;\n"; 11127be168c0dSopenharmony_ci- ofs << " return kMSStatusSuccess;\n"; 11128be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_SUCCESS;\n"; 11129be168c0dSopenharmony_ci } 11130be168c0dSopenharmony_ci ofs << "}\n\n"; 11131be168c0dSopenharmony_ci } 11132be168c0dSopenharmony_ci@@ -324,20 +347,20 @@ void CodeMSModelBuildCommon(std::ofstream &ofs, const Configurator &config) { 11133be168c0dSopenharmony_ci ofs << R"RAW( 11134be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 11135be168c0dSopenharmony_ci if (micro_model == NULL) { 11136be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11137be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 11138be168c0dSopenharmony_ci } 11139be168c0dSopenharmony_ci if (micro_model->build == NULL) { 11140be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11141be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 11142be168c0dSopenharmony_ci } 11143be168c0dSopenharmony_ci )RAW"; 11144be168c0dSopenharmony_ci- if (config.target() != kCortex_M) { 11145be168c0dSopenharmony_ci+ if (config.target() != kCortex_M && !config.dynamic_shape()) { 11146be168c0dSopenharmony_ci ofs << " IncRefCount();\n"; 11147be168c0dSopenharmony_ci } 11148be168c0dSopenharmony_ci ofs << R"RAW( 11149be168c0dSopenharmony_ci- MSStatus ret = 11150be168c0dSopenharmony_ci+ OH_AI_Status ret = 11151be168c0dSopenharmony_ci micro_model->build(model, model_data, data_size, model_context); 11152be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11153be168c0dSopenharmony_ci- MSModelDestroy(model); 11154be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11155be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model); 11156be168c0dSopenharmony_ci } 11157be168c0dSopenharmony_ci return ret; 11158be168c0dSopenharmony_ci } 11159be168c0dSopenharmony_ci@@ -345,23 +368,23 @@ void CodeMSModelBuildCommon(std::ofstream &ofs, const Configurator &config) { 11160be168c0dSopenharmony_ci } 11161be168c0dSopenharmony_ci 11162be168c0dSopenharmony_ci void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t weight_size, const Configurator &config) { 11163be168c0dSopenharmony_ci- ofs << "MSStatus MSModelBuild" << model_index 11164be168c0dSopenharmony_ci- << "(MSModelHandle model, const void *model_data, size_t data_size,\n" 11165be168c0dSopenharmony_ci- " const MSContextHandle model_context) {\n" 11166be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelBuild" << model_index 11167be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, const void *model_data, size_t data_size,\n" 11168be168c0dSopenharmony_ci+ " const OH_AI_ContextHandle model_context) {\n" 11169be168c0dSopenharmony_ci " if (model == NULL) {\n" 11170be168c0dSopenharmony_ci- " return kMSStatusLiteParamInvalid;\n" 11171be168c0dSopenharmony_ci+ " return OH_AI_STATUS_LITE_PARAM_INVALID;\n" 11172be168c0dSopenharmony_ci " }\n"; 11173be168c0dSopenharmony_ci if (config.changeable_weights_name().empty()) { 11174be168c0dSopenharmony_ci ofs << " if (data_size != " << weight_size 11175be168c0dSopenharmony_ci << ") {\n" 11176be168c0dSopenharmony_ci- " return kMSStatusLiteInputParamInvalid;\n" 11177be168c0dSopenharmony_ci+ " return OH_AI_STATUS_LITE_INPUT_PARAM_INVALID;\n" 11178be168c0dSopenharmony_ci " }\n"; 11179be168c0dSopenharmony_ci } 11180be168c0dSopenharmony_ci ofs << " MicroModel *micro_model = (MicroModel *)model;\n" 11181be168c0dSopenharmony_ci- " int ret = MSModelCreate" 11182be168c0dSopenharmony_ci+ " int ret = OH_AI_ModelCreate" 11183be168c0dSopenharmony_ci << model_index 11184be168c0dSopenharmony_ci << "(micro_model);\n" 11185be168c0dSopenharmony_ci- " if (ret != kMSStatusSuccess) {\n" 11186be168c0dSopenharmony_ci+ " if (ret != OH_AI_STATUS_SUCCESS) {\n" 11187be168c0dSopenharmony_ci " return ret;\n" 11188be168c0dSopenharmony_ci " }\n"; 11189be168c0dSopenharmony_ci if (config.target() != kCortex_M) { 11190be168c0dSopenharmony_ci@@ -372,7 +395,7 @@ void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t we 11191be168c0dSopenharmony_ci if (config.support_parallel()) { 11192be168c0dSopenharmony_ci ofs << " MicroContext *micro_context = (MicroContext *)model_context;\n" 11193be168c0dSopenharmony_ci " if (micro_context == NULL) {\n" 11194be168c0dSopenharmony_ci- " return kMSStatusLiteNullptr;" 11195be168c0dSopenharmony_ci+ " return OH_AI_STATUS_LITE_NULLPTR;" 11196be168c0dSopenharmony_ci " }\n" 11197be168c0dSopenharmony_ci " ret = CreateThreadPool(micro_context->thread_num_);\n" 11198be168c0dSopenharmony_ci " if(ret != RET_OK) {\n" 11199be168c0dSopenharmony_ci@@ -384,35 +407,172 @@ void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t we 11200be168c0dSopenharmony_ci ofs << "}\n"; 11201be168c0dSopenharmony_ci } 11202be168c0dSopenharmony_ci 11203be168c0dSopenharmony_ci+void CodeMSModelResizeInit(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11204be168c0dSopenharmony_ci+ auto &dynamic_symbols_num = config.dynamic_symbols_num(); 11205be168c0dSopenharmony_ci+ std::string array_index; 11206be168c0dSopenharmony_ci+ for (auto num : dynamic_symbols_num) { 11207be168c0dSopenharmony_ci+ array_index += "[" + std::to_string(num) + "]"; 11208be168c0dSopenharmony_ci+ } 11209be168c0dSopenharmony_ci+ auto shapes = ctx->shape_all_scenes(); 11210be168c0dSopenharmony_ci+ if (!shapes.empty()) { 11211be168c0dSopenharmony_ci+ auto num_of_each_scene = shapes.begin()->second.size(); 11212be168c0dSopenharmony_ci+ ofs << " static int shapes" << array_index << "[" + std::to_string(num_of_each_scene) + "] = {"; 11213be168c0dSopenharmony_ci+ for (auto &item : shapes) { 11214be168c0dSopenharmony_ci+ auto &shape_val = item.second; 11215be168c0dSopenharmony_ci+ for (size_t j = 0; j < shape_val.size(); ++j) { 11216be168c0dSopenharmony_ci+ ofs << shape_val[j] << ", "; 11217be168c0dSopenharmony_ci+ } 11218be168c0dSopenharmony_ci+ } 11219be168c0dSopenharmony_ci+ ofs << "};\n"; 11220be168c0dSopenharmony_ci+ } 11221be168c0dSopenharmony_ci+ auto offsets = ctx->offset_all_scenes(); 11222be168c0dSopenharmony_ci+ if (!offsets.empty()) { 11223be168c0dSopenharmony_ci+ auto num_of_each_scene = offsets.begin()->second.size(); 11224be168c0dSopenharmony_ci+ ofs << " static int offsets" << array_index << "[" + std::to_string(num_of_each_scene) + "] = {"; 11225be168c0dSopenharmony_ci+ for (auto &item : offsets) { 11226be168c0dSopenharmony_ci+ auto &offset_val = item.second; 11227be168c0dSopenharmony_ci+ for (size_t j = 0; j < offset_val.size(); ++j) { 11228be168c0dSopenharmony_ci+ ofs << offset_val[j] << ", "; 11229be168c0dSopenharmony_ci+ } 11230be168c0dSopenharmony_ci+ } 11231be168c0dSopenharmony_ci+ ofs << "};\n"; 11232be168c0dSopenharmony_ci+ } 11233be168c0dSopenharmony_ci+ ofs << " size_t buffer_sizes" << array_index << " = {"; 11234be168c0dSopenharmony_ci+ auto buffer_size = ctx->buffer_sizes(); 11235be168c0dSopenharmony_ci+ auto workspace = ctx->workspaces(); 11236be168c0dSopenharmony_ci+ if (buffer_size.size() != workspace.size()) { 11237be168c0dSopenharmony_ci+ return; 11238be168c0dSopenharmony_ci+ } 11239be168c0dSopenharmony_ci+ for (size_t i = 0; i < buffer_size.size(); i++) { 11240be168c0dSopenharmony_ci+ ofs << buffer_size[i] + workspace[i] << ", "; 11241be168c0dSopenharmony_ci+ } 11242be168c0dSopenharmony_ci+ ofs << "};\n"; 11243be168c0dSopenharmony_ci+} 11244be168c0dSopenharmony_ci+ 11245be168c0dSopenharmony_ci+void CodeMSModelResize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11246be168c0dSopenharmony_ci+ auto &shape_templates = ctx->shape_templates(); 11247be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelResize" << ctx->GetCurModelIndex() 11248be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos, size_t " 11249be168c0dSopenharmony_ci+ "shape_info_num) {\n" 11250be168c0dSopenharmony_ci+ " if (model == NULL) {\n" 11251be168c0dSopenharmony_ci+ " return OH_AI_STATUS_LITE_PARAM_INVALID;\n" 11252be168c0dSopenharmony_ci+ " }\n"; 11253be168c0dSopenharmony_ci+ if (!config.dynamic_shape()) { 11254be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_LITE_NOT_SUPPORT;\n"; 11255be168c0dSopenharmony_ci+ } else { 11256be168c0dSopenharmony_ci+ ofs << " MicroModel *micro_model = (MicroModel *)model;\n" 11257be168c0dSopenharmony_ci+ << " if (micro_model == NULL) {\n" 11258be168c0dSopenharmony_ci+ " return OH_AI_STATUS_LITE_NULLPTR;\n" 11259be168c0dSopenharmony_ci+ " }\n"; 11260be168c0dSopenharmony_ci+ CodeMSModelResizeInit(ofs, ctx, config); 11261be168c0dSopenharmony_ci+ std::map<std::string, std::vector<int>> symbol_to_indexes; 11262be168c0dSopenharmony_ci+ std::map<std::string, std::string> user_to_inner; 11263be168c0dSopenharmony_ci+ auto &user_graph_inputs_template = config.user_graph_inputs_template(); 11264be168c0dSopenharmony_ci+ for (size_t i = 0; i < ctx->graph_inputs().size(); ++i) { 11265be168c0dSopenharmony_ci+ auto cur_tensor = ctx->graph_inputs()[i]; 11266be168c0dSopenharmony_ci+ auto cur_shapes = shape_templates.at(cur_tensor); 11267be168c0dSopenharmony_ci+ for (size_t j = 0; j < cur_shapes.size(); ++j) { 11268be168c0dSopenharmony_ci+ if (IsNumber(cur_shapes.at(j))) { 11269be168c0dSopenharmony_ci+ continue; 11270be168c0dSopenharmony_ci+ } 11271be168c0dSopenharmony_ci+ ofs << " if (shape_infos[" << i << "].shape[" << j << "] <= 0) {\n" 11272be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_PARAM_INVALID;\n" 11273be168c0dSopenharmony_ci+ << " }\n"; 11274be168c0dSopenharmony_ci+ ofs << " ((MicroTensor *)(inputs.handle_list[" << i << "]))->shape[" << j << "] = shape_infos[" << i 11275be168c0dSopenharmony_ci+ << "].shape[" << j << "];\n"; 11276be168c0dSopenharmony_ci+ if (symbol_to_indexes.find(cur_shapes.at(j)) != symbol_to_indexes.end()) { 11277be168c0dSopenharmony_ci+ continue; 11278be168c0dSopenharmony_ci+ } 11279be168c0dSopenharmony_ci+ symbol_to_indexes[cur_shapes.at(j)] = {static_cast<int>(i), static_cast<int>(j)}; 11280be168c0dSopenharmony_ci+ user_to_inner[user_graph_inputs_template[i][j]] = cur_shapes.at(j); 11281be168c0dSopenharmony_ci+ } 11282be168c0dSopenharmony_ci+ } 11283be168c0dSopenharmony_ci+ int index = 0; 11284be168c0dSopenharmony_ci+ std::map<std::string, std::string> inner_to_outer; 11285be168c0dSopenharmony_ci+ for (auto &item : symbol_to_indexes) { 11286be168c0dSopenharmony_ci+ ofs << " int dim" << index << " = shape_infos[" << item.second[0] << "].shape[" << item.second[1] << "];\n"; 11287be168c0dSopenharmony_ci+ inner_to_outer[item.first] = "dim" + std::to_string(index); 11288be168c0dSopenharmony_ci+ ++index; 11289be168c0dSopenharmony_ci+ } 11290be168c0dSopenharmony_ci+ std::string condition; 11291be168c0dSopenharmony_ci+ index = 0; 11292be168c0dSopenharmony_ci+ for (; index < static_cast<int>(symbol_to_indexes.size()) - 1; ++index) { 11293be168c0dSopenharmony_ci+ condition += "store" + std::to_string(ctx->GetCurModelIndex()) + "_" + std::to_string(index) + " == dim" + 11294be168c0dSopenharmony_ci+ std::to_string(index) + " && "; 11295be168c0dSopenharmony_ci+ } 11296be168c0dSopenharmony_ci+ condition += "store" + std::to_string(ctx->GetCurModelIndex()) + "_" + std::to_string(index) + " == dim" + 11297be168c0dSopenharmony_ci+ std::to_string(index); 11298be168c0dSopenharmony_ci+ ofs << " if (" << condition << ") {\n" 11299be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_SUCCESS;\n" 11300be168c0dSopenharmony_ci+ << " }\n"; 11301be168c0dSopenharmony_ci+ for (size_t i = 0; i < symbol_to_indexes.size(); ++i) { 11302be168c0dSopenharmony_ci+ ofs << " store" + std::to_string(ctx->GetCurModelIndex()) + "_" << i << " = dim" << i << ";\n"; 11303be168c0dSopenharmony_ci+ } 11304be168c0dSopenharmony_ci+ ofs << " if (" << kBufferPrefixName << " != NULL) {\n"; 11305be168c0dSopenharmony_ci+ ofs << " free(" << kBufferPrefixName << ");\n"; 11306be168c0dSopenharmony_ci+ ofs << " }\n"; 11307be168c0dSopenharmony_ci+ std::string real_array_index; 11308be168c0dSopenharmony_ci+ auto &dynamic_symbols = config.dynamic_symbols(); 11309be168c0dSopenharmony_ci+ for (auto &symbol : dynamic_symbols) { 11310be168c0dSopenharmony_ci+ real_array_index += "[" + inner_to_outer[user_to_inner[symbol]] + " - 1]"; 11311be168c0dSopenharmony_ci+ } 11312be168c0dSopenharmony_ci+ ofs << " " << kBufferPrefixName << " = malloc(buffer_sizes" << real_array_index << ");\n"; 11313be168c0dSopenharmony_ci+ ofs << " micro_model->runtime_buffer = " << kBufferPrefixName << ";\n"; 11314be168c0dSopenharmony_ci+ ofs << " " << kShapePrefixName << " = &shapes" << real_array_index << "[0];\n"; 11315be168c0dSopenharmony_ci+ ofs << " " << kOffsetPrefixName << " = &offsets" << real_array_index << "[0];\n"; 11316be168c0dSopenharmony_ci+ ofs << " OH_AI_TensorHandleArray outputs = OH_AI_ModelGetOutputs(model);\n"; 11317be168c0dSopenharmony_ci+ for (size_t i = 0; i < ctx->graph_outputs().size(); ++i) { 11318be168c0dSopenharmony_ci+ ofs << " OH_AI_TensorSetData(outputs.handle_list[" << i << "], NULL);\n"; 11319be168c0dSopenharmony_ci+ auto cur_tensor = ctx->graph_outputs()[i]; 11320be168c0dSopenharmony_ci+ auto cur_shapes = shape_templates.at(cur_tensor); 11321be168c0dSopenharmony_ci+ for (size_t j = 0; j < cur_shapes.size(); ++j) { 11322be168c0dSopenharmony_ci+ if (IsNumber(cur_shapes.at(j))) { 11323be168c0dSopenharmony_ci+ continue; 11324be168c0dSopenharmony_ci+ } 11325be168c0dSopenharmony_ci+ ofs << " ((MicroTensor *)(outputs.handle_list[" << i << "]))->shape[" << j << "] = " << cur_shapes.at(j) 11326be168c0dSopenharmony_ci+ << ";\n"; 11327be168c0dSopenharmony_ci+ } 11328be168c0dSopenharmony_ci+ } 11329be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_SUCCESS;\n"; 11330be168c0dSopenharmony_ci+ } 11331be168c0dSopenharmony_ci+ ofs << "}\n"; 11332be168c0dSopenharmony_ci+} 11333be168c0dSopenharmony_ci+ 11334be168c0dSopenharmony_ci void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config) { 11335be168c0dSopenharmony_ci- if (config->target() != kCortex_M) { 11336be168c0dSopenharmony_ci+ if (config->code_mode() == CodeMode::Inference && config->target() != kCortex_M) { 11337be168c0dSopenharmony_ci ofs << handle_array_destroy; 11338be168c0dSopenharmony_ci } 11339be168c0dSopenharmony_ci- ofs << "void MSModelDestroy(MSModelHandle *model) {\n"; 11340be168c0dSopenharmony_ci+ ofs << "void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) {\n"; 11341be168c0dSopenharmony_ci+ ofs << " if (*model) {\n" 11342be168c0dSopenharmony_ci+ " MicroModel *micro_model = (MicroModel *)*model;\n"; 11343be168c0dSopenharmony_ci if (config->target() != kCortex_M) { 11344be168c0dSopenharmony_ci- ofs << " if (*model) {\n" 11345be168c0dSopenharmony_ci- " MicroModel *micro_model = (MicroModel *)*model;\n"; 11346be168c0dSopenharmony_ci- ofs << " if (micro_model->runtime_buffer) {\n" 11347be168c0dSopenharmony_ci- " micro_model->runtime_buffer = NULL;\n" 11348be168c0dSopenharmony_ci- " }\n"; 11349be168c0dSopenharmony_ci- ofs << " MSTensorHandleArrayDestroy(micro_model->inputs);\n" 11350be168c0dSopenharmony_ci- " MSTensorHandleArrayDestroy(micro_model->outputs);\n" 11351be168c0dSopenharmony_ci- " micro_model->inputs.handle_list = NULL;\n" 11352be168c0dSopenharmony_ci+ ofs << " if (micro_model->runtime_buffer) {\n"; 11353be168c0dSopenharmony_ci+ if (config->dynamic_shape()) { 11354be168c0dSopenharmony_ci+ ofs << " free(micro_model->runtime_buffer);\n"; 11355be168c0dSopenharmony_ci+ } else { 11356be168c0dSopenharmony_ci+ ofs << " micro_model->runtime_buffer = NULL;\n"; 11357be168c0dSopenharmony_ci+ } 11358be168c0dSopenharmony_ci+ ofs << " }\n"; 11359be168c0dSopenharmony_ci+ } 11360be168c0dSopenharmony_ci+ ofs << " OH_AI_TensorHandleArrayDestroy(micro_model->inputs);\n" 11361be168c0dSopenharmony_ci+ " OH_AI_TensorHandleArrayDestroy(micro_model->outputs);\n"; 11362be168c0dSopenharmony_ci+ if (config->code_mode() == CodeMode::Inference) { 11363be168c0dSopenharmony_ci+ ofs << " micro_model->inputs.handle_list = NULL;\n" 11364be168c0dSopenharmony_ci " micro_model->outputs.handle_list = NULL;\n" 11365be168c0dSopenharmony_ci- " micro_model->free_resource();\n" 11366be168c0dSopenharmony_ci- " DecRefCount();\n" 11367be168c0dSopenharmony_ci- " }\n"; 11368be168c0dSopenharmony_ci- 11369be168c0dSopenharmony_ci- if (config->support_parallel()) { 11370be168c0dSopenharmony_ci- ofs << " ClearThreadPool();\n"; 11371be168c0dSopenharmony_ci+ " micro_model->free_resource();\n"; 11372be168c0dSopenharmony_ci+ if (!config->dynamic_shape()) { 11373be168c0dSopenharmony_ci+ ofs << " DecRefCount();\n"; 11374be168c0dSopenharmony_ci } 11375be168c0dSopenharmony_ci+ ofs << " }\n"; 11376be168c0dSopenharmony_ci } else { 11377be168c0dSopenharmony_ci- ofs << " if (*model) {\n" 11378be168c0dSopenharmony_ci- " MicroModel *micro_model = (MicroModel *)*model;\n"; 11379be168c0dSopenharmony_ci- ofs << " micro_model->runtime_buffer = NULL;\n" 11380be168c0dSopenharmony_ci+ ofs << " free(*model);\n" 11381be168c0dSopenharmony_ci " *model = NULL;\n" 11382be168c0dSopenharmony_ci " }\n"; 11383be168c0dSopenharmony_ci } 11384be168c0dSopenharmony_ci+ 11385be168c0dSopenharmony_ci+ if (config->support_parallel()) { 11386be168c0dSopenharmony_ci+ ofs << " ClearThreadPool();\n"; 11387be168c0dSopenharmony_ci+ } 11388be168c0dSopenharmony_ci ofs << "}\n"; 11389be168c0dSopenharmony_ci } 11390be168c0dSopenharmony_ci 11391be168c0dSopenharmony_ci@@ -420,14 +580,14 @@ void CodeMSModelPredictState(std::ofstream &ofs) { ofs << micro_model_predict_st 11392be168c0dSopenharmony_ci 11393be168c0dSopenharmony_ci void CodeMSModelPredictCommon(std::ofstream &ofs) { 11394be168c0dSopenharmony_ci ofs << R"RAW( 11395be168c0dSopenharmony_ci-MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs, 11396be168c0dSopenharmony_ci- const MSKernelCallBackC before, const MSKernelCallBackC after) { 11397be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs, 11398be168c0dSopenharmony_ci+ const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) { 11399be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 11400be168c0dSopenharmony_ci if (micro_model == NULL) { 11401be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11402be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 11403be168c0dSopenharmony_ci } 11404be168c0dSopenharmony_ci if (micro_model->predict == NULL) { 11405be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11406be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 11407be168c0dSopenharmony_ci } 11408be168c0dSopenharmony_ci return micro_model->predict(model, inputs, outputs, before, after); 11409be168c0dSopenharmony_ci } 11410be168c0dSopenharmony_ci@@ -438,35 +598,35 @@ MSStatus MSModelPredict(MSModelHandle model, const MSTensorHandleArray inputs, M 11411be168c0dSopenharmony_ci void CodeMSModelPredict(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11412be168c0dSopenharmony_ci auto inputs_num = ctx->graph_inputs().size(); 11413be168c0dSopenharmony_ci auto outputs_num = ctx->graph_outputs().size(); 11414be168c0dSopenharmony_ci- ofs << "MSStatus MSModelPredict" << ctx->GetCurModelIndex() 11415be168c0dSopenharmony_ci- << "(MSModelHandle model, const MSTensorHandleArray inputs, MSTensorHandleArray *outputs,\n" 11416be168c0dSopenharmony_ci- << " const MSKernelCallBackC before, const MSKernelCallBackC after) {\n"; 11417be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelPredict" << ctx->GetCurModelIndex() 11418be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_TensorHandleArray *outputs,\n" 11419be168c0dSopenharmony_ci+ << " const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {\n"; 11420be168c0dSopenharmony_ci ofs << R"RAW( 11421be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 11422be168c0dSopenharmony_ci if (micro_model == NULL) { 11423be168c0dSopenharmony_ci- return kMSStatusLiteNullptr; 11424be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 11425be168c0dSopenharmony_ci } 11426be168c0dSopenharmony_ci if (micro_model->runtime_buffer == NULL) { 11427be168c0dSopenharmony_ci- return kMSStatusLiteMemoryFailed; 11428be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_MEMORY_FAILED; 11429be168c0dSopenharmony_ci } 11430be168c0dSopenharmony_ci )RAW"; 11431be168c0dSopenharmony_ci ofs << " if (inputs.handle_num != " << inputs_num << ") {\n"; 11432be168c0dSopenharmony_ci- ofs << " return kMSStatusLiteParamInvalid;\n"; 11433be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_LITE_PARAM_INVALID;\n"; 11434be168c0dSopenharmony_ci ofs << " }\n"; 11435be168c0dSopenharmony_ci ofs << " if (outputs->handle_num != " << outputs_num << ") {\n"; 11436be168c0dSopenharmony_ci- ofs << " return kMSStatusLiteParamInvalid;\n"; 11437be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_LITE_PARAM_INVALID;\n"; 11438be168c0dSopenharmony_ci ofs << " }\n"; 11439be168c0dSopenharmony_ci- if (config.target() != kCortex_M) { 11440be168c0dSopenharmony_ci+ if (config.target() != kCortex_M && !config.dynamic_shape()) { 11441be168c0dSopenharmony_ci ofs << " if (!LockBuffer(micro_model->runtime_buffer)) {\n" 11442be168c0dSopenharmony_ci << " void *buffer = Malloc(GetBufferSize" << ctx->GetCurModelIndex() << "());\n" 11443be168c0dSopenharmony_ci << " if (buffer == NULL) {\n" 11444be168c0dSopenharmony_ci- << " return kMSStatusLiteNullptr;\n" 11445be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_NULLPTR;\n" 11446be168c0dSopenharmony_ci << " }\n" 11447be168c0dSopenharmony_ci << " if (micro_model->runtime_buffer != buffer) {\n" 11448be168c0dSopenharmony_ci << " micro_model->runtime_buffer = buffer;\n" 11449be168c0dSopenharmony_ci << " int ret = SetBuffer" << ctx->GetCurModelIndex() << "(((MemBlock *)buffer)->addr);\n" 11450be168c0dSopenharmony_ci- << " if (ret != kMSStatusSuccess) {\n" 11451be168c0dSopenharmony_ci- << " return kMSStatusLiteMemoryFailed;\n" 11452be168c0dSopenharmony_ci+ << " if (ret != OH_AI_STATUS_SUCCESS) {\n" 11453be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_MEMORY_FAILED;\n" 11454be168c0dSopenharmony_ci << " }\n" 11455be168c0dSopenharmony_ci << " }\n" 11456be168c0dSopenharmony_ci << " }\n"; 11457be168c0dSopenharmony_ci@@ -495,8 +655,7 @@ void CodeMSModelPredict(std::ofstream &ofs, const std::unique_ptr<CoderContext> 11458be168c0dSopenharmony_ci ofs << " }\n"; 11459be168c0dSopenharmony_ci ofs << " }\n"; 11460be168c0dSopenharmony_ci ofs << "\n"; 11461be168c0dSopenharmony_ci- ofs << " void *outputs_data_array[" << outputs_num << "];\n"; 11462be168c0dSopenharmony_ci- ofs << " int expect_out_types[" << outputs_num << "] = {"; 11463be168c0dSopenharmony_ci+ ofs << " int cur_out_types[" << outputs_num << "] = {"; 11464be168c0dSopenharmony_ci for (size_t i = 0; i < outputs_num; ++i) { 11465be168c0dSopenharmony_ci ofs << ctx->graph_outputs().at(i)->data_type() << ", "; 11466be168c0dSopenharmony_ci } 11467be168c0dSopenharmony_ci@@ -506,21 +665,18 @@ void CodeMSModelPredict(std::ofstream &ofs, const std::unique_ptr<CoderContext> 11468be168c0dSopenharmony_ci ofs << "false, "; 11469be168c0dSopenharmony_ci } 11470be168c0dSopenharmony_ci ofs << "};\n"; 11471be168c0dSopenharmony_ci- ofs << " for (int i = 0; i < " << outputs_num << "; i++) {\n"; 11472be168c0dSopenharmony_ci- ofs << " outputs_data_array[i] = MSTensorGetMutableData(outputs->handle_list[i]);\n"; 11473be168c0dSopenharmony_ci- ofs << " }\n"; 11474be168c0dSopenharmony_ci- ofs << " CopyOutputsData" << ctx->GetCurModelIndex() 11475be168c0dSopenharmony_ci- << "(outputs, outputs_data_array, expect_out_types, out_type_changed);\n"; 11476be168c0dSopenharmony_ci- if (config.target() != kCortex_M) { 11477be168c0dSopenharmony_ci+ ofs << " OH_AI_Status ret = CopyOutputsData" << ctx->GetCurModelIndex() 11478be168c0dSopenharmony_ci+ << "(outputs, cur_out_types, out_type_changed);\n"; 11479be168c0dSopenharmony_ci+ if (config.target() != kCortex_M && !config.dynamic_shape()) { 11480be168c0dSopenharmony_ci ofs << " UnLockBuffer(micro_model->runtime_buffer);\n"; 11481be168c0dSopenharmony_ci } 11482be168c0dSopenharmony_ci- ofs << " return kMSStatusSuccess;\n"; 11483be168c0dSopenharmony_ci+ ofs << " return ret;\n"; 11484be168c0dSopenharmony_ci ofs << "}\n"; 11485be168c0dSopenharmony_ci } 11486be168c0dSopenharmony_ci 11487be168c0dSopenharmony_ci void CodeCopyOutputsState(std::ofstream &ofs, const int model_index) { 11488be168c0dSopenharmony_ci- ofs << "int CopyOutputsData" << model_index 11489be168c0dSopenharmony_ci- << "(MSTensorHandleArray *outputs_ori, void **outputs, int *expect_types, bool *type_changed);\n\n"; 11490be168c0dSopenharmony_ci+ ofs << "OH_AI_Status CopyOutputsData" << model_index 11491be168c0dSopenharmony_ci+ << "(OH_AI_TensorHandleArray *outputs_ori, void **outputs, int *cur_out_types, bool *type_changed);\n\n"; 11492be168c0dSopenharmony_ci } 11493be168c0dSopenharmony_ci 11494be168c0dSopenharmony_ci void CodeCopyOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx) { 11495be168c0dSopenharmony_ci@@ -528,56 +684,60 @@ void CodeCopyOutputsImplement(std::ofstream &ofs, const std::unique_ptr<CoderCon 11496be168c0dSopenharmony_ci std::vector<Tensor *> outputs = ctx->graph_outputs(); 11497be168c0dSopenharmony_ci size_t outputs_size = outputs.size(); 11498be168c0dSopenharmony_ci 11499be168c0dSopenharmony_ci- ofs << "int CopyOutputsData" << ctx->GetCurModelIndex() 11500be168c0dSopenharmony_ci- << "(MSTensorHandleArray *outputs_ori, void **outputs, int *expect_types, bool *type_changed) {\n" 11501be168c0dSopenharmony_ci- " if (outputs_ori == NULL || outputs == NULL) {\n" 11502be168c0dSopenharmony_ci- " return RET_ERROR;\n" 11503be168c0dSopenharmony_ci+ ofs << "OH_AI_Status CopyOutputsData" << ctx->GetCurModelIndex() 11504be168c0dSopenharmony_ci+ << "(OH_AI_TensorHandleArray *outputs_ori, int *cur_out_types, bool *type_changed) {\n" 11505be168c0dSopenharmony_ci+ " if (outputs_ori == NULL || cur_out_types == NULL || type_changed == NULL) {\n" 11506be168c0dSopenharmony_ci+ " return OH_AI_STATUS_LITE_NULLPTR;\n" 11507be168c0dSopenharmony_ci " }\n"; 11508be168c0dSopenharmony_ci ofs << " unsigned char *buffer[" << outputs_size << "] = {"; 11509be168c0dSopenharmony_ci for (size_t i = 0; i < outputs_size; ++i) { 11510be168c0dSopenharmony_ci- ofs << tensor_map[outputs[i]] << ", "; 11511be168c0dSopenharmony_ci- } 11512be168c0dSopenharmony_ci- ofs << "};\n"; 11513be168c0dSopenharmony_ci- ofs << " size_t buffer_size[" << outputs_size << "] = {"; 11514be168c0dSopenharmony_ci- for (size_t i = 0; i < outputs_size; ++i) { 11515be168c0dSopenharmony_ci- Tensor *output = outputs[i]; 11516be168c0dSopenharmony_ci- MS_CHECK_PTR_IF_NULL(output); 11517be168c0dSopenharmony_ci- ofs << output->Size() << ", "; 11518be168c0dSopenharmony_ci+ auto out_str = ctx->tensor_addr(outputs[i]); 11519be168c0dSopenharmony_ci+ if (out_str.empty()) { 11520be168c0dSopenharmony_ci+ ofs << tensor_map[outputs[i]] << ", "; 11521be168c0dSopenharmony_ci+ } else { 11522be168c0dSopenharmony_ci+ ofs << out_str << ", "; 11523be168c0dSopenharmony_ci+ } 11524be168c0dSopenharmony_ci } 11525be168c0dSopenharmony_ci ofs << "};\n"; 11526be168c0dSopenharmony_ci ofs << " for (int i = 0; i < " << outputs_size << "; i++) {\n" 11527be168c0dSopenharmony_ci << " MicroTensor *micro_tensor = (MicroTensor *)outputs_ori->handle_list[i];\n" 11528be168c0dSopenharmony_ci- << " int cur_type = micro_tensor->type;\n" 11529be168c0dSopenharmony_ci- << " int expect_type = expect_types[i];\n"; 11530be168c0dSopenharmony_ci- ofs << " if (cur_type == expect_type) {\n" 11531be168c0dSopenharmony_ci- << " memcpy(outputs[i], buffer[i], buffer_size[i]);\n" 11532be168c0dSopenharmony_ci+ << " int expect_type = micro_tensor->type;\n" 11533be168c0dSopenharmony_ci+ << " int cur_type = cur_out_types[i];\n"; 11534be168c0dSopenharmony_ci+ ofs << " if (expect_type == cur_type) {\n" 11535be168c0dSopenharmony_ci+ << " micro_tensor->data = buffer[i];\n" 11536be168c0dSopenharmony_ci+ << " micro_tensor->owned = false;\n" 11537be168c0dSopenharmony_ci << " continue;\n" 11538be168c0dSopenharmony_ci << " }\n" 11539be168c0dSopenharmony_ci+ << "#ifdef ENABLE_FP16\n" 11540be168c0dSopenharmony_ci << " int shape_size = micro_tensor->ndim;\n" 11541be168c0dSopenharmony_ci << " int num = 1;\n" 11542be168c0dSopenharmony_ci- << " for (int i = 0; i < shape_size; ++i) {\n" 11543be168c0dSopenharmony_ci- << " num *= micro_tensor->shape[i];\n" 11544be168c0dSopenharmony_ci+ << " for (int j = 0; j < shape_size; ++j) {\n" 11545be168c0dSopenharmony_ci+ << " num *= micro_tensor->shape[j];\n" 11546be168c0dSopenharmony_ci << " }\n"; 11547be168c0dSopenharmony_ci- ofs << " int type_trans_mode = TypeTransMode_MAX;\n" 11548be168c0dSopenharmony_ci- " if (expect_type == kMSDataTypeNumberTypeFloat16 && cur_type == kMSDataTypeNumberTypeFloat32) {\n" 11549be168c0dSopenharmony_ci- " type_trans_mode = TypeTransMode_FP32_TO_FP16;\n" 11550be168c0dSopenharmony_ci- " } else if (expect_type == kMSDataTypeNumberTypeFloat32 && cur_type == kMSDataTypeNumberTypeFloat16) {\n" 11551be168c0dSopenharmony_ci- " type_trans_mode = TypeTransMode_FP16_TO_FP32;\n" 11552be168c0dSopenharmony_ci- " }\n"; 11553be168c0dSopenharmony_ci+ ofs 11554be168c0dSopenharmony_ci+ << " int type_trans_mode = TypeTransMode_MAX;\n" 11555be168c0dSopenharmony_ci+ " if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT16 && cur_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32) {\n" 11556be168c0dSopenharmony_ci+ " type_trans_mode = TypeTransMode_FP32_TO_FP16;\n" 11557be168c0dSopenharmony_ci+ " } else if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32 && cur_type == " 11558be168c0dSopenharmony_ci+ "OH_AI_DATATYPE_NUMBERTYPE_FLOAT16) {\n" 11559be168c0dSopenharmony_ci+ " type_trans_mode = TypeTransMode_FP16_TO_FP32;\n" 11560be168c0dSopenharmony_ci+ " }\n"; 11561be168c0dSopenharmony_ci ofs << " if (type_trans_mode == TypeTransMode_UNSUPPORT) {\n" 11562be168c0dSopenharmony_ci- << " return kMSStatusLiteNotSupport;\n" 11563be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_NOT_SUPPORT;\n" 11564be168c0dSopenharmony_ci << " }\n"; 11565be168c0dSopenharmony_ci- ofs << "#ifdef ENABLE_FP16\n" 11566be168c0dSopenharmony_ci- << " if (type_trans_mode == TypeTransMode_FP32_TO_FP16) {\n" 11567be168c0dSopenharmony_ci- << " Fp32CastToFp16((float *)(buffer[i]), (float16_t *)&outputs, num);\n" 11568be168c0dSopenharmony_ci+ ofs << " void *out_data = OH_AI_TensorGetMutableData(micro_tensor);\n"; 11569be168c0dSopenharmony_ci+ ofs << " if (type_trans_mode == TypeTransMode_FP32_TO_FP16) {\n" 11570be168c0dSopenharmony_ci+ << " Fp32CastToFp16((float *)(buffer[i]), (float16_t *)out_data, num);\n" 11571be168c0dSopenharmony_ci << " type_changed[i] = true;\n" 11572be168c0dSopenharmony_ci << " } else if (type_trans_mode == TypeTransMode_FP16_TO_FP32) {\n" 11573be168c0dSopenharmony_ci- << " Fp16CastToFp32((float16_t *)&outputs, (float *)(buffer[i]), num);\n" 11574be168c0dSopenharmony_ci+ << " Fp16CastToFp32((float16_t *)(buffer[i]), (float *)out_data, num);\n" 11575be168c0dSopenharmony_ci << " type_changed[i] = true;\n" 11576be168c0dSopenharmony_ci << " }\n" 11577be168c0dSopenharmony_ci+ << "#else\n" 11578be168c0dSopenharmony_ci+ << " return OH_AI_STATUS_LITE_NOT_SUPPORT;\n" 11579be168c0dSopenharmony_ci << "#endif\n" 11580be168c0dSopenharmony_ci << " }\n"; 11581be168c0dSopenharmony_ci- ofs << " return RET_OK;\n" 11582be168c0dSopenharmony_ci+ ofs << " return OH_AI_STATUS_SUCCESS;\n" 11583be168c0dSopenharmony_ci "}\n\n"; 11584be168c0dSopenharmony_ci } 11585be168c0dSopenharmony_ci 11586be168c0dSopenharmony_ci@@ -688,6 +848,16 @@ void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderCo 11587be168c0dSopenharmony_ci "}\n"; 11588be168c0dSopenharmony_ci } 11589be168c0dSopenharmony_ci 11590be168c0dSopenharmony_ci+void CodeResetImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config) { 11591be168c0dSopenharmony_ci+ ofs << "void Reset" << ctx->GetCurModelIndex() << "() {\n"; 11592be168c0dSopenharmony_ci+ auto &dynamic_symbols = config.dynamic_symbols(); 11593be168c0dSopenharmony_ci+ for (size_t i = 0; i < dynamic_symbols.size(); ++i) { 11594be168c0dSopenharmony_ci+ ofs << " store" << ctx->GetCurModelIndex() << "_" << i << " = -1;\n"; 11595be168c0dSopenharmony_ci+ } 11596be168c0dSopenharmony_ci+ ofs << " FreeResource" << ctx->GetCurModelIndex() << "();\n"; 11597be168c0dSopenharmony_ci+ ofs << "}\n"; 11598be168c0dSopenharmony_ci+} 11599be168c0dSopenharmony_ci+ 11600be168c0dSopenharmony_ci void CodeFreeResourceState(std::ofstream &ofs) { ofs << free_resource_state; } 11601be168c0dSopenharmony_ci 11602be168c0dSopenharmony_ci void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 11603be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h 11604be168c0dSopenharmony_ciindex 56209f05..6f0c7736 100644 11605be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h 11606be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/common_component.h 11607be168c0dSopenharmony_ci@@ -32,12 +32,13 @@ void CodeMSModelCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<Code 11608be168c0dSopenharmony_ci void CodeCortexCalcWorkspaceSize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx); 11609be168c0dSopenharmony_ci void CodeMSModelSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11610be168c0dSopenharmony_ci void CodeCortexSetWorkspace(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx); 11611be168c0dSopenharmony_ci-void CodeMSTensorHandleArrayDestroyState(std::ofstream &ofs, const Configurator &config); 11612be168c0dSopenharmony_ci void CodeMSModelCreateDefault(std::ofstream &ofs); 11613be168c0dSopenharmony_ci void CodeMSModelCreate(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11614be168c0dSopenharmony_ci void CodeMSModelBuildState(std::ofstream &ofs); 11615be168c0dSopenharmony_ci void CodeMSModelBuildCommon(std::ofstream &ofs, const Configurator &config); 11616be168c0dSopenharmony_ci void CodeMSModelBuild(std::ofstream &ofs, const int model_index, const size_t weight_size, const Configurator &config); 11617be168c0dSopenharmony_ci+void CodeMSModelResizeInit(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11618be168c0dSopenharmony_ci+void CodeMSModelResize(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11619be168c0dSopenharmony_ci void CodeMSModelDestory(std::ofstream &ofs, const Configurator *config); 11620be168c0dSopenharmony_ci void CodeMSModelPredictState(std::ofstream &ofs); 11621be168c0dSopenharmony_ci void CodeMSModelPredictCommon(std::ofstream &ofs); 11622be168c0dSopenharmony_ci@@ -57,6 +58,7 @@ void CodeGraphQuantArgsImplement(std::ofstream &ofs, const std::unique_ptr<Coder 11623be168c0dSopenharmony_ci void CodeManageResourceState(std::ofstream &ofs, const int model_index); 11624be168c0dSopenharmony_ci void CodeInitResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx); 11625be168c0dSopenharmony_ci 11626be168c0dSopenharmony_ci+void CodeResetImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, const Configurator &config); 11627be168c0dSopenharmony_ci void CodeFreeResourceState(std::ofstream &ofs); 11628be168c0dSopenharmony_ci void CodeFreeResourceImplement(std::ofstream &ofs, const std::unique_ptr<CoderContext> &ctx, 11629be168c0dSopenharmony_ci const Configurator &config); 11630be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc 11631be168c0dSopenharmony_ciindex b2ed21be..0ee02e0c 100644 11632be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc 11633be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/component.cc 11634be168c0dSopenharmony_ci@@ -24,6 +24,8 @@ const char *kOutputPrefixName = nullptr; 11635be168c0dSopenharmony_ci const char *kWeightPrefixName = nullptr; 11636be168c0dSopenharmony_ci const char *kBufferPrefixName = nullptr; 11637be168c0dSopenharmony_ci const char *kBufferPrefixNameAdd = nullptr; 11638be168c0dSopenharmony_ci+const char *kOffsetPrefixName = nullptr; 11639be168c0dSopenharmony_ci+const char *kShapePrefixName = nullptr; 11640be168c0dSopenharmony_ci 11641be168c0dSopenharmony_ci char *ModifyPrefixName(char *name, int model_index, const std::string &prefix) { 11642be168c0dSopenharmony_ci if (name != nullptr) { 11643be168c0dSopenharmony_ci@@ -57,6 +59,8 @@ void FreeGlobalVariable() { 11644be168c0dSopenharmony_ci Free(kWeightPrefixName); 11645be168c0dSopenharmony_ci Free(kBufferPrefixName); 11646be168c0dSopenharmony_ci Free(kBufferPrefixNameAdd); 11647be168c0dSopenharmony_ci+ Free(kOffsetPrefixName); 11648be168c0dSopenharmony_ci+ Free(kShapePrefixName) 11649be168c0dSopenharmony_ci } 11650be168c0dSopenharmony_ci 11651be168c0dSopenharmony_ci void InitGlobalVariable(int model_index) { 11652be168c0dSopenharmony_ci@@ -65,5 +69,7 @@ void InitGlobalVariable(int model_index) { 11653be168c0dSopenharmony_ci kWeightPrefixName = ModifyPrefixName(const_cast<char *>(kWeightPrefixName), model_index, "_weight"); 11654be168c0dSopenharmony_ci kBufferPrefixName = ModifyPrefixName(const_cast<char *>(kBufferPrefixName), model_index, "_buffer"); 11655be168c0dSopenharmony_ci kBufferPrefixNameAdd = ModifyPrefixName(const_cast<char *>(kBufferPrefixNameAdd), model_index, "_buffer + "); 11656be168c0dSopenharmony_ci+ kOffsetPrefixName = ModifyPrefixName(const_cast<char *>(kOffsetPrefixName), model_index, "_offset"); 11657be168c0dSopenharmony_ci+ kShapePrefixName = ModifyPrefixName(const_cast<char *>(kShapePrefixName), model_index, "_shape"); 11658be168c0dSopenharmony_ci } 11659be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 11660be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/component.h b/mindspore/lite/tools/converter/micro/coder/generator/component/component.h 11661be168c0dSopenharmony_ciindex 0e943317..e084d692 100644 11662be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/component.h 11663be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/component.h 11664be168c0dSopenharmony_ci@@ -16,7 +16,6 @@ 11665be168c0dSopenharmony_ci 11666be168c0dSopenharmony_ci #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_COMPONENT_H_ 11667be168c0dSopenharmony_ci #define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_GENERATOR_COMPONENT_COMPONENT_H_ 11668be168c0dSopenharmony_ci-#include <string> 11669be168c0dSopenharmony_ci 11670be168c0dSopenharmony_ci namespace mindspore::lite::micro { 11671be168c0dSopenharmony_ci extern const char *kInputPrefixName; 11672be168c0dSopenharmony_ci@@ -26,6 +25,8 @@ constexpr auto kPackWeightOffsetName = "w_offset"; 11673be168c0dSopenharmony_ci constexpr auto kPackWeightSizeName = "w_size"; 11674be168c0dSopenharmony_ci extern const char *kBufferPrefixName; 11675be168c0dSopenharmony_ci extern const char *kBufferPrefixNameAdd; 11676be168c0dSopenharmony_ci+extern const char *kOffsetPrefixName; 11677be168c0dSopenharmony_ci+extern const char *kShapePrefixName; 11678be168c0dSopenharmony_ci void FreeGlobalVariable(); 11679be168c0dSopenharmony_ci void InitGlobalVariable(int model_index); 11680be168c0dSopenharmony_ci 11681be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc 11682be168c0dSopenharmony_ciindex 91f2ca89..ad638276 100644 11683be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc 11684be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/benchmark.cc 11685be168c0dSopenharmony_ci@@ -53,7 +53,7 @@ const char benchmark_source[] = R"RAW(/** 11686be168c0dSopenharmony_ci 11687be168c0dSopenharmony_ci void usage() { 11688be168c0dSopenharmony_ci printf( 11689be168c0dSopenharmony_ci- "-- mindspore benchmark params usage:\n" 11690be168c0dSopenharmony_ci+ "-- mindspore benchmark paraOH_AI_ usage:\n" 11691be168c0dSopenharmony_ci "args[0]: executable file\n" 11692be168c0dSopenharmony_ci "args[1]: inputs binary file\n" 11693be168c0dSopenharmony_ci "args[2]: model weight binary file\n" 11694be168c0dSopenharmony_ci@@ -67,38 +67,38 @@ void usage() { 11695be168c0dSopenharmony_ci 11696be168c0dSopenharmony_ci uint64_t GetTimeUs() { 11697be168c0dSopenharmony_ci const int USEC = 1000000; 11698be168c0dSopenharmony_ci- const int MSEC = 1000; 11699be168c0dSopenharmony_ci+ const int OH_AI_EC = 1000; 11700be168c0dSopenharmony_ci struct timespec ts = {0, 0}; 11701be168c0dSopenharmony_ci if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { 11702be168c0dSopenharmony_ci return 0; 11703be168c0dSopenharmony_ci } 11704be168c0dSopenharmony_ci- uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); 11705be168c0dSopenharmony_ci+ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / OH_AI_EC)); 11706be168c0dSopenharmony_ci return retval; 11707be168c0dSopenharmony_ci } 11708be168c0dSopenharmony_ci 11709be168c0dSopenharmony_ci-void PrintTensorHandle(MSTensorHandle tensor) { 11710be168c0dSopenharmony_ci- printf("name: %s, ", MSTensorGetName(tensor)); 11711be168c0dSopenharmony_ci- MSDataType data_type = MSTensorGetDataType(tensor); 11712be168c0dSopenharmony_ci+void PrintTensorHandle(OH_AI_TensorHandle tensor) { 11713be168c0dSopenharmony_ci+ printf("name: %s, ", OH_AI_TensorGetName(tensor)); 11714be168c0dSopenharmony_ci+ OH_AI_DataType data_type = OH_AI_TensorGetDataType(tensor); 11715be168c0dSopenharmony_ci printf("DataType: %d, ", data_type); 11716be168c0dSopenharmony_ci- size_t element_num = (size_t)(MSTensorGetElementNum(tensor)); 11717be168c0dSopenharmony_ci+ size_t element_num = (size_t)(OH_AI_TensorGetElementNum(tensor)); 11718be168c0dSopenharmony_ci printf("Elements: %zu, ", element_num); 11719be168c0dSopenharmony_ci printf("Shape: ["); 11720be168c0dSopenharmony_ci size_t shape_num = 0; 11721be168c0dSopenharmony_ci- const int64_t *dims = MSTensorGetShape(tensor, &shape_num); 11722be168c0dSopenharmony_ci+ const int64_t *dims = OH_AI_TensorGetShape(tensor, &shape_num); 11723be168c0dSopenharmony_ci for (size_t i = 0; i < shape_num; i++) { 11724be168c0dSopenharmony_ci printf("%d ", (int)dims[i]); 11725be168c0dSopenharmony_ci } 11726be168c0dSopenharmony_ci printf("], Data: \n"); 11727be168c0dSopenharmony_ci- void *data = MSTensorGetMutableData(tensor); 11728be168c0dSopenharmony_ci+ void *data = OH_AI_TensorGetMutableData(tensor); 11729be168c0dSopenharmony_ci element_num = element_num > 10 ? 10 : element_num; 11730be168c0dSopenharmony_ci switch (data_type) { 11731be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeFloat32: { 11732be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT32: { 11733be168c0dSopenharmony_ci for (size_t i = 0; i < element_num; i++) { 11734be168c0dSopenharmony_ci printf("%.6f, ", ((float *)data)[i]); 11735be168c0dSopenharmony_ci } 11736be168c0dSopenharmony_ci printf("\n"); 11737be168c0dSopenharmony_ci } break; 11738be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeFloat16: 11739be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT16: 11740be168c0dSopenharmony_ci #ifdef ENABLE_FP16 11741be168c0dSopenharmony_ci { 11742be168c0dSopenharmony_ci for (size_t i = 0; i < element_num; i++) { 11743be168c0dSopenharmony_ci@@ -107,25 +107,25 @@ void PrintTensorHandle(MSTensorHandle tensor) { 11744be168c0dSopenharmony_ci printf("\n"); 11745be168c0dSopenharmony_ci } break; 11746be168c0dSopenharmony_ci #endif 11747be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt16: { 11748be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT16: { 11749be168c0dSopenharmony_ci for (size_t i = 0; i < element_num; i++) { 11750be168c0dSopenharmony_ci printf("%" PRId16, ((int16_t *)data)[i]); 11751be168c0dSopenharmony_ci } 11752be168c0dSopenharmony_ci printf("\n"); 11753be168c0dSopenharmony_ci } break; 11754be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt32: { 11755be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT32: { 11756be168c0dSopenharmony_ci for (size_t i = 0; i < element_num; i++) { 11757be168c0dSopenharmony_ci printf("%" PRId32, ((int32_t *)data)[i]); 11758be168c0dSopenharmony_ci } 11759be168c0dSopenharmony_ci printf("\n"); 11760be168c0dSopenharmony_ci } break; 11761be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt8: { 11762be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT8: { 11763be168c0dSopenharmony_ci for (size_t i = 0; i < element_num; i++) { 11764be168c0dSopenharmony_ci printf("%" PRIi8, ((int8_t *)data)[i]); 11765be168c0dSopenharmony_ci } 11766be168c0dSopenharmony_ci printf("\n"); 11767be168c0dSopenharmony_ci } break; 11768be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt8: { 11769be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT8: { 11770be168c0dSopenharmony_ci for (size_t i = 0; i < element_num; i++) { 11771be168c0dSopenharmony_ci printf("%u", ((uint8_t *)data)[i]); 11772be168c0dSopenharmony_ci } 11773be168c0dSopenharmony_ci@@ -141,31 +141,31 @@ int main(int argc, const char **argv) { 11774be168c0dSopenharmony_ci if (argc < 2) { 11775be168c0dSopenharmony_ci printf("input command is invalid\n"); 11776be168c0dSopenharmony_ci usage(); 11777be168c0dSopenharmony_ci- return kMSStatusLiteError; 11778be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 11779be168c0dSopenharmony_ci } 11780be168c0dSopenharmony_ci printf("=======run benchmark======\n"); 11781be168c0dSopenharmony_ci 11782be168c0dSopenharmony_ci- MSContextHandle ms_context_handle = MSContextCreate(); 11783be168c0dSopenharmony_ci+ OH_AI_ContextHandle ms_context_handle = OH_AI_ContextCreate(); 11784be168c0dSopenharmony_ci if (argc >= 6) { 11785be168c0dSopenharmony_ci int thread_num = atoi(argv[5]); 11786be168c0dSopenharmony_ci if (thread_num < 1 || thread_num > kMaxThreadNum) { 11787be168c0dSopenharmony_ci printf("Thread number error! It should be greater than 0 and less than 5\n"); 11788be168c0dSopenharmony_ci- return kMSStatusLiteParamInvalid; 11789be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11790be168c0dSopenharmony_ci } 11791be168c0dSopenharmony_ci- MSContextSetThreadNum(ms_context_handle, thread_num); 11792be168c0dSopenharmony_ci+ OH_AI_ContextSetThreadNum(ms_context_handle, thread_num); 11793be168c0dSopenharmony_ci } 11794be168c0dSopenharmony_ci- printf("ThreadNum: %d.\n", MSContextGetThreadNum(ms_context_handle)); 11795be168c0dSopenharmony_ci+ printf("ThreadNum: %d.\n", OH_AI_ContextGetThreadNum(ms_context_handle)); 11796be168c0dSopenharmony_ci 11797be168c0dSopenharmony_ci int bind_mode = kBindDefault; 11798be168c0dSopenharmony_ci if (argc >= 7) { 11799be168c0dSopenharmony_ci bind_mode = atoi(argv[6]); 11800be168c0dSopenharmony_ci if (bind_mode < 0 || bind_mode > 2) { 11801be168c0dSopenharmony_ci printf("Thread bind mode error! 0: No bind, 1: Bind hign cpu, 2: Bind mid cpu.\n"); 11802be168c0dSopenharmony_ci- return kMSStatusLiteParamInvalid; 11803be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11804be168c0dSopenharmony_ci } 11805be168c0dSopenharmony_ci } 11806be168c0dSopenharmony_ci- MSContextSetThreadAffinityMode(ms_context_handle, bind_mode); 11807be168c0dSopenharmony_ci- printf("BindMode: %d.\n", MSContextGetThreadAffinityMode(ms_context_handle)); 11808be168c0dSopenharmony_ci+ OH_AI_ContextSetThreadAffinityMode(ms_context_handle, bind_mode); 11809be168c0dSopenharmony_ci+ printf("BindMode: %d.\n", OH_AI_ContextGetThreadAffinityMode(ms_context_handle)); 11810be168c0dSopenharmony_ci 11811be168c0dSopenharmony_ci void *model_buffer = NULL; 11812be168c0dSopenharmony_ci int model_size = 0; 11813be168c0dSopenharmony_ci@@ -174,14 +174,14 @@ int main(int argc, const char **argv) { 11814be168c0dSopenharmony_ci model_buffer = ReadInputData(argv[2], &model_size); 11815be168c0dSopenharmony_ci if (model_buffer == NULL) { 11816be168c0dSopenharmony_ci printf("Read model file failed."); 11817be168c0dSopenharmony_ci- return kMSStatusLiteParamInvalid; 11818be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11819be168c0dSopenharmony_ci } 11820be168c0dSopenharmony_ci } 11821be168c0dSopenharmony_ci- MSModelHandle model_handle = MSModelCreate(); 11822be168c0dSopenharmony_ci- int ret = MSModelBuild(model_handle, model_buffer, model_size, kMSModelTypeMindIR, ms_context_handle); 11823be168c0dSopenharmony_ci- MSContextDestroy(&ms_context_handle); 11824be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11825be168c0dSopenharmony_ci- printf("MSModelBuildFromFile failed, ret: %d\n", ret); 11826be168c0dSopenharmony_ci+ OH_AI_ModelHandle model_handle = OH_AI_ModelCreate(); 11827be168c0dSopenharmony_ci+ int ret = OH_AI_ModelBuild(model_handle, model_buffer, model_size, OH_AI_MODELTYPE_MINDIR, ms_context_handle); 11828be168c0dSopenharmony_ci+ OH_AI_ContextDestroy(&ms_context_handle); 11829be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11830be168c0dSopenharmony_ci+ printf("OH_AI_ModelBuild failed, ret: %d\n", ret); 11831be168c0dSopenharmony_ci free(model_buffer); 11832be168c0dSopenharmony_ci model_buffer = NULL; 11833be168c0dSopenharmony_ci return ret; 11834be168c0dSopenharmony_ci@@ -191,33 +191,33 @@ int main(int argc, const char **argv) { 11835be168c0dSopenharmony_ci model_buffer = NULL; 11836be168c0dSopenharmony_ci } 11837be168c0dSopenharmony_ci // set model inputs tensor data 11838be168c0dSopenharmony_ci- MSTensorHandleArray inputs_handle = MSModelGetInputs(model_handle); 11839be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray inputs_handle = OH_AI_ModelGetInputs(model_handle); 11840be168c0dSopenharmony_ci if (inputs_handle.handle_list == NULL) { 11841be168c0dSopenharmony_ci- printf("MSModelGetInputs failed, ret: %d", ret); 11842be168c0dSopenharmony_ci+ printf("OH_AI_ModelGetInputs failed, ret: %d", ret); 11843be168c0dSopenharmony_ci return ret; 11844be168c0dSopenharmony_ci } 11845be168c0dSopenharmony_ci size_t inputs_num = inputs_handle.handle_num; 11846be168c0dSopenharmony_ci void *inputs_binbuf[inputs_num]; 11847be168c0dSopenharmony_ci int inputs_size[inputs_num]; 11848be168c0dSopenharmony_ci for (size_t i = 0; i < inputs_num; ++i) { 11849be168c0dSopenharmony_ci- MSTensorHandle tensor = inputs_handle.handle_list[i]; 11850be168c0dSopenharmony_ci- inputs_size[i] = (int)MSTensorGetDataSize(tensor); 11851be168c0dSopenharmony_ci+ OH_AI_TensorHandle tensor = inputs_handle.handle_list[i]; 11852be168c0dSopenharmony_ci+ inputs_size[i] = (int)OH_AI_TensorGetDataSize(tensor); 11853be168c0dSopenharmony_ci } 11854be168c0dSopenharmony_ci ret = ReadInputsFile((char *)(argv[1]), inputs_binbuf, inputs_size, (int)inputs_num); 11855be168c0dSopenharmony_ci if (ret != 0) { 11856be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11857be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11858be168c0dSopenharmony_ci return ret; 11859be168c0dSopenharmony_ci } 11860be168c0dSopenharmony_ci for (size_t i = 0; i < inputs_num; ++i) { 11861be168c0dSopenharmony_ci- void *input_data = MSTensorGetMutableData(inputs_handle.handle_list[i]); 11862be168c0dSopenharmony_ci+ void *input_data = OH_AI_TensorGetMutableData(inputs_handle.handle_list[i]); 11863be168c0dSopenharmony_ci memcpy(input_data, inputs_binbuf[i], inputs_size[i]); 11864be168c0dSopenharmony_ci free(inputs_binbuf[i]); 11865be168c0dSopenharmony_ci inputs_binbuf[i] = NULL; 11866be168c0dSopenharmony_ci } 11867be168c0dSopenharmony_ci 11868be168c0dSopenharmony_ci- MSTensorHandleArray outputs_handle = MSModelGetOutputs(model_handle); 11869be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray outputs_handle = OH_AI_ModelGetOutputs(model_handle); 11870be168c0dSopenharmony_ci if (!outputs_handle.handle_list) { 11871be168c0dSopenharmony_ci- printf("MSModelGetOutputs failed, ret: %d", ret); 11872be168c0dSopenharmony_ci+ printf("OH_AI_ModelGetOutputs failed, ret: %d", ret); 11873be168c0dSopenharmony_ci return ret; 11874be168c0dSopenharmony_ci } 11875be168c0dSopenharmony_ci 11876be168c0dSopenharmony_ci@@ -226,15 +226,15 @@ int main(int argc, const char **argv) { 11877be168c0dSopenharmony_ci warm_up_loop_count = atoi(argv[7]); 11878be168c0dSopenharmony_ci if (warm_up_loop_count < 0) { 11879be168c0dSopenharmony_ci printf("The warm up loop count error! Cannot be less than 0.\n"); 11880be168c0dSopenharmony_ci- return kMSStatusLiteParamInvalid; 11881be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 11882be168c0dSopenharmony_ci } 11883be168c0dSopenharmony_ci } 11884be168c0dSopenharmony_ci printf("Running warm up loops..."); 11885be168c0dSopenharmony_ci for (int i = 0; i < warm_up_loop_count; ++i) { 11886be168c0dSopenharmony_ci- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11887be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11888be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11889be168c0dSopenharmony_ci- printf("MSModelPredict failed, ret: %d", ret); 11890be168c0dSopenharmony_ci+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11891be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11892be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11893be168c0dSopenharmony_ci+ printf("OH_AI_ModelPredict failed, ret: %d", ret); 11894be168c0dSopenharmony_ci return ret; 11895be168c0dSopenharmony_ci } 11896be168c0dSopenharmony_ci } 11897be168c0dSopenharmony_ci@@ -244,10 +244,10 @@ int main(int argc, const char **argv) { 11898be168c0dSopenharmony_ci printf("\nloop count: %d\n", loop_count); 11899be168c0dSopenharmony_ci uint64_t start_time = GetTimeUs(); 11900be168c0dSopenharmony_ci for (int i = 0; i < loop_count; ++i) { 11901be168c0dSopenharmony_ci- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11902be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11903be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11904be168c0dSopenharmony_ci- printf("MSModelPredict failed, ret: %d", ret); 11905be168c0dSopenharmony_ci+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11906be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11907be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11908be168c0dSopenharmony_ci+ printf("OH_AI_ModelPredict failed, ret: %d", ret); 11909be168c0dSopenharmony_ci return ret; 11910be168c0dSopenharmony_ci } 11911be168c0dSopenharmony_ci } 11912be168c0dSopenharmony_ci@@ -255,23 +255,23 @@ int main(int argc, const char **argv) { 11913be168c0dSopenharmony_ci float total_time = (float)(end_time - start_time) / 1000.0f; 11914be168c0dSopenharmony_ci printf("total time: %.5fms, per time: %.5fms\n", total_time, total_time / loop_count); 11915be168c0dSopenharmony_ci } 11916be168c0dSopenharmony_ci- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11917be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11918be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11919be168c0dSopenharmony_ci+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11920be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11921be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11922be168c0dSopenharmony_ci return ret; 11923be168c0dSopenharmony_ci } 11924be168c0dSopenharmony_ci printf("========run success=======\n"); 11925be168c0dSopenharmony_ci printf("\noutputs: \n"); 11926be168c0dSopenharmony_ci for (size_t i = 0; i < outputs_handle.handle_num; i++) { 11927be168c0dSopenharmony_ci- MSTensorHandle output = outputs_handle.handle_list[i]; 11928be168c0dSopenharmony_ci+ OH_AI_TensorHandle output = outputs_handle.handle_list[i]; 11929be168c0dSopenharmony_ci PrintTensorHandle(output); 11930be168c0dSopenharmony_ci } 11931be168c0dSopenharmony_ci if (argc >= 5) { 11932be168c0dSopenharmony_ci CalibTensor *calib_tensors; 11933be168c0dSopenharmony_ci int calib_num = 0; 11934be168c0dSopenharmony_ci ret = ReadCalibData(argv[4], &calib_tensors, &calib_num); 11935be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11936be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11937be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11938be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11939be168c0dSopenharmony_ci return ret; 11940be168c0dSopenharmony_ci } 11941be168c0dSopenharmony_ci float cosine_distance_threshold = 0.9999; 11942be168c0dSopenharmony_ci@@ -279,15 +279,15 @@ int main(int argc, const char **argv) { 11943be168c0dSopenharmony_ci cosine_distance_threshold = atof(argv[8]); 11944be168c0dSopenharmony_ci } 11945be168c0dSopenharmony_ci ret = CompareOutputs(outputs_handle, &calib_tensors, calib_num, cosine_distance_threshold); 11946be168c0dSopenharmony_ci- if (ret != kMSStatusSuccess) { 11947be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11948be168c0dSopenharmony_ci+ if (ret != OH_AI_STATUS_SUCCESS) { 11949be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11950be168c0dSopenharmony_ci return ret; 11951be168c0dSopenharmony_ci } 11952be168c0dSopenharmony_ci FreeCalibTensors(&calib_tensors, calib_num); 11953be168c0dSopenharmony_ci } 11954be168c0dSopenharmony_ci printf("========run success=======\n"); 11955be168c0dSopenharmony_ci- MSModelDestroy(&model_handle); 11956be168c0dSopenharmony_ci- return kMSStatusSuccess; 11957be168c0dSopenharmony_ci+ OH_AI_ModelDestroy(&model_handle); 11958be168c0dSopenharmony_ci+ return OH_AI_STATUS_SUCCESS; 11959be168c0dSopenharmony_ci } 11960be168c0dSopenharmony_ci )RAW"; 11961be168c0dSopenharmony_ci 11962be168c0dSopenharmony_ci@@ -385,7 +385,7 @@ int benchmark() { 11963be168c0dSopenharmony_ci return kMSStatusLiteError; 11964be168c0dSopenharmony_ci } 11965be168c0dSopenharmony_ci MSModelSetWorkspace(model_handle, g_WorkSpace, WORK_SPACE_SIZE); 11966be168c0dSopenharmony_ci- ret = MSModelBuild(model_handle, NULL, 0, kMSModelTypeMindIR, NULL); 11967be168c0dSopenharmony_ci+ ret = OH_AI_ModelBuild(model_handle, NULL, 0, kMSModelTypeMindIR, NULL); 11968be168c0dSopenharmony_ci if (ret != kMSStatusSuccess) { 11969be168c0dSopenharmony_ci printf("MSModelBuildFromFile failed, ret : %d.\n", ret); 11970be168c0dSopenharmony_ci MSModelDestroy(&model_handle); 11971be168c0dSopenharmony_ci@@ -424,7 +424,7 @@ int benchmark() { 11972be168c0dSopenharmony_ci } 11973be168c0dSopenharmony_ci 11974be168c0dSopenharmony_ci printf("========Infer start=======\n"); 11975be168c0dSopenharmony_ci- ret = MSModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11976be168c0dSopenharmony_ci+ ret = OH_AI_ModelPredict(model_handle, inputs_handle, &outputs_handle, NULL, NULL); 11977be168c0dSopenharmony_ci if (ret != kMSStatusSuccess) { 11978be168c0dSopenharmony_ci MSModelDestroy(&model_handle); 11979be168c0dSopenharmony_ci return ret; 11980be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc 11981be168c0dSopenharmony_ciindex 71ca2287..66af9069 100644 11982be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc 11983be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/calib_output.cc 11984be168c0dSopenharmony_ci@@ -48,7 +48,7 @@ typedef struct CalibTensor { 11985be168c0dSopenharmony_ci float *data_; 11986be168c0dSopenharmony_ci } CalibTensor; 11987be168c0dSopenharmony_ci int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensots, int *calib_num); 11988be168c0dSopenharmony_ci-int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 11989be168c0dSopenharmony_ci+int CompareOutputs(OH_AI_TensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 11990be168c0dSopenharmony_ci float cosine_distance_threshold); 11991be168c0dSopenharmony_ci void FreeCalibTensors(CalibTensor **calib_tensors, int calib_num); 11992be168c0dSopenharmony_ci 11993be168c0dSopenharmony_ci@@ -89,12 +89,12 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 11994be168c0dSopenharmony_ci FILE *file = fopen(calib_data_path, "r"); 11995be168c0dSopenharmony_ci if (!file) { 11996be168c0dSopenharmony_ci printf("Unable open %s", calib_data_path); 11997be168c0dSopenharmony_ci- return kMSStatusLiteError; 11998be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 11999be168c0dSopenharmony_ci } 12000be168c0dSopenharmony_ci CalibTensor *calib_tensors = (CalibTensor *)malloc(kMaxOutput * sizeof(CalibTensor)); 12001be168c0dSopenharmony_ci if(calib_tensors == NULL) { 12002be168c0dSopenharmony_ci printf("Malloc calib tensors failed."); 12003be168c0dSopenharmony_ci- return kMSStatusLiteError; 12004be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12005be168c0dSopenharmony_ci } 12006be168c0dSopenharmony_ci // read line by line 12007be168c0dSopenharmony_ci char line[kMaxTensorSize]; 12008be168c0dSopenharmony_ci@@ -111,7 +111,7 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 12009be168c0dSopenharmony_ci char* tensor_name = (char *)malloc(strlen(p)+1); 12010be168c0dSopenharmony_ci if(tensor_name == NULL) { 12011be168c0dSopenharmony_ci printf("Malloc tensor name failed."); 12012be168c0dSopenharmony_ci- return kMSStatusLiteError; 12013be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12014be168c0dSopenharmony_ci } 12015be168c0dSopenharmony_ci (void)strcpy(tensor_name, p); 12016be168c0dSopenharmony_ci calib_tensors[*calib_num].tensor_name = tensor_name; 12017be168c0dSopenharmony_ci@@ -134,7 +134,7 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 12018be168c0dSopenharmony_ci float *data = (float *)malloc(elements * sizeof(float)); 12019be168c0dSopenharmony_ci if(data == NULL) { 12020be168c0dSopenharmony_ci printf("Malloc tensor data failed."); 12021be168c0dSopenharmony_ci- return kMSStatusLiteError; 12022be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12023be168c0dSopenharmony_ci } 12024be168c0dSopenharmony_ci p = strtok(line, " "); 12025be168c0dSopenharmony_ci int k = 0; 12026be168c0dSopenharmony_ci@@ -152,43 +152,43 @@ int ReadCalibData(const char *calib_data_path, CalibTensor **calib_tensor_pointe 12027be168c0dSopenharmony_ci } 12028be168c0dSopenharmony_ci *calib_tensor_pointers = calib_tensors; 12029be168c0dSopenharmony_ci fclose(file); 12030be168c0dSopenharmony_ci- return kMSStatusSuccess; 12031be168c0dSopenharmony_ci+ return OH_AI_STATUS_SUCCESS; 12032be168c0dSopenharmony_ci } 12033be168c0dSopenharmony_ci 12034be168c0dSopenharmony_ci-int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 12035be168c0dSopenharmony_ci+int CompareOutputs(OH_AI_TensorHandleArray outputs, CalibTensor **calib_tensors, int calib_num, 12036be168c0dSopenharmony_ci float cosine_distance_threshold) { 12037be168c0dSopenharmony_ci if (outputs.handle_num != (size_t)calib_num) { 12038be168c0dSopenharmony_ci printf("error, outputs and calibs size is mismatch\n"); 12039be168c0dSopenharmony_ci- return kMSStatusLiteError; 12040be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12041be168c0dSopenharmony_ci } 12042be168c0dSopenharmony_ci size_t outputs_num = outputs.handle_num; 12043be168c0dSopenharmony_ci bool is_success = true; 12044be168c0dSopenharmony_ci for (size_t i = 0; i < outputs_num; ++i) { 12045be168c0dSopenharmony_ci MicroTensor *output = (MicroTensor *)outputs.handle_list[i]; 12046be168c0dSopenharmony_ci if (!output || !output->data) { 12047be168c0dSopenharmony_ci- return kMSStatusLiteError; 12048be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12049be168c0dSopenharmony_ci } 12050be168c0dSopenharmony_ci CalibTensor *calib = calib_tensors[0]; 12051be168c0dSopenharmony_ci if (!calib || !calib[i].data_) { 12052be168c0dSopenharmony_ci- return kMSStatusLiteError; 12053be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12054be168c0dSopenharmony_ci } 12055be168c0dSopenharmony_ci if (strcmp(output->name, calib[i].tensor_name) != 0) { 12056be168c0dSopenharmony_ci printf("warning, output tensor name is not equal to calib\n"); 12057be168c0dSopenharmony_ci } 12058be168c0dSopenharmony_ci- size_t elements = (size_t)MSTensorGetElementNum(output); 12059be168c0dSopenharmony_ci+ size_t elements = (size_t)OH_AI_TensorGetElementNum(output); 12060be168c0dSopenharmony_ci if (elements != (size_t)calib[i].elemets_num_) { 12061be168c0dSopenharmony_ci printf("error, output elements num is not equal to calib\n"); 12062be168c0dSopenharmony_ci- return kMSStatusLiteError; 12063be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12064be168c0dSopenharmony_ci } 12065be168c0dSopenharmony_ci float cosin = 0.f, dot = 0.f, normx = 0.f, normy = 0.f; 12066be168c0dSopenharmony_ci switch (output->type) { 12067be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeFloat32: { 12068be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT32: { 12069be168c0dSopenharmony_ci float *float_output = (float *)output->data; 12070be168c0dSopenharmony_ci for (size_t j = 0; j < elements; ++j) { 12071be168c0dSopenharmony_ci if (isnan(float_output[j]) || isinf(float_output[j]) || isnan(calib[i].data_[j]) || 12072be168c0dSopenharmony_ci isinf(calib[i].data_[j])) { 12073be168c0dSopenharmony_ci printf("error, output data is nan or inf\n"); 12074be168c0dSopenharmony_ci- return kMSStatusLiteError; 12075be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12076be168c0dSopenharmony_ci } 12077be168c0dSopenharmony_ci dot += float_output[j] * calib[i].data_[j]; 12078be168c0dSopenharmony_ci normx += float_output[j] * float_output[j]; 12079be168c0dSopenharmony_ci@@ -196,7 +196,7 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12080be168c0dSopenharmony_ci } 12081be168c0dSopenharmony_ci break; 12082be168c0dSopenharmony_ci } 12083be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt8: { 12084be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT8: { 12085be168c0dSopenharmony_ci int8_t *int_output = (int8_t *)output->data; 12086be168c0dSopenharmony_ci for (size_t j = 0; j < elements; ++j) { 12087be168c0dSopenharmony_ci dot += (float) (int_output[j] * calib[i].data_[j]); 12088be168c0dSopenharmony_ci@@ -205,7 +205,7 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12089be168c0dSopenharmony_ci } 12090be168c0dSopenharmony_ci break; 12091be168c0dSopenharmony_ci } 12092be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt8: { 12093be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT8: { 12094be168c0dSopenharmony_ci uint8_t *int_output = (uint8_t *)output->data; 12095be168c0dSopenharmony_ci for (size_t j = 0; j < elements; ++j) { 12096be168c0dSopenharmony_ci dot += (float) (int_output[j] * calib[i].data_[j]); 12097be168c0dSopenharmony_ci@@ -214,8 +214,8 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12098be168c0dSopenharmony_ci } 12099be168c0dSopenharmony_ci break; 12100be168c0dSopenharmony_ci } 12101be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt32: 12102be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt32: { 12103be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT32: 12104be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT32: { 12105be168c0dSopenharmony_ci int32_t *int_output = (int32_t *)output->data; 12106be168c0dSopenharmony_ci for (size_t j = 0; j < elements; ++j) { 12107be168c0dSopenharmony_ci dot += (float) (int_output[j] * calib[i].data_[j]); 12108be168c0dSopenharmony_ci@@ -238,10 +238,10 @@ int CompareOutputs(MSTensorHandleArray outputs, CalibTensor **calib_tensors, int 12109be168c0dSopenharmony_ci } 12110be168c0dSopenharmony_ci if (!is_success) { 12111be168c0dSopenharmony_ci printf("compare outputs failed.\n"); 12112be168c0dSopenharmony_ci- return kMSStatusLiteError; 12113be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12114be168c0dSopenharmony_ci } 12115be168c0dSopenharmony_ci printf("compare outputs success.\n"); 12116be168c0dSopenharmony_ci- return kMSStatusSuccess; 12117be168c0dSopenharmony_ci+ return OH_AI_STATUS_SUCCESS; 12118be168c0dSopenharmony_ci } 12119be168c0dSopenharmony_ci 12120be168c0dSopenharmony_ci void FreeCalibTensors(CalibTensor **calib_tensors_pointers, int calib_num) { 12121be168c0dSopenharmony_ci@@ -328,7 +328,7 @@ const char *calib_source_cortex = R"RAW(/** 12122be168c0dSopenharmony_ci int LoadCalibInputs(MSTensorHandleArray *inputs, TensorArray *tensor_array) { 12123be168c0dSopenharmony_ci if (inputs->handle_num != tensor_array->tensors_size_) { 12124be168c0dSopenharmony_ci printf("error, inputs and calibs size is mismatch.\n"); 12125be168c0dSopenharmony_ci- return kMSStatusLiteError; 12126be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12127be168c0dSopenharmony_ci } 12128be168c0dSopenharmony_ci Tensor *calib_tensors = tensor_array->tensors_; 12129be168c0dSopenharmony_ci if (calib_tensors == NULL) { 12130be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc 12131be168c0dSopenharmony_ciindex 79bfc485..f63e6f9e 100644 12132be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc 12133be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/cmake_lists.cc 12134be168c0dSopenharmony_ci@@ -127,9 +127,9 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") 12135be168c0dSopenharmony_ci set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") 12136be168c0dSopenharmony_ci else() 12137be168c0dSopenharmony_ci message(STATUS "build benchmark release version") 12138be168c0dSopenharmony_ci- set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12139be168c0dSopenharmony_ci+ set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12140be168c0dSopenharmony_ci -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}") 12141be168c0dSopenharmony_ci- set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12142be168c0dSopenharmony_ci+ set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12143be168c0dSopenharmony_ci -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}") 12144be168c0dSopenharmony_ci string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 12145be168c0dSopenharmony_ci string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 12146be168c0dSopenharmony_ci@@ -211,9 +211,9 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") 12147be168c0dSopenharmony_ci set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") 12148be168c0dSopenharmony_ci else() 12149be168c0dSopenharmony_ci message(STATUS "build net library release version") 12150be168c0dSopenharmony_ci- set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12151be168c0dSopenharmony_ci+ set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12152be168c0dSopenharmony_ci -Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}") 12153be168c0dSopenharmony_ci- set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \ 12154be168c0dSopenharmony_ci+ set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -fstack-protector-strong -Wno-attributes \ 12155be168c0dSopenharmony_ci -Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}") 12156be168c0dSopenharmony_ci string(REPLACE "-g" "" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") 12157be168c0dSopenharmony_ci string(REPLACE "-g" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") 12158be168c0dSopenharmony_ci@@ -241,11 +241,11 @@ function(create_library) 12159be168c0dSopenharmony_ci endforeach() 12160be168c0dSopenharmony_ci add_custom_command(TARGET net 12161be168c0dSopenharmony_ci POST_BUILD 12162be168c0dSopenharmony_ci- COMMAND ar cr ${library_name} *.o 12163be168c0dSopenharmony_ci+ COMMAND ar cr ${library_name} *.obj 12164be168c0dSopenharmony_ci COMMAND ranlib ${library_name} 12165be168c0dSopenharmony_ci COMMAND echo "new static library ${library_name} size:" 12166be168c0dSopenharmony_ci COMMAND ls -lh ${library_name} 12167be168c0dSopenharmony_ci- COMMAND rm -rf tmp && rm -rf *.o 12168be168c0dSopenharmony_ci+ COMMAND rm -rf tmp && rm -rf *.obj 12169be168c0dSopenharmony_ci COMMENT "generate specified static library ${library_name}" 12170be168c0dSopenharmony_ci ) 12171be168c0dSopenharmony_ci endfunction(create_library) 12172be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc 12173be168c0dSopenharmony_ciindex 9a2aeaa7..669cd8c1 100644 12174be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc 12175be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/load_input.cc 12176be168c0dSopenharmony_ci@@ -131,7 +131,7 @@ int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int input 12177be168c0dSopenharmony_ci while ((token = strtok_r(path, delim, &path))) { 12178be168c0dSopenharmony_ci if (i >= inputs_num) { 12179be168c0dSopenharmony_ci printf("inputs num is error, need: %d\n", inputs_num); 12180be168c0dSopenharmony_ci- return kMSStatusLiteParamInvalid; 12181be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 12182be168c0dSopenharmony_ci } 12183be168c0dSopenharmony_ci inputs_path[i] = token; 12184be168c0dSopenharmony_ci printf("input %d: %s\n", i, inputs_path[i]); 12185be168c0dSopenharmony_ci@@ -144,7 +144,7 @@ int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int input 12186be168c0dSopenharmony_ci if (size != inputs_size[i] || buffers[i] == NULL) { 12187be168c0dSopenharmony_ci printf("size mismatch, %s, input: %d, needed: %d\n", inputs_path[i], size, inputs_size[i]); 12188be168c0dSopenharmony_ci free(buffers[i]); 12189be168c0dSopenharmony_ci- return kMSStatusLiteError; 12190be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 12191be168c0dSopenharmony_ci } 12192be168c0dSopenharmony_ci } 12193be168c0dSopenharmony_ci return 0; 12194be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc 12195be168c0dSopenharmony_ciindex 856de855..d662e3a8 100644 12196be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc 12197be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mcontext.cc 12198be168c0dSopenharmony_ci@@ -73,24 +73,24 @@ const char context_source_cortex[] = R"RAW( 12199be168c0dSopenharmony_ci #include <stdlib.h> 12200be168c0dSopenharmony_ci #include <string.h> 12201be168c0dSopenharmony_ci 12202be168c0dSopenharmony_ci-MSContextHandle MSContextCreate() { 12203be168c0dSopenharmony_ci+OH_AI_ContextHandle OH_AI_ContextCreate() { 12204be168c0dSopenharmony_ci return NULL; 12205be168c0dSopenharmony_ci } 12206be168c0dSopenharmony_ci 12207be168c0dSopenharmony_ci-void MSContextDestroy(MSContextHandle *context) { 12208be168c0dSopenharmony_ci+void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 12209be168c0dSopenharmony_ci } 12210be168c0dSopenharmony_ci 12211be168c0dSopenharmony_ci-void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12212be168c0dSopenharmony_ci+void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) { 12213be168c0dSopenharmony_ci } 12214be168c0dSopenharmony_ci 12215be168c0dSopenharmony_ci-int32_t MSContextGetThreadNum(const MSContextHandle context) { 12216be168c0dSopenharmony_ci+int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 12217be168c0dSopenharmony_ci return 1; 12218be168c0dSopenharmony_ci } 12219be168c0dSopenharmony_ci 12220be168c0dSopenharmony_ci-void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12221be168c0dSopenharmony_ci+void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 12222be168c0dSopenharmony_ci } 12223be168c0dSopenharmony_ci 12224be168c0dSopenharmony_ci-int MSContextGetThreadAffinityMode(const MSContextHandle context) { 12225be168c0dSopenharmony_ci+int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 12226be168c0dSopenharmony_ci return 0; 12227be168c0dSopenharmony_ci } 12228be168c0dSopenharmony_ci )RAW"; 12229be168c0dSopenharmony_ci@@ -116,7 +116,7 @@ const char context_source_no_parallel[] = R"RAW( 12230be168c0dSopenharmony_ci #include <stdlib.h> 12231be168c0dSopenharmony_ci #include <string.h> 12232be168c0dSopenharmony_ci 12233be168c0dSopenharmony_ci-MSContextHandle MSContextCreate() { 12234be168c0dSopenharmony_ci+OH_AI_ContextHandle OH_AI_ContextCreate() { 12235be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)malloc(sizeof(MicroContext)); 12236be168c0dSopenharmony_ci if (micro_context == NULL) { 12237be168c0dSopenharmony_ci return NULL; 12238be168c0dSopenharmony_ci@@ -129,7 +129,7 @@ MSContextHandle MSContextCreate() { 12239be168c0dSopenharmony_ci return micro_context; 12240be168c0dSopenharmony_ci } 12241be168c0dSopenharmony_ci 12242be168c0dSopenharmony_ci-void MSContextDestroy(MSContextHandle *context) { 12243be168c0dSopenharmony_ci+void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 12244be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)(*context); 12245be168c0dSopenharmony_ci if (micro_context) { 12246be168c0dSopenharmony_ci free(micro_context); 12247be168c0dSopenharmony_ci@@ -137,17 +137,17 @@ void MSContextDestroy(MSContextHandle *context) { 12248be168c0dSopenharmony_ci } 12249be168c0dSopenharmony_ci } 12250be168c0dSopenharmony_ci 12251be168c0dSopenharmony_ci-void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12252be168c0dSopenharmony_ci+void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) { 12253be168c0dSopenharmony_ci } 12254be168c0dSopenharmony_ci 12255be168c0dSopenharmony_ci-int32_t MSContextGetThreadNum(const MSContextHandle context) { 12256be168c0dSopenharmony_ci+int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 12257be168c0dSopenharmony_ci return 1; 12258be168c0dSopenharmony_ci } 12259be168c0dSopenharmony_ci 12260be168c0dSopenharmony_ci-void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12261be168c0dSopenharmony_ci+void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 12262be168c0dSopenharmony_ci } 12263be168c0dSopenharmony_ci 12264be168c0dSopenharmony_ci-int MSContextGetThreadAffinityMode(const MSContextHandle context) { 12265be168c0dSopenharmony_ci+int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 12266be168c0dSopenharmony_ci return 0; 12267be168c0dSopenharmony_ci } 12268be168c0dSopenharmony_ci )RAW"; 12269be168c0dSopenharmony_ci@@ -176,7 +176,7 @@ const char context_source[] = R"RAW( 12270be168c0dSopenharmony_ci 12271be168c0dSopenharmony_ci #define MAX_THREAD_NUM 4 12272be168c0dSopenharmony_ci 12273be168c0dSopenharmony_ci-MSContextHandle MSContextCreate() { 12274be168c0dSopenharmony_ci+OH_AI_ContextHandle OH_AI_ContextCreate() { 12275be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)malloc(sizeof(MicroContext)); 12276be168c0dSopenharmony_ci if (micro_context == NULL) { 12277be168c0dSopenharmony_ci return NULL; 12278be168c0dSopenharmony_ci@@ -189,7 +189,7 @@ MSContextHandle MSContextCreate() { 12279be168c0dSopenharmony_ci return micro_context; 12280be168c0dSopenharmony_ci } 12281be168c0dSopenharmony_ci 12282be168c0dSopenharmony_ci-void MSContextDestroy(MSContextHandle *context) { 12283be168c0dSopenharmony_ci+void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 12284be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)(*context); 12285be168c0dSopenharmony_ci if (micro_context) { 12286be168c0dSopenharmony_ci if (micro_context->affinity_core_list_) { 12287be168c0dSopenharmony_ci@@ -201,7 +201,7 @@ void MSContextDestroy(MSContextHandle *context) { 12288be168c0dSopenharmony_ci } 12289be168c0dSopenharmony_ci } 12290be168c0dSopenharmony_ci 12291be168c0dSopenharmony_ci-void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12292be168c0dSopenharmony_ci+void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) { 12293be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)context; 12294be168c0dSopenharmony_ci if (micro_context) { 12295be168c0dSopenharmony_ci int core_num = GetCpuCoreNum(); 12296be168c0dSopenharmony_ci@@ -214,7 +214,7 @@ void MSContextSetThreadNum(MSContextHandle context, int32_t thread_num) { 12297be168c0dSopenharmony_ci } 12298be168c0dSopenharmony_ci } 12299be168c0dSopenharmony_ci 12300be168c0dSopenharmony_ci-int32_t MSContextGetThreadNum(const MSContextHandle context) { 12301be168c0dSopenharmony_ci+int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 12302be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)context; 12303be168c0dSopenharmony_ci if (micro_context) { 12304be168c0dSopenharmony_ci return micro_context->thread_num_; 12305be168c0dSopenharmony_ci@@ -222,7 +222,7 @@ int32_t MSContextGetThreadNum(const MSContextHandle context) { 12306be168c0dSopenharmony_ci return 0; 12307be168c0dSopenharmony_ci } 12308be168c0dSopenharmony_ci 12309be168c0dSopenharmony_ci-void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12310be168c0dSopenharmony_ci+void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 12311be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)context; 12312be168c0dSopenharmony_ci if (micro_context) { 12313be168c0dSopenharmony_ci if (mode >= 0 && mode <= 2) { 12314be168c0dSopenharmony_ci@@ -233,7 +233,7 @@ void MSContextSetThreadAffinityMode(MSContextHandle context, int mode) { 12315be168c0dSopenharmony_ci } 12316be168c0dSopenharmony_ci } 12317be168c0dSopenharmony_ci 12318be168c0dSopenharmony_ci-int MSContextGetThreadAffinityMode(const MSContextHandle context) { 12319be168c0dSopenharmony_ci+int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 12320be168c0dSopenharmony_ci MicroContext *micro_context = (MicroContext *)context; 12321be168c0dSopenharmony_ci if (micro_context) { 12322be168c0dSopenharmony_ci return micro_context->affinity_mode; 12323be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc 12324be168c0dSopenharmony_ciindex 44273071..5cbe4507 100644 12325be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc 12326be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/msession.cc 12327be168c0dSopenharmony_ci@@ -18,25 +18,25 @@ 12328be168c0dSopenharmony_ci 12329be168c0dSopenharmony_ci namespace mindspore::lite::micro { 12330be168c0dSopenharmony_ci const char model_runtime_other_source[] = R"RAW( 12331be168c0dSopenharmony_ci-MSTensorHandleArray MSModelGetInputs(const MSModelHandle model) { 12332be168c0dSopenharmony_ci+OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) { 12333be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 12334be168c0dSopenharmony_ci if (micro_model == NULL) { 12335be168c0dSopenharmony_ci- MSTensorHandleArray tmp = {0, NULL}; 12336be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray tmp = {0, NULL}; 12337be168c0dSopenharmony_ci return tmp; 12338be168c0dSopenharmony_ci } 12339be168c0dSopenharmony_ci return micro_model->inputs; 12340be168c0dSopenharmony_ci } 12341be168c0dSopenharmony_ci 12342be168c0dSopenharmony_ci-MSTensorHandleArray MSModelGetOutputs(const MSModelHandle model) { 12343be168c0dSopenharmony_ci+OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) { 12344be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 12345be168c0dSopenharmony_ci if (micro_model == NULL) { 12346be168c0dSopenharmony_ci- MSTensorHandleArray tmp = {0, NULL}; 12347be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray tmp = {0, NULL}; 12348be168c0dSopenharmony_ci return tmp; 12349be168c0dSopenharmony_ci } 12350be168c0dSopenharmony_ci return micro_model->outputs; 12351be168c0dSopenharmony_ci } 12352be168c0dSopenharmony_ci 12353be168c0dSopenharmony_ci-MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char *tensor_name) { 12354be168c0dSopenharmony_ci+OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 12355be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 12356be168c0dSopenharmony_ci if (micro_model == NULL || micro_model->inputs.handle_list == NULL) { 12357be168c0dSopenharmony_ci return NULL; 12358be168c0dSopenharmony_ci@@ -53,7 +53,7 @@ MSTensorHandle MSModelGetInputByTensorName(const MSModelHandle model, const char 12359be168c0dSopenharmony_ci return NULL; 12360be168c0dSopenharmony_ci } 12361be168c0dSopenharmony_ci 12362be168c0dSopenharmony_ci-MSTensorHandle MSModelGetOutputByTensorName(const MSModelHandle model, const char *tensor_name) { 12363be168c0dSopenharmony_ci+OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 12364be168c0dSopenharmony_ci MicroModel *micro_model = (MicroModel *)model; 12365be168c0dSopenharmony_ci if (micro_model == NULL || micro_model->outputs.handle_list == NULL) { 12366be168c0dSopenharmony_ci return NULL; 12367be168c0dSopenharmony_ci@@ -70,9 +70,16 @@ MSTensorHandle MSModelGetOutputByTensorName(const MSModelHandle model, const cha 12368be168c0dSopenharmony_ci return NULL; 12369be168c0dSopenharmony_ci } 12370be168c0dSopenharmony_ci 12371be168c0dSopenharmony_ci-MSStatus MSModelResize(MSModelHandle model, const MSTensorHandleArray inputs, MSShapeInfo *shape_infos, 12372be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos, 12373be168c0dSopenharmony_ci size_t shape_info_num) { 12374be168c0dSopenharmony_ci- return kMSStatusLiteNotSupport; 12375be168c0dSopenharmony_ci+ MicroModel *micro_model = (MicroModel *)model; 12376be168c0dSopenharmony_ci+ if (micro_model == NULL) { 12377be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 12378be168c0dSopenharmony_ci+ } 12379be168c0dSopenharmony_ci+ if (micro_model->resize == NULL) { 12380be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 12381be168c0dSopenharmony_ci+ } 12382be168c0dSopenharmony_ci+ return micro_model->resize(model, inputs, shape_infos, shape_info_num); 12383be168c0dSopenharmony_ci } 12384be168c0dSopenharmony_ci 12385be168c0dSopenharmony_ci )RAW"; 12386be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc 12387be168c0dSopenharmony_ciindex b125b31d..e4581829 100644 12388be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc 12389be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/const_blocks/mtensor.cc 12390be168c0dSopenharmony_ci@@ -46,8 +46,8 @@ const char tensor_header[] = R"RAW( 12391be168c0dSopenharmony_ci #endif 12392be168c0dSopenharmony_ci 12393be168c0dSopenharmony_ci typedef struct { 12394be168c0dSopenharmony_ci- enum MSDataType type; 12395be168c0dSopenharmony_ci- enum MSFormat format; 12396be168c0dSopenharmony_ci+ enum OH_AI_DataType type; 12397be168c0dSopenharmony_ci+ enum OH_AI_Format format; 12398be168c0dSopenharmony_ci char *name; 12399be168c0dSopenharmony_ci int ndim; 12400be168c0dSopenharmony_ci int64_t *shape; 12401be168c0dSopenharmony_ci@@ -76,7 +76,7 @@ enum TypeTransMode { 12402be168c0dSopenharmony_ci TypeTransMode_MAX = TypeTransMode_UNSUPPORT 12403be168c0dSopenharmony_ci }; 12404be168c0dSopenharmony_ci 12405be168c0dSopenharmony_ci-void *TransformInput(MSTensorHandle tensor, int expect_type, bool *type_changed); 12406be168c0dSopenharmony_ci+void *TransformInput(OH_AI_TensorHandle tensor, int expect_type, bool *type_changed); 12407be168c0dSopenharmony_ci 12408be168c0dSopenharmony_ci #ifdef ENABLE_FP16 12409be168c0dSopenharmony_ci void Fp32CastToFp16(const float *input, float16_t *output, int number); 12410be168c0dSopenharmony_ci@@ -109,37 +109,37 @@ const char tensor_source[] = R"RAW( 12411be168c0dSopenharmony_ci #include "string.h" 12412be168c0dSopenharmony_ci #include "tensor.h" 12413be168c0dSopenharmony_ci 12414be168c0dSopenharmony_ci-size_t DataTypeSize(const MSDataType type) { 12415be168c0dSopenharmony_ci+size_t DataTypeSize(const OH_AI_DataType type) { 12416be168c0dSopenharmony_ci switch (type) { 12417be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeFloat64: 12418be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT64: 12419be168c0dSopenharmony_ci return sizeof(double); 12420be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeFloat32: 12421be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT32: 12422be168c0dSopenharmony_ci return sizeof(float); 12423be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt8: 12424be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT8: 12425be168c0dSopenharmony_ci return sizeof(int8_t); 12426be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt8: 12427be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT8: 12428be168c0dSopenharmony_ci return sizeof(uint8_t); 12429be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeFloat16: 12430be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt16: 12431be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_FLOAT16: 12432be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT16: 12433be168c0dSopenharmony_ci return sizeof(int16_t); 12434be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt32: 12435be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT32: 12436be168c0dSopenharmony_ci return sizeof(int32_t); 12437be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeInt64: 12438be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_INT64: 12439be168c0dSopenharmony_ci return sizeof(int64_t); 12440be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt16: 12441be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT16: 12442be168c0dSopenharmony_ci return sizeof(uint16_t); 12443be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt32: 12444be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT32: 12445be168c0dSopenharmony_ci return sizeof(uint32_t); 12446be168c0dSopenharmony_ci- case kMSDataTypeNumberTypeUInt64: 12447be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_NUMBERTYPE_UINT64: 12448be168c0dSopenharmony_ci return sizeof(uint64_t); 12449be168c0dSopenharmony_ci- case kMSDataTypeObjectTypeString: 12450be168c0dSopenharmony_ci+ case OH_AI_DATATYPE_OBJECTTYPE_STRING: 12451be168c0dSopenharmony_ci return sizeof(char); 12452be168c0dSopenharmony_ci default: 12453be168c0dSopenharmony_ci return 0; 12454be168c0dSopenharmony_ci } 12455be168c0dSopenharmony_ci } 12456be168c0dSopenharmony_ci 12457be168c0dSopenharmony_ci-MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t *shape, size_t shape_num, 12458be168c0dSopenharmony_ci+OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num, 12459be168c0dSopenharmony_ci const void *data, size_t data_len) { 12460be168c0dSopenharmony_ci size_t data_type_len = DataTypeSize(type); 12461be168c0dSopenharmony_ci size_t acc_sum = 1; 12462be168c0dSopenharmony_ci@@ -160,16 +160,16 @@ MSTensorHandle MSTensorCreate(const char *name, MSDataType type, const int64_t * 12463be168c0dSopenharmony_ci memcpy(micro_tensor->data, data, data_len); 12464be168c0dSopenharmony_ci micro_tensor->shape = malloc(shape_num * sizeof(int64_t)); 12465be168c0dSopenharmony_ci memcpy(micro_tensor->shape, shape, shape_num * sizeof(int64_t)); 12466be168c0dSopenharmony_ci- micro_tensor->format = kMSFormatNHWC; 12467be168c0dSopenharmony_ci+ micro_tensor->format = OH_AI_FORMAT_NHWC; 12468be168c0dSopenharmony_ci return micro_tensor; 12469be168c0dSopenharmony_ci } 12470be168c0dSopenharmony_ci 12471be168c0dSopenharmony_ci-void MSTensorDestroy(MSTensorHandle *tensor) { 12472be168c0dSopenharmony_ci+void OH_AI_TensorDestroy(OH_AI_TensorHandle *tensor) { 12473be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(*tensor); 12474be168c0dSopenharmony_ci free(micro_tensor); 12475be168c0dSopenharmony_ci } 12476be168c0dSopenharmony_ci 12477be168c0dSopenharmony_ci-void MSTensorSetName(MSTensorHandle tensor, const char *name) { 12478be168c0dSopenharmony_ci+void OH_AI_TensorSetName(OH_AI_TensorHandle tensor, const char *name) { 12479be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12480be168c0dSopenharmony_ci if(micro_tensor->name != NULL) { 12481be168c0dSopenharmony_ci free(micro_tensor->name); 12482be168c0dSopenharmony_ci@@ -179,10 +179,10 @@ void MSTensorSetName(MSTensorHandle tensor, const char *name) { 12483be168c0dSopenharmony_ci memcpy(micro_tensor->name, name, len + 1); 12484be168c0dSopenharmony_ci } 12485be168c0dSopenharmony_ci 12486be168c0dSopenharmony_ci-MSTensorHandle MSTensorClone(MSTensorHandle tensor) { 12487be168c0dSopenharmony_ci+OH_AI_TensorHandle OH_AI_TensorClone(OH_AI_TensorHandle tensor) { 12488be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12489be168c0dSopenharmony_ci MicroTensor *clone_tensor = malloc( sizeof(MicroTensor)); 12490be168c0dSopenharmony_ci- size_t tensor_data_size = MSTensorGetDataSize(micro_tensor); 12491be168c0dSopenharmony_ci+ size_t tensor_data_size = OH_AI_TensorGetDataSize(micro_tensor); 12492be168c0dSopenharmony_ci clone_tensor->data = malloc(tensor_data_size); 12493be168c0dSopenharmony_ci clone_tensor->owned = true; 12494be168c0dSopenharmony_ci memcpy(clone_tensor->data,micro_tensor->data,tensor_data_size); 12495be168c0dSopenharmony_ci@@ -195,26 +195,26 @@ MSTensorHandle MSTensorClone(MSTensorHandle tensor) { 12496be168c0dSopenharmony_ci clone_tensor->shape = clone_shape; 12497be168c0dSopenharmony_ci char* clone_name = malloc(strlen(micro_tensor->name)); 12498be168c0dSopenharmony_ci strcpy(clone_name,micro_tensor->name); 12499be168c0dSopenharmony_ci- clone_tensor->format = kMSFormatNHWC; 12500be168c0dSopenharmony_ci+ clone_tensor->format = OH_AI_FORMAT_NHWC; 12501be168c0dSopenharmony_ci return clone_tensor; 12502be168c0dSopenharmony_ci } 12503be168c0dSopenharmony_ci 12504be168c0dSopenharmony_ci-const char *MSTensorGetName(const MSTensorHandle tensor) { 12505be168c0dSopenharmony_ci+const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) { 12506be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12507be168c0dSopenharmony_ci return micro_tensor->name; 12508be168c0dSopenharmony_ci } 12509be168c0dSopenharmony_ci 12510be168c0dSopenharmony_ci-void MSTensorSetDataType(MSTensorHandle tensor, MSDataType type) { 12511be168c0dSopenharmony_ci+void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) { 12512be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12513be168c0dSopenharmony_ci micro_tensor->type = type; 12514be168c0dSopenharmony_ci } 12515be168c0dSopenharmony_ci 12516be168c0dSopenharmony_ci-MSDataType MSTensorGetDataType(const MSTensorHandle tensor) { 12517be168c0dSopenharmony_ci+OH_AI_DataType OH_AI_TensorGetDataType(const OH_AI_TensorHandle tensor) { 12518be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12519be168c0dSopenharmony_ci return micro_tensor->type; 12520be168c0dSopenharmony_ci } 12521be168c0dSopenharmony_ci 12522be168c0dSopenharmony_ci-void MSTensorSetShape(MSTensorHandle tensor, const int64_t *shape, size_t shape_num) { 12523be168c0dSopenharmony_ci+void OH_AI_TensorSetShape(OH_AI_TensorHandle tensor, const int64_t *shape, size_t shape_num) { 12524be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12525be168c0dSopenharmony_ci if(micro_tensor->shape != NULL) { 12526be168c0dSopenharmony_ci free(micro_tensor->shape); 12527be168c0dSopenharmony_ci@@ -224,23 +224,23 @@ void MSTensorSetShape(MSTensorHandle tensor, const int64_t *shape, size_t shape_ 12528be168c0dSopenharmony_ci memcpy(micro_tensor->shape, shape, shape_num * sizeof(int64_t)); 12529be168c0dSopenharmony_ci } 12530be168c0dSopenharmony_ci 12531be168c0dSopenharmony_ci-const int64_t *MSTensorGetShape(const MSTensorHandle tensor, size_t *shape_num) { 12532be168c0dSopenharmony_ci+const int64_t *OH_AI_TensorGetShape(const OH_AI_TensorHandle tensor, size_t *shape_num) { 12533be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12534be168c0dSopenharmony_ci *shape_num = micro_tensor->ndim; 12535be168c0dSopenharmony_ci return micro_tensor->shape; 12536be168c0dSopenharmony_ci } 12537be168c0dSopenharmony_ci 12538be168c0dSopenharmony_ci-void MSTensorSetFormat(MSTensorHandle tensor, MSFormat format) { 12539be168c0dSopenharmony_ci+void OH_AI_TensorSetFormat(OH_AI_TensorHandle tensor, OH_AI_Format format) { 12540be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12541be168c0dSopenharmony_ci micro_tensor->format = format; 12542be168c0dSopenharmony_ci } 12543be168c0dSopenharmony_ci 12544be168c0dSopenharmony_ci-MSFormat MSTensorGetFormat(const MSTensorHandle tensor) { 12545be168c0dSopenharmony_ci+OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor) { 12546be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12547be168c0dSopenharmony_ci return micro_tensor->format; 12548be168c0dSopenharmony_ci } 12549be168c0dSopenharmony_ci 12550be168c0dSopenharmony_ci-void MSTensorSetData(MSTensorHandle tensor, void *data) { 12551be168c0dSopenharmony_ci+void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) { 12552be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12553be168c0dSopenharmony_ci if (micro_tensor->data == data) { 12554be168c0dSopenharmony_ci return; 12555be168c0dSopenharmony_ci@@ -254,23 +254,23 @@ void MSTensorSetData(MSTensorHandle tensor, void *data) { 12556be168c0dSopenharmony_ci micro_tensor->data = data; 12557be168c0dSopenharmony_ci } 12558be168c0dSopenharmony_ci 12559be168c0dSopenharmony_ci-const void *MSTensorGetData(const MSTensorHandle tensor) { 12560be168c0dSopenharmony_ci+const void *OH_AI_TensorGetData(const OH_AI_TensorHandle tensor) { 12561be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12562be168c0dSopenharmony_ci return micro_tensor->data; 12563be168c0dSopenharmony_ci } 12564be168c0dSopenharmony_ci 12565be168c0dSopenharmony_ci-void *MSTensorGetMutableData(const MSTensorHandle tensor) { 12566be168c0dSopenharmony_ci+void *OH_AI_TensorGetMutableData(const OH_AI_TensorHandle tensor) { 12567be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12568be168c0dSopenharmony_ci if(micro_tensor->data) { 12569be168c0dSopenharmony_ci return micro_tensor->data; 12570be168c0dSopenharmony_ci } 12571be168c0dSopenharmony_ci- void* data = malloc(MSTensorGetDataSize(tensor)); 12572be168c0dSopenharmony_ci+ void* data = malloc(OH_AI_TensorGetDataSize(tensor)); 12573be168c0dSopenharmony_ci micro_tensor->owned = true; 12574be168c0dSopenharmony_ci micro_tensor->data = data; 12575be168c0dSopenharmony_ci return data; 12576be168c0dSopenharmony_ci } 12577be168c0dSopenharmony_ci 12578be168c0dSopenharmony_ci-int64_t MSTensorGetElementNum(const MSTensorHandle tensor) { 12579be168c0dSopenharmony_ci+int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor) { 12580be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12581be168c0dSopenharmony_ci int64_t acc_sum = 1; 12582be168c0dSopenharmony_ci for(int i=0;i< micro_tensor->ndim;i++) { 12583be168c0dSopenharmony_ci@@ -279,10 +279,10 @@ int64_t MSTensorGetElementNum(const MSTensorHandle tensor) { 12584be168c0dSopenharmony_ci return acc_sum; 12585be168c0dSopenharmony_ci } 12586be168c0dSopenharmony_ci 12587be168c0dSopenharmony_ci-size_t MSTensorGetDataSize(const MSTensorHandle tensor) { 12588be168c0dSopenharmony_ci+size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) { 12589be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12590be168c0dSopenharmony_ci size_t data_type_size = DataTypeSize(micro_tensor->type); 12591be168c0dSopenharmony_ci- int64_t elements = MSTensorGetElementNum(tensor); 12592be168c0dSopenharmony_ci+ int64_t elements = OH_AI_TensorGetElementNum(tensor); 12593be168c0dSopenharmony_ci return data_type_size * elements; 12594be168c0dSopenharmony_ci } 12595be168c0dSopenharmony_ci 12596be168c0dSopenharmony_ci@@ -300,16 +300,16 @@ void Fp16CastToFp32(const float16_t *input, float *output, int number) { 12597be168c0dSopenharmony_ci } 12598be168c0dSopenharmony_ci #endif 12599be168c0dSopenharmony_ci 12600be168c0dSopenharmony_ci-void *TransformInput(MSTensorHandle tensor, int expect_type, bool *type_changed) { 12601be168c0dSopenharmony_ci+void *TransformInput(OH_AI_TensorHandle tensor, int expect_type, bool *type_changed) { 12602be168c0dSopenharmony_ci MicroTensor* micro_tensor = (MicroTensor*)(tensor); 12603be168c0dSopenharmony_ci int cur_type = micro_tensor->type; 12604be168c0dSopenharmony_ci if (cur_type == expect_type) { 12605be168c0dSopenharmony_ci return micro_tensor->data; 12606be168c0dSopenharmony_ci } 12607be168c0dSopenharmony_ci int type_trans_mode = TypeTransMode_MAX; 12608be168c0dSopenharmony_ci- if (expect_type == kMSDataTypeNumberTypeFloat16 && cur_type == kMSDataTypeNumberTypeFloat32) { 12609be168c0dSopenharmony_ci+ if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT16 && cur_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32) { 12610be168c0dSopenharmony_ci type_trans_mode = TypeTransMode_FP32_TO_FP16; 12611be168c0dSopenharmony_ci- } else if (expect_type == kMSDataTypeNumberTypeFloat32 && cur_type == kMSDataTypeNumberTypeFloat16) { 12612be168c0dSopenharmony_ci+ } else if (expect_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT32 && cur_type == OH_AI_DATATYPE_NUMBERTYPE_FLOAT16) { 12613be168c0dSopenharmony_ci type_trans_mode = TypeTransMode_FP16_TO_FP32; 12614be168c0dSopenharmony_ci } 12615be168c0dSopenharmony_ci if (type_trans_mode == TypeTransMode_UNSUPPORT) { 12616be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 12617be168c0dSopenharmony_ciindex ac958750..6a131b52 100644 12618be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 12619be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/component/weight_component.cc 12620be168c0dSopenharmony_ci@@ -61,6 +61,8 @@ void CodeWeightFileHeader(std::ofstream &ofs, const std::unique_ptr<CoderContext 12621be168c0dSopenharmony_ci << "#include <string.h>\n" 12622be168c0dSopenharmony_ci << "extern unsigned char *" << ctx->buffer_name() << ";\n" 12623be168c0dSopenharmony_ci << "extern uint8_t *" << ctx->weight_name() << ";\n" 12624be168c0dSopenharmony_ci+ << "extern int *" << kShapePrefixName << ";\n" 12625be168c0dSopenharmony_ci+ << "extern int *" << kOffsetPrefixName << ";\n" 12626be168c0dSopenharmony_ci << "enum STATUS {\n" 12627be168c0dSopenharmony_ci " RET_OK = 0,\n" 12628be168c0dSopenharmony_ci " RET_ERROR = 1,\n" 12629be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 12630be168c0dSopenharmony_ciindex dd66c333..23009e17 100644 12631be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 12632be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/generator/generator.cc 12633be168c0dSopenharmony_ci@@ -43,20 +43,28 @@ const char micro_model_define_source[] = R"RAW( 12634be168c0dSopenharmony_ci typedef struct { 12635be168c0dSopenharmony_ci void *runtime_buffer; 12636be168c0dSopenharmony_ci bool train_mode; // true: train mode, false: eval mode 12637be168c0dSopenharmony_ci- MSTensorHandleArray inputs; 12638be168c0dSopenharmony_ci- MSTensorHandleArray outputs; 12639be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray inputs; 12640be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray outputs; 12641be168c0dSopenharmony_ci ModelBuild build; 12642be168c0dSopenharmony_ci+ ModelResize resize; 12643be168c0dSopenharmony_ci ModelSetWorkspace set_work_space; 12644be168c0dSopenharmony_ci ModelCalcWorkspaceSize calc_work_space; 12645be168c0dSopenharmony_ci FreeResource free_resource; 12646be168c0dSopenharmony_ci )RAW"; 12647be168c0dSopenharmony_ci 12648be168c0dSopenharmony_ci const char set_workspace_state[] = R"RAW( 12649be168c0dSopenharmony_ci-typedef void (*ModelSetWorkspace)(MSModelHandle model, void *workspace, size_t workspace_size); 12650be168c0dSopenharmony_ci+typedef void (*ModelSetWorkspace)(OH_AI_ModelHandle model, void *workspace, size_t workspace_size); 12651be168c0dSopenharmony_ci )RAW"; 12652be168c0dSopenharmony_ci 12653be168c0dSopenharmony_ci const char calc_workspace_state[] = R"RAW( 12654be168c0dSopenharmony_ci-typedef size_t (*ModelCalcWorkspaceSize)(MSModelHandle model); 12655be168c0dSopenharmony_ci+typedef size_t (*ModelCalcWorkspaceSize)(OH_AI_ModelHandle model); 12656be168c0dSopenharmony_ci+)RAW"; 12657be168c0dSopenharmony_ci+ 12658be168c0dSopenharmony_ci+const char model_resize[] = R"RAW( 12659be168c0dSopenharmony_ci+typedef OH_AI_Status (*ModelResize)(OH_AI_ModelHandle model, 12660be168c0dSopenharmony_ci+ const OH_AI_TensorHandleArray inputs, 12661be168c0dSopenharmony_ci+ OH_AI_ShapeInfo *shape_infos, 12662be168c0dSopenharmony_ci+ size_t shape_info_num); 12663be168c0dSopenharmony_ci )RAW"; 12664be168c0dSopenharmony_ci 12665be168c0dSopenharmony_ci int WriteContentToFile(const std::string &file, const std::string &content) { 12666be168c0dSopenharmony_ci@@ -311,6 +319,7 @@ int Generator::CodeCommonModelFile() { 12667be168c0dSopenharmony_ci CodeFreeResourceState(hofs); 12668be168c0dSopenharmony_ci hofs << set_workspace_state; 12669be168c0dSopenharmony_ci hofs << calc_workspace_state; 12670be168c0dSopenharmony_ci+ hofs << model_resize; 12671be168c0dSopenharmony_ci hofs << micro_model_define_source; 12672be168c0dSopenharmony_ci if (config_->code_mode() == CodeMode::Inference) { 12673be168c0dSopenharmony_ci hofs << " ModelPredict predict;\n"; 12674be168c0dSopenharmony_ci@@ -321,7 +330,7 @@ int Generator::CodeCommonModelFile() { 12675be168c0dSopenharmony_ci } 12676be168c0dSopenharmony_ci hofs << "} MicroModel;\n"; 12677be168c0dSopenharmony_ci 12678be168c0dSopenharmony_ci- hofs << "void MSTensorHandleArrayDestroy(MSTensorHandleArray inputs);\n"; 12679be168c0dSopenharmony_ci+ hofs << "void MSTensorHandleArrayDestroy(OH_AI_TensorHandleArray inputs);\n"; 12680be168c0dSopenharmony_ci hofs << "#endif // MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_MODEL_H_\n\n"; 12681be168c0dSopenharmony_ci 12682be168c0dSopenharmony_ci // model source file 12683be168c0dSopenharmony_ci@@ -340,7 +349,7 @@ int Generator::CodeCommonModelFile() { 12684be168c0dSopenharmony_ci if (config_->support_parallel()) { 12685be168c0dSopenharmony_ci cofs << "#include \"" << kThreadWrapper << "\"\n"; 12686be168c0dSopenharmony_ci } 12687be168c0dSopenharmony_ci- if (config_->target() != kCortex_M) { 12688be168c0dSopenharmony_ci+ if (config_->target() != kCortex_M && !config_->dynamic_shape()) { 12689be168c0dSopenharmony_ci cofs << "#include \"src/allocator.h\"\n"; 12690be168c0dSopenharmony_ci } 12691be168c0dSopenharmony_ci CodeMSModelCalcWorkspaceSize(cofs, ctx_, *config_); 12692be168c0dSopenharmony_ci@@ -369,7 +378,7 @@ int Generator::CodeModelHandleHFile() { 12693be168c0dSopenharmony_ci "#define MINDSPORE_LITE_MICRO_LIBRARY_INCLUDE_MODEL_HANDLE_H_\n\n" 12694be168c0dSopenharmony_ci << "#include \"c_api/model_c.h\"\n\n"; 12695be168c0dSopenharmony_ci for (int i = 0; i <= ctx_->GetCurModelIndex(); ++i) { 12696be168c0dSopenharmony_ci- ofs << "extern MSModelHandle model" << std::to_string(i) << "; // " << ctx_->model_name() << "\n"; 12697be168c0dSopenharmony_ci+ ofs << "extern OH_AI_ModelHandle model" << std::to_string(i) << "; // " << ctx_->model_name() << "\n"; 12698be168c0dSopenharmony_ci } 12699be168c0dSopenharmony_ci ofs << "\n#endif // MINDSPORE_LITE_MICRO_LIBRARY_INCLUDE_MODEL_HANDLE_H_\n"; 12700be168c0dSopenharmony_ci return RET_OK; 12701be168c0dSopenharmony_ci@@ -386,7 +395,7 @@ int Generator::CodeMSModelImplement() { 12702be168c0dSopenharmony_ci ofs << "#include \"c_api/model_c.h\"\n"; 12703be168c0dSopenharmony_ci ofs << "#include \"src/model.h\"\n"; 12704be168c0dSopenharmony_ci ofs << "#include \"src/model" << ctx_->GetCurModelIndex() << "/" << net_inc_hfile_ << "\"\n"; 12705be168c0dSopenharmony_ci- if (config_->target() != kCortex_M) { 12706be168c0dSopenharmony_ci+ if (config_->target() != kCortex_M && !config_->dynamic_shape()) { 12707be168c0dSopenharmony_ci ofs << "#include \"src/allocator.h\"\n"; 12708be168c0dSopenharmony_ci } 12709be168c0dSopenharmony_ci if (config_->support_parallel()) { 12710be168c0dSopenharmony_ci@@ -399,33 +408,37 @@ int Generator::CodeMSModelImplement() { 12711be168c0dSopenharmony_ci ofs << "#define GRAPH_OUTPUTS_SIZE " << ctx_->graph_outputs().size() << "\n"; 12712be168c0dSopenharmony_ci ofs << "#define WEIGHT_BUF_SIZE " << ctx_->weight_buffer_size() << "\n"; 12713be168c0dSopenharmony_ci } 12714be168c0dSopenharmony_ci- ofs << "MSStatus MSModelBuild" << ctx_->GetCurModelIndex() << "(MSModelHandle model, const void *model_data,\n" 12715be168c0dSopenharmony_ci- << " size_t data_size, const MSContextHandle model_context);\n"; 12716be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelBuild" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model, const void *model_data,\n" 12717be168c0dSopenharmony_ci+ << " size_t data_size, const OH_AI_ContextHandle model_context);\n"; 12718be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelResize" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model, \n" 12719be168c0dSopenharmony_ci+ << " const OH_AI_TensorHandleArray inputs, OH_AI_ShapeInfo *shape_infos, size_t shape_info_num);\n"; 12720be168c0dSopenharmony_ci if (config_->code_mode() == CodeMode::Inference) { 12721be168c0dSopenharmony_ci- ofs << "MSStatus MSModelPredict" << ctx_->GetCurModelIndex() 12722be168c0dSopenharmony_ci- << "(MSModelHandle model, const MSTensorHandleArray inputs,\n" 12723be168c0dSopenharmony_ci- << " MSTensorHandleArray *output,\n" 12724be168c0dSopenharmony_ci- << " const MSKernelCallBackC before,\n" 12725be168c0dSopenharmony_ci- << " const MSKernelCallBackC after);\n"; 12726be168c0dSopenharmony_ci+ ofs << "OH_AI_Status OH_AI_ModelPredict" << ctx_->GetCurModelIndex() 12727be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,\n" 12728be168c0dSopenharmony_ci+ << " OH_AI_TensorHandleArray *output,\n" 12729be168c0dSopenharmony_ci+ << " const OH_AI_KernelCallBack before,\n" 12730be168c0dSopenharmony_ci+ << " const OH_AI_KernelCallBack after);\n"; 12731be168c0dSopenharmony_ci } else { 12732be168c0dSopenharmony_ci- ofs << "MSStatus MSModelRunStep" << ctx_->GetCurModelIndex() 12733be168c0dSopenharmony_ci- << "(MSModelHandle model,\n" 12734be168c0dSopenharmony_ci- " const MSKernelCallBackC before,\n" 12735be168c0dSopenharmony_ci- " const MSKernelCallBackC after);\n"; 12736be168c0dSopenharmony_ci- ofs << "MSStatus MSModelSetTrainMode" << ctx_->GetCurModelIndex() << "(MSModelHandle model, bool train);\n"; 12737be168c0dSopenharmony_ci- ofs << "MSStatus MSModelExportWeight" << ctx_->GetCurModelIndex() 12738be168c0dSopenharmony_ci- << "(MSModelHandle model, const char *export_path);\n"; 12739be168c0dSopenharmony_ci- } 12740be168c0dSopenharmony_ci+ ofs << "OH_AI_Status MSModelRunStep" << ctx_->GetCurModelIndex() 12741be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model,\n" 12742be168c0dSopenharmony_ci+ " const OH_AI_KernelCallBack before,\n" 12743be168c0dSopenharmony_ci+ " const OH_AI_KernelCallBack after);\n"; 12744be168c0dSopenharmony_ci+ ofs << "OH_AI_Status MSModelSetTrainMode" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model, bool train);\n"; 12745be168c0dSopenharmony_ci+ ofs << "OH_AI_Status MSModelExportWeight" << ctx_->GetCurModelIndex() 12746be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, const char *export_path);\n"; 12747be168c0dSopenharmony_ci+ } 12748be168c0dSopenharmony_ci+ ofs << "void Reset" << ctx_->GetCurModelIndex() << "();\n"; 12749be168c0dSopenharmony_ci ofs << "void MSModelSetWorkspace" << ctx_->GetCurModelIndex() 12750be168c0dSopenharmony_ci- << "(MSModelHandle model, void *workspace, size_t workspace_size);\n"; 12751be168c0dSopenharmony_ci- ofs << "size_t MSModelCalcWorkspaceSize" << ctx_->GetCurModelIndex() << "(MSModelHandle model);\n"; 12752be168c0dSopenharmony_ci+ << "(OH_AI_ModelHandle model, void *workspace, size_t workspace_size);\n"; 12753be168c0dSopenharmony_ci+ ofs << "size_t MSModelCalcWorkspaceSize" << ctx_->GetCurModelIndex() << "(OH_AI_ModelHandle model);\n"; 12754be168c0dSopenharmony_ci ofs << "static MicroModel gModel" << ctx_->GetCurModelIndex() << " = {.runtime_buffer = NULL,\n" 12755be168c0dSopenharmony_ci << " .train_mode = false,\n" 12756be168c0dSopenharmony_ci << " .inputs = {" << ctx_->graph_inputs().size() << ", NULL},\n" 12757be168c0dSopenharmony_ci << " .outputs = {" << ctx_->graph_outputs().size() << ", NULL},\n" 12758be168c0dSopenharmony_ci- << " .build = MSModelBuild" << ctx_->GetCurModelIndex() << ",\n"; 12759be168c0dSopenharmony_ci+ << " .build = OH_AI_ModelBuild" << ctx_->GetCurModelIndex() << ",\n" 12760be168c0dSopenharmony_ci+ << " .resize = OH_AI_ModelResize" << ctx_->GetCurModelIndex() << ",\n"; 12761be168c0dSopenharmony_ci if (config_->code_mode() == CodeMode::Inference) { 12762be168c0dSopenharmony_ci- ofs << " .predict = MSModelPredict" << ctx_->GetCurModelIndex() << ",\n"; 12763be168c0dSopenharmony_ci+ ofs << " .predict = OH_AI_ModelPredict" << ctx_->GetCurModelIndex() << ",\n"; 12764be168c0dSopenharmony_ci } else { 12765be168c0dSopenharmony_ci ofs << " .run_step = MSModelRunStep" << ctx_->GetCurModelIndex() << ",\n" 12766be168c0dSopenharmony_ci << " .set_train_mode = MSModelSetTrainMode" << ctx_->GetCurModelIndex() << ",\n" 12767be168c0dSopenharmony_ci@@ -439,11 +452,16 @@ int Generator::CodeMSModelImplement() { 12768be168c0dSopenharmony_ci ofs << " .set_work_space = NULL,\n" 12769be168c0dSopenharmony_ci << " .calc_work_space = NULL,\n"; 12770be168c0dSopenharmony_ci } 12771be168c0dSopenharmony_ci- ofs << " .free_resource = FreeResource" << ctx_->GetCurModelIndex() << "};\n"; 12772be168c0dSopenharmony_ci- ofs << "MSModelHandle model" << ctx_->GetCurModelIndex() << " = &gModel" << ctx_->GetCurModelIndex() << ";\n\n"; 12773be168c0dSopenharmony_ci- 12774be168c0dSopenharmony_ci+ ofs << " .free_resource = Reset" << ctx_->GetCurModelIndex() << "};\n"; 12775be168c0dSopenharmony_ci+ ofs << "OH_AI_ModelHandle model" << ctx_->GetCurModelIndex() << " = &gModel" << ctx_->GetCurModelIndex() << ";\n\n"; 12776be168c0dSopenharmony_ci+ auto &dynamic_symbols = config_->dynamic_symbols(); 12777be168c0dSopenharmony_ci+ for (size_t i = 0; i < dynamic_symbols.size(); ++i) { 12778be168c0dSopenharmony_ci+ ofs << "static int store" << ctx_->GetCurModelIndex() << "_" << i << " = -1;\n"; 12779be168c0dSopenharmony_ci+ } 12780be168c0dSopenharmony_ci+ CodeResetImplement(ofs, ctx_, *config_); 12781be168c0dSopenharmony_ci CodeMSModelCreate(ofs, ctx_, *config_); 12782be168c0dSopenharmony_ci CodeMSModelBuild(ofs, ctx_->GetCurModelIndex(), weight_size_, *config_); 12783be168c0dSopenharmony_ci+ CodeMSModelResize(ofs, ctx_, *config_); 12784be168c0dSopenharmony_ci CodeCopyOutputsImplement(ofs, ctx_); 12785be168c0dSopenharmony_ci if (config_->target() == kCortex_M) { 12786be168c0dSopenharmony_ci CodeCortexCalcWorkspaceSize(ofs, ctx_); 12787be168c0dSopenharmony_ci@@ -483,6 +501,8 @@ int Generator::CodeWeightFile() { 12788be168c0dSopenharmony_ci if (config_->target() != kCortex_M) { 12789be168c0dSopenharmony_ci cofs << "unsigned char *" << ctx_->buffer_name() << " = 0; \n"; 12790be168c0dSopenharmony_ci cofs << "unsigned char *" << ctx_->weight_name() << " = 0; \n"; 12791be168c0dSopenharmony_ci+ cofs << "int *" << kShapePrefixName << " = 0; \n"; 12792be168c0dSopenharmony_ci+ cofs << "int *" << kOffsetPrefixName << " = 0; \n"; 12793be168c0dSopenharmony_ci std::string net_file = model_dir_ + "net" + std::to_string(ctx_->GetCurModelIndex()) + ".bin"; 12794be168c0dSopenharmony_ci SaveDataToNet(ctx_, net_file, config_->keep_original_weight(), &weight_size_); 12795be168c0dSopenharmony_ci } else { 12796be168c0dSopenharmony_ci@@ -598,8 +618,10 @@ int Generator::CreateCommonFiles() { 12797be168c0dSopenharmony_ci MS_CHECK_RET_CODE(CodeStaticContent(), "code static content failed."); 12798be168c0dSopenharmony_ci MS_CHECK_RET_CODE(CodeModelHandleHFile(), "code model_handle h file failed."); 12799be168c0dSopenharmony_ci MS_CHECK_RET_CODE(CodeCommonModelFile(), "code common model file failed."); 12800be168c0dSopenharmony_ci+ if (!config_->dynamic_shape()) { 12801be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(CodeAllocatorFile(), "code allocator file failed."); 12802be168c0dSopenharmony_ci+ } 12803be168c0dSopenharmony_ci MS_CHECK_RET_CODE(CodeRegKernelHFile(), "code registered kernel header file failed."); 12804be168c0dSopenharmony_ci- MS_CHECK_RET_CODE(CodeAllocatorFile(), "code allocator file failed."); 12805be168c0dSopenharmony_ci MS_CHECK_RET_CODE(CodeSourceCMakeFile(), "code net cmake file failed."); 12806be168c0dSopenharmony_ci return RET_OK; 12807be168c0dSopenharmony_ci } 12808be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.cc 12809be168c0dSopenharmony_cinew file mode 100644 12810be168c0dSopenharmony_ciindex 00000000..108ba227 12811be168c0dSopenharmony_ci--- /dev/null 12812be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.cc 12813be168c0dSopenharmony_ci@@ -0,0 +1,116 @@ 12814be168c0dSopenharmony_ci+/** 12815be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 12816be168c0dSopenharmony_ci+ * 12817be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 12818be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 12819be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 12820be168c0dSopenharmony_ci+ * 12821be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 12822be168c0dSopenharmony_ci+ * 12823be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 12824be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 12825be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12826be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 12827be168c0dSopenharmony_ci+ * limitations under the License. 12828be168c0dSopenharmony_ci+ */ 12829be168c0dSopenharmony_ci+ 12830be168c0dSopenharmony_ci+#include "coder/opcoders/base/reshape_dynamic_base_coder.h" 12831be168c0dSopenharmony_ci+#include <string> 12832be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/serializer.h" 12833be168c0dSopenharmony_ci+#include "include/errorcode.h" 12834be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 12835be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 12836be168c0dSopenharmony_ci+ 12837be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_ExpandDims; 12838be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Flatten; 12839be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_FlattenGrad; 12840be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Reshape; 12841be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Squeeze; 12842be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Unsqueeze; 12843be168c0dSopenharmony_ci+ 12844be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 12845be168c0dSopenharmony_ci+int ReshapeDynamicBaseCoder::Prepare(CoderContext *const context) { 12846be168c0dSopenharmony_ci+ if (input_tensors_.size() == C2NUM) { 12847be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 12848be168c0dSopenharmony_ci+ "Currently, only support the first input of reshape is non-const when shape is dynamical."); 12849be168c0dSopenharmony_ci+ 12850be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32 || 12851be168c0dSopenharmony_ci+ input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt, 12852be168c0dSopenharmony_ci+ RET_ERROR, "The data-type of Reshape's second input must be int."); 12853be168c0dSopenharmony_ci+ } 12854be168c0dSopenharmony_ci+ return RET_OK; 12855be168c0dSopenharmony_ci+} 12856be168c0dSopenharmony_ci+ 12857be168c0dSopenharmony_ci+int ReshapeDynamicBaseCoder::DoCode(CoderContext *const context) { 12858be168c0dSopenharmony_ci+ Serializer coder; 12859be168c0dSopenharmony_ci+ 12860be168c0dSopenharmony_ci+ int data_item_size = static_cast<int>(lite::DataTypeSize(input_tensor_->data_type())); 12861be168c0dSopenharmony_ci+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 12862be168c0dSopenharmony_ci+ int64_t const_part = 1; 12863be168c0dSopenharmony_ci+ std::string non_const_part; 12864be168c0dSopenharmony_ci+ for (const auto &item : in_shape) { 12865be168c0dSopenharmony_ci+ if (IsNumber(item)) { 12866be168c0dSopenharmony_ci+ const_part *= std::stoi(item); 12867be168c0dSopenharmony_ci+ } else { 12868be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 12869be168c0dSopenharmony_ci+ non_const_part += " * "; 12870be168c0dSopenharmony_ci+ } 12871be168c0dSopenharmony_ci+ non_const_part += item; 12872be168c0dSopenharmony_ci+ } 12873be168c0dSopenharmony_ci+ } 12874be168c0dSopenharmony_ci+ std::string size = std::to_string(const_part * data_item_size) + " * " + non_const_part; 12875be168c0dSopenharmony_ci+ std::string input_data = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 12876be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 12877be168c0dSopenharmony_ci+ std::string output_data = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 12878be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 12879be168c0dSopenharmony_ci+ coder.CodeFunction("memcpy", output_data, input_data, size); 12880be168c0dSopenharmony_ci+ 12881be168c0dSopenharmony_ci+ context->AppendCode(coder.str()); 12882be168c0dSopenharmony_ci+ return RET_OK; 12883be168c0dSopenharmony_ci+} 12884be168c0dSopenharmony_ci+ 12885be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Reshape, 12886be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12887be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Reshape, 12888be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12889be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Reshape, 12890be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12891be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Reshape, 12892be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12893be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Flatten, 12894be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12895be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Flatten, 12896be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12897be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Flatten, 12898be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12899be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_ExpandDims, 12900be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12901be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_ExpandDims, 12902be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12903be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_ExpandDims, 12904be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12905be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_ExpandDims, 12906be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12907be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_ExpandDims, 12908be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12909be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Squeeze, 12910be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12911be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Squeeze, 12912be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12913be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Squeeze, 12914be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12915be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Squeeze, 12916be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12917be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Squeeze, 12918be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12919be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Unsqueeze, 12920be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12921be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Unsqueeze, 12922be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12923be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Unsqueeze, 12924be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12925be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Unsqueeze, 12926be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12927be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt8, PrimitiveType_Unsqueeze, 12928be168c0dSopenharmony_ci+ CPUOpCoderCreator<ReshapeDynamicBaseCoder>) 12929be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 12930be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.h 12931be168c0dSopenharmony_cinew file mode 100644 12932be168c0dSopenharmony_ciindex 00000000..aaae22eb 12933be168c0dSopenharmony_ci--- /dev/null 12934be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/reshape_dynamic_base_coder.h 12935be168c0dSopenharmony_ci@@ -0,0 +1,38 @@ 12936be168c0dSopenharmony_ci+/** 12937be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 12938be168c0dSopenharmony_ci+ * 12939be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 12940be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 12941be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 12942be168c0dSopenharmony_ci+ * 12943be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 12944be168c0dSopenharmony_ci+ * 12945be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 12946be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 12947be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12948be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 12949be168c0dSopenharmony_ci+ * limitations under the License. 12950be168c0dSopenharmony_ci+ */ 12951be168c0dSopenharmony_ci+ 12952be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_RESHAPE_DYNAMIC_BASE_CODER_H_ 12953be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_RESHAPE_DYNAMIC_BASE_CODER_H_ 12954be168c0dSopenharmony_ci+ 12955be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/op_coder.h" 12956be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 12957be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 12958be168c0dSopenharmony_ci+ 12959be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 12960be168c0dSopenharmony_ci+class ReshapeDynamicBaseCoder final : public OperatorCoder { 12961be168c0dSopenharmony_ci+ public: 12962be168c0dSopenharmony_ci+ ReshapeDynamicBaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 12963be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 12964be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 12965be168c0dSopenharmony_ci+ 12966be168c0dSopenharmony_ci+ ~ReshapeDynamicBaseCoder() override = default; 12967be168c0dSopenharmony_ci+ 12968be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 12969be168c0dSopenharmony_ci+ 12970be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 12971be168c0dSopenharmony_ci+}; 12972be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 12973be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_RESHAPE_DYNAMIC_BASE_CODER_H_ 12974be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc 12975be168c0dSopenharmony_cinew file mode 100644 12976be168c0dSopenharmony_ciindex 00000000..4b2b0abe 12977be168c0dSopenharmony_ci--- /dev/null 12978be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.cc 12979be168c0dSopenharmony_ci@@ -0,0 +1,115 @@ 12980be168c0dSopenharmony_ci+/** 12981be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 12982be168c0dSopenharmony_ci+ * 12983be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 12984be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 12985be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 12986be168c0dSopenharmony_ci+ * 12987be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 12988be168c0dSopenharmony_ci+ * 12989be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 12990be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 12991be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12992be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 12993be168c0dSopenharmony_ci+ * limitations under the License. 12994be168c0dSopenharmony_ci+ */ 12995be168c0dSopenharmony_ci+ 12996be168c0dSopenharmony_ci+#include "coder/opcoders/base/strided_slice_dynamic_base_coder.h" 12997be168c0dSopenharmony_ci+#include <cmath> 12998be168c0dSopenharmony_ci+#include <string> 12999be168c0dSopenharmony_ci+#include "mindspore/lite/src/common/log_util.h" 13000be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 13001be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 13002be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 13003be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 13004be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 13005be168c0dSopenharmony_ci+#include "base/float16.h" 13006be168c0dSopenharmony_ci+ 13007be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_StridedSlice; 13008be168c0dSopenharmony_ci+ 13009be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 13010be168c0dSopenharmony_ci+namespace { 13011be168c0dSopenharmony_ci+size_t GetInnerSize(TypeId type_id, size_t inner_elements) { 13012be168c0dSopenharmony_ci+ switch (type_id) { 13013be168c0dSopenharmony_ci+ case kNumberTypeInt8: 13014be168c0dSopenharmony_ci+ return inner_elements * sizeof(int8_t); 13015be168c0dSopenharmony_ci+ case kNumberTypeFloat32: 13016be168c0dSopenharmony_ci+ return inner_elements * sizeof(float); 13017be168c0dSopenharmony_ci+ case kNumberTypeInt32: 13018be168c0dSopenharmony_ci+ return inner_elements * sizeof(int32_t); 13019be168c0dSopenharmony_ci+ case kNumberTypeFloat16: 13020be168c0dSopenharmony_ci+ return inner_elements * sizeof(float16); 13021be168c0dSopenharmony_ci+ default: 13022be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Not supported data type: " << type_id; 13023be168c0dSopenharmony_ci+ return 0; 13024be168c0dSopenharmony_ci+ } 13025be168c0dSopenharmony_ci+} 13026be168c0dSopenharmony_ci+} // namespace 13027be168c0dSopenharmony_ci+ 13028be168c0dSopenharmony_ci+int StridedSliceDynamicBaseCoder::Prepare(CoderContext *context) { 13029be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 13030be168c0dSopenharmony_ci+ for (size_t i = 1; i < input_tensors_.size(); ++i) { 13031be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->IsConst(), RET_PARAM_INVALID, 13032be168c0dSopenharmony_ci+ "The " << i << " input of strided slice should be const."); 13033be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeInt32, RET_PARAM_INVALID, 13034be168c0dSopenharmony_ci+ "The " << i << " input tensor data type should be int32."); 13035be168c0dSopenharmony_ci+ } 13036be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(output_tensors_.size(), C1NUM); 13037be168c0dSopenharmony_ci+ strided_slice_param_ = reinterpret_cast<StridedSliceParameter *>(parameter_); 13038be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(strided_slice_param_); 13039be168c0dSopenharmony_ci+ auto begin_tensor = input_tensors_.at(1); 13040be168c0dSopenharmony_ci+ input_shape_ = shape_info_container_->GetTemplateShape(input_tensor_); 13041be168c0dSopenharmony_ci+ if (input_shape_.size() > DIMENSION_8D || begin_tensor->shape().size() > DIMENSION_8D) { 13042be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "StridedSlice not support input rank or begin num exceeds " << DIMENSION_8D; 13043be168c0dSopenharmony_ci+ return RET_ERROR; 13044be168c0dSopenharmony_ci+ } 13045be168c0dSopenharmony_ci+ dynamic_param_.in_shape_ = "{"; 13046be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_shape_.size(); ++i) { 13047be168c0dSopenharmony_ci+ dynamic_param_.in_shape_ += input_shape_[i] + ", "; 13048be168c0dSopenharmony_ci+ } 13049be168c0dSopenharmony_ci+ dynamic_param_.in_shape_ += "}"; 13050be168c0dSopenharmony_ci+ return RET_OK; 13051be168c0dSopenharmony_ci+} 13052be168c0dSopenharmony_ci+ 13053be168c0dSopenharmony_ci+int StridedSliceDynamicBaseCoder::DoCode(CoderContext *ctx) { 13054be168c0dSopenharmony_ci+ inner_size_ = GetInnerSize(input_tensor_->data_type(), inner_); 13055be168c0dSopenharmony_ci+ Collect(ctx, 13056be168c0dSopenharmony_ci+ { 13057be168c0dSopenharmony_ci+ "nnacl/fp32/strided_slice_fp32.h", 13058be168c0dSopenharmony_ci+ }, 13059be168c0dSopenharmony_ci+ { 13060be168c0dSopenharmony_ci+ "strided_slice_fp32.c", 13061be168c0dSopenharmony_ci+ }); 13062be168c0dSopenharmony_ci+ switch (input_tensor_->data_type()) { 13063be168c0dSopenharmony_ci+ case kNumberTypeInt8: 13064be168c0dSopenharmony_ci+ strided_slice_param_->data_type = ::kNumberTypeInt8; 13065be168c0dSopenharmony_ci+ break; 13066be168c0dSopenharmony_ci+ case kNumberTypeFloat32: 13067be168c0dSopenharmony_ci+ strided_slice_param_->data_type = ::kNumberTypeFloat32; 13068be168c0dSopenharmony_ci+ break; 13069be168c0dSopenharmony_ci+ case kNumberTypeInt32: 13070be168c0dSopenharmony_ci+ strided_slice_param_->data_type = ::kNumberTypeInt32; 13071be168c0dSopenharmony_ci+ break; 13072be168c0dSopenharmony_ci+ case kNumberTypeFloat16: 13073be168c0dSopenharmony_ci+ strided_slice_param_->data_type = ::kNumberTypeFloat16; 13074be168c0dSopenharmony_ci+ break; 13075be168c0dSopenharmony_ci+ default: 13076be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Not supported data type: " << input_tensor_->data_type(); 13077be168c0dSopenharmony_ci+ return RET_ERROR; 13078be168c0dSopenharmony_ci+ } 13079be168c0dSopenharmony_ci+ nnacl::NNaclFp32Serializer code; 13080be168c0dSopenharmony_ci+ code.CodeStruct("strided_slice_parameter", *strided_slice_param_, dynamic_param_); 13081be168c0dSopenharmony_ci+ std::string input_data = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13082be168c0dSopenharmony_ci+ std::string output_data = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13083be168c0dSopenharmony_ci+ code.CodeFunction("DoStridedSlice", input_data, output_data, "&strided_slice_parameter"); 13084be168c0dSopenharmony_ci+ ctx->AppendCode(code.str()); 13085be168c0dSopenharmony_ci+ return RET_OK; 13086be168c0dSopenharmony_ci+} 13087be168c0dSopenharmony_ci+ 13088be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_StridedSlice, 13089be168c0dSopenharmony_ci+ CPUOpCoderCreator<StridedSliceDynamicBaseCoder>) 13090be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_StridedSlice, 13091be168c0dSopenharmony_ci+ CPUOpCoderCreator<StridedSliceDynamicBaseCoder>) 13092be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_StridedSlice, 13093be168c0dSopenharmony_ci+ CPUOpCoderCreator<StridedSliceDynamicBaseCoder>) 13094be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 13095be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h 13096be168c0dSopenharmony_cinew file mode 100644 13097be168c0dSopenharmony_ciindex 00000000..d41cff4f 13098be168c0dSopenharmony_ci--- /dev/null 13099be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/base/strided_slice_dynamic_base_coder.h 13100be168c0dSopenharmony_ci@@ -0,0 +1,45 @@ 13101be168c0dSopenharmony_ci+/** 13102be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13103be168c0dSopenharmony_ci+ * 13104be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13105be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13106be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13107be168c0dSopenharmony_ci+ * 13108be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13109be168c0dSopenharmony_ci+ * 13110be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13111be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13112be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13113be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13114be168c0dSopenharmony_ci+ * limitations under the License. 13115be168c0dSopenharmony_ci+ */ 13116be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ 13117be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ 13118be168c0dSopenharmony_ci+#include <vector> 13119be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 13120be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h" 13121be168c0dSopenharmony_ci+#include "nnacl/strided_slice_parameter.h" 13122be168c0dSopenharmony_ci+ 13123be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 13124be168c0dSopenharmony_ci+class StridedSliceDynamicBaseCoder final : public OperatorCoder { 13125be168c0dSopenharmony_ci+ public: 13126be168c0dSopenharmony_ci+ StridedSliceDynamicBaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 13127be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 13128be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 13129be168c0dSopenharmony_ci+ 13130be168c0dSopenharmony_ci+ ~StridedSliceDynamicBaseCoder() override = default; 13131be168c0dSopenharmony_ci+ 13132be168c0dSopenharmony_ci+ int Prepare(CoderContext *context) override; 13133be168c0dSopenharmony_ci+ 13134be168c0dSopenharmony_ci+ int DoCode(CoderContext *context) override; 13135be168c0dSopenharmony_ci+ 13136be168c0dSopenharmony_ci+ private: 13137be168c0dSopenharmony_ci+ StridedSliceParameter *strided_slice_param_{nullptr}; 13138be168c0dSopenharmony_ci+ StridedSliceDynamicParameter dynamic_param_; 13139be168c0dSopenharmony_ci+ size_t inner_{1}; 13140be168c0dSopenharmony_ci+ size_t inner_size_{1}; 13141be168c0dSopenharmony_ci+ std::vector<std::string> input_shape_; 13142be168c0dSopenharmony_ci+ std::vector<std::string> output_shape_; 13143be168c0dSopenharmony_ci+}; 13144be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 13145be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_BASE_STRIDED_SLICE_BASE_CODER_H_ 13146be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h 13147be168c0dSopenharmony_cinew file mode 100644 13148be168c0dSopenharmony_ciindex 00000000..1e9e4f8d 13149be168c0dSopenharmony_ci--- /dev/null 13150be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h 13151be168c0dSopenharmony_ci@@ -0,0 +1,43 @@ 13152be168c0dSopenharmony_ci+/** 13153be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13154be168c0dSopenharmony_ci+ * 13155be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13156be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13157be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13158be168c0dSopenharmony_ci+ * 13159be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13160be168c0dSopenharmony_ci+ * 13161be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13162be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13163be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13164be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13165be168c0dSopenharmony_ci+ * limitations under the License. 13166be168c0dSopenharmony_ci+ */ 13167be168c0dSopenharmony_ci+ 13168be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_ARITHMETIC_DYNAMIC_PARAMETER_H_ 13169be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_ARITHMETIC_DYNAMIC_PARAMETER_H_ 13170be168c0dSopenharmony_ci+#include <string> 13171be168c0dSopenharmony_ci+ 13172be168c0dSopenharmony_ci+typedef struct ArithmeticDynamicParameter { 13173be168c0dSopenharmony_ci+ std::string in_shape0_; 13174be168c0dSopenharmony_ci+ std::string in_elements_num0_; 13175be168c0dSopenharmony_ci+ std::string in_shape1_; 13176be168c0dSopenharmony_ci+ std::string in_elements_num1_; 13177be168c0dSopenharmony_ci+ 13178be168c0dSopenharmony_ci+ std::string out_shape_; 13179be168c0dSopenharmony_ci+ std::string out_elements_num_; 13180be168c0dSopenharmony_ci+ 13181be168c0dSopenharmony_ci+ std::string in_strides0_; 13182be168c0dSopenharmony_ci+ std::string in_strides1_; 13183be168c0dSopenharmony_ci+ std::string out_strides_; 13184be168c0dSopenharmony_ci+ 13185be168c0dSopenharmony_ci+ std::string multiples0_; 13186be168c0dSopenharmony_ci+ std::string multiples1_; 13187be168c0dSopenharmony_ci+} ArithmeticDynamicParameter; 13188be168c0dSopenharmony_ci+ 13189be168c0dSopenharmony_ci+typedef struct BroadcastDynamicShapeInfo { 13190be168c0dSopenharmony_ci+ std::string input_shape_; 13191be168c0dSopenharmony_ci+ std::string output_shape_; 13192be168c0dSopenharmony_ci+} BroadcastDynamicShapeInfo; 13193be168c0dSopenharmony_ci+ 13194be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_ARITHMETIC_DYNAMIC_PARAMETER_H_ 13195be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h 13196be168c0dSopenharmony_cinew file mode 100644 13197be168c0dSopenharmony_ciindex 00000000..a05ab848 13198be168c0dSopenharmony_ci--- /dev/null 13199be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h 13200be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 13201be168c0dSopenharmony_ci+/** 13202be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13203be168c0dSopenharmony_ci+ * 13204be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13205be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13206be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13207be168c0dSopenharmony_ci+ * 13208be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13209be168c0dSopenharmony_ci+ * 13210be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13211be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13212be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13213be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13214be168c0dSopenharmony_ci+ * limitations under the License. 13215be168c0dSopenharmony_ci+ */ 13216be168c0dSopenharmony_ci+ 13217be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_CONV_DYNAMIC_PARAMETER_H_ 13218be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_CONV_DYNAMIC_PARAMETER_H_ 13219be168c0dSopenharmony_ci+#include <string> 13220be168c0dSopenharmony_ci+ 13221be168c0dSopenharmony_ci+typedef struct ConvDynamicParameter { 13222be168c0dSopenharmony_ci+ std::string input_batch_; 13223be168c0dSopenharmony_ci+ std::string output_batch_; 13224be168c0dSopenharmony_ci+} ConvDynamicParameter; 13225be168c0dSopenharmony_ci+ 13226be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_CONV_DYNAMIC_PARAMETER_H_ 13227be168c0dSopenharmony_ci\ No newline at end of file 13228be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h 13229be168c0dSopenharmony_cinew file mode 100644 13230be168c0dSopenharmony_ciindex 00000000..970a863a 13231be168c0dSopenharmony_ci--- /dev/null 13232be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h 13233be168c0dSopenharmony_ci@@ -0,0 +1,28 @@ 13234be168c0dSopenharmony_ci+/** 13235be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13236be168c0dSopenharmony_ci+ * 13237be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13238be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13239be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13240be168c0dSopenharmony_ci+ * 13241be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13242be168c0dSopenharmony_ci+ * 13243be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13244be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13245be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13246be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13247be168c0dSopenharmony_ci+ * limitations under the License. 13248be168c0dSopenharmony_ci+ */ 13249be168c0dSopenharmony_ci+ 13250be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_DYNAMIC_LSTM_PARAMETER_H_ 13251be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_DYNAMIC_LSTM_PARAMETER_H_ 13252be168c0dSopenharmony_ci+ 13253be168c0dSopenharmony_ci+typedef struct DynamicLstmParameter { 13254be168c0dSopenharmony_ci+ std::string seq_len_; 13255be168c0dSopenharmony_ci+ std::string batch_; 13256be168c0dSopenharmony_ci+ std::string input_row_align_; 13257be168c0dSopenharmony_ci+ std::string state_row_align_; 13258be168c0dSopenharmony_ci+ std::string output_step_; 13259be168c0dSopenharmony_ci+} DynamicLstmParameter; 13260be168c0dSopenharmony_ci+ 13261be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_DYNAMIC_LSTM_PARAMETER_H_ 13262be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h 13263be168c0dSopenharmony_cinew file mode 100644 13264be168c0dSopenharmony_ciindex 00000000..d99b0cf9 13265be168c0dSopenharmony_ci--- /dev/null 13266be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h 13267be168c0dSopenharmony_ci@@ -0,0 +1,25 @@ 13268be168c0dSopenharmony_ci+/** 13269be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13270be168c0dSopenharmony_ci+ * 13271be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13272be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13273be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13274be168c0dSopenharmony_ci+ * 13275be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13276be168c0dSopenharmony_ci+ * 13277be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13278be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13279be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13280be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13281be168c0dSopenharmony_ci+ * limitations under the License. 13282be168c0dSopenharmony_ci+ */ 13283be168c0dSopenharmony_ci+ 13284be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_MATMUL_DYNAMIC_PARAMETER_H_ 13285be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_MATMUL_DYNAMIC_PARAMETER_H_ 13286be168c0dSopenharmony_ci+ 13287be168c0dSopenharmony_ci+typedef struct MatmulDynamicParameter { 13288be168c0dSopenharmony_ci+ std::string row_; 13289be168c0dSopenharmony_ci+ std::string batch_; 13290be168c0dSopenharmony_ci+} MatmulDynamicParameter; 13291be168c0dSopenharmony_ci+ 13292be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_MATMUL_DYNAMIC_PARAMETER_H_ 13293be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h 13294be168c0dSopenharmony_cinew file mode 100644 13295be168c0dSopenharmony_ciindex 00000000..f2636e55 13296be168c0dSopenharmony_ci--- /dev/null 13297be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h 13298be168c0dSopenharmony_ci@@ -0,0 +1,33 @@ 13299be168c0dSopenharmony_ci+/** 13300be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13301be168c0dSopenharmony_ci+ * 13302be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13303be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13304be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13305be168c0dSopenharmony_ci+ * 13306be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13307be168c0dSopenharmony_ci+ * 13308be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13309be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13310be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13311be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13312be168c0dSopenharmony_ci+ * limitations under the License. 13313be168c0dSopenharmony_ci+ */ 13314be168c0dSopenharmony_ci+ 13315be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_POOLING_DYNAMIC_PARAMETER_H_ 13316be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_POOLING_DYNAMIC_PARAMETER_H_ 13317be168c0dSopenharmony_ci+#include <string> 13318be168c0dSopenharmony_ci+ 13319be168c0dSopenharmony_ci+typedef struct PoolingDynamicParameter { 13320be168c0dSopenharmony_ci+ int avg_mode_; 13321be168c0dSopenharmony_ci+ bool global_; 13322be168c0dSopenharmony_ci+ int window_w_; 13323be168c0dSopenharmony_ci+ int window_h_; 13324be168c0dSopenharmony_ci+ int stride_w_; 13325be168c0dSopenharmony_ci+ int stride_h_; 13326be168c0dSopenharmony_ci+ 13327be168c0dSopenharmony_ci+ std::string input_batch_; 13328be168c0dSopenharmony_ci+ std::string output_batch_; 13329be168c0dSopenharmony_ci+} PoolingDynamicParameter; 13330be168c0dSopenharmony_ci+ 13331be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_POOLING_DYNAMIC_PARAMETER_H_ 13332be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h 13333be168c0dSopenharmony_cinew file mode 100644 13334be168c0dSopenharmony_ciindex 00000000..e8728383 13335be168c0dSopenharmony_ci--- /dev/null 13336be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h 13337be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 13338be168c0dSopenharmony_ci+/** 13339be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13340be168c0dSopenharmony_ci+ * 13341be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13342be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13343be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13344be168c0dSopenharmony_ci+ * 13345be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13346be168c0dSopenharmony_ci+ * 13347be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13348be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13349be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13350be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13351be168c0dSopenharmony_ci+ * limitations under the License. 13352be168c0dSopenharmony_ci+ */ 13353be168c0dSopenharmony_ci+ 13354be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SCALE_DYNAMIC_PARAMETER_H_ 13355be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SCALE_DYNAMIC_PARAMETER_H_ 13356be168c0dSopenharmony_ci+#include <string> 13357be168c0dSopenharmony_ci+ 13358be168c0dSopenharmony_ci+typedef struct ScaleDynamicParameter { 13359be168c0dSopenharmony_ci+ std::string outer_size_; 13360be168c0dSopenharmony_ci+ std::string axis_size_; 13361be168c0dSopenharmony_ci+ std::string inner_size_; 13362be168c0dSopenharmony_ci+} ScaleDynamicParameter; 13363be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SCALE_DYNAMIC_PARAMETER_H_ 13364be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h 13365be168c0dSopenharmony_cinew file mode 100644 13366be168c0dSopenharmony_ciindex 00000000..f17993d4 13367be168c0dSopenharmony_ci--- /dev/null 13368be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h 13369be168c0dSopenharmony_ci@@ -0,0 +1,27 @@ 13370be168c0dSopenharmony_ci+/** 13371be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13372be168c0dSopenharmony_ci+ * 13373be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13374be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13375be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13376be168c0dSopenharmony_ci+ * 13377be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13378be168c0dSopenharmony_ci+ * 13379be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13380be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13381be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13382be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13383be168c0dSopenharmony_ci+ * limitations under the License. 13384be168c0dSopenharmony_ci+ */ 13385be168c0dSopenharmony_ci+ 13386be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SLICE_DYNAMIC_PARAMETER_H_ 13387be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SLICE_DYNAMIC_PARAMETER_H_ 13388be168c0dSopenharmony_ci+#include <string> 13389be168c0dSopenharmony_ci+ 13390be168c0dSopenharmony_ci+typedef struct SliceDynamicParameter { 13391be168c0dSopenharmony_ci+ std::string shape_; 13392be168c0dSopenharmony_ci+ std::string size_; 13393be168c0dSopenharmony_ci+ std::string end_; 13394be168c0dSopenharmony_ci+} SliceDynamicParameter; 13395be168c0dSopenharmony_ci+ 13396be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SLICE_DYNAMIC_PARAMETER_H_ 13397be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h 13398be168c0dSopenharmony_cinew file mode 100644 13399be168c0dSopenharmony_ciindex 00000000..92dfaf21 13400be168c0dSopenharmony_ci--- /dev/null 13401be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h 13402be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 13403be168c0dSopenharmony_ci+/** 13404be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13405be168c0dSopenharmony_ci+ * 13406be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13407be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13408be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13409be168c0dSopenharmony_ci+ * 13410be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13411be168c0dSopenharmony_ci+ * 13412be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13413be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13414be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13415be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13416be168c0dSopenharmony_ci+ * limitations under the License. 13417be168c0dSopenharmony_ci+ */ 13418be168c0dSopenharmony_ci+ 13419be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SOFTMAX_DYNAMIC_PARAMETER_H_ 13420be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SOFTMAX_DYNAMIC_PARAMETER_H_ 13421be168c0dSopenharmony_ci+#include <string> 13422be168c0dSopenharmony_ci+ 13423be168c0dSopenharmony_ci+typedef struct SoftmaxDynamicParameter { 13424be168c0dSopenharmony_ci+ std::string input_shape_; 13425be168c0dSopenharmony_ci+ std::string element_size_; 13426be168c0dSopenharmony_ci+} SoftmaxDynamicParameter; 13427be168c0dSopenharmony_ci+ 13428be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SOFTMAX_DYNAMIC_PARAMETER_H_ 13429be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h 13430be168c0dSopenharmony_cinew file mode 100644 13431be168c0dSopenharmony_ciindex 00000000..b97097ad 13432be168c0dSopenharmony_ci--- /dev/null 13433be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h 13434be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 13435be168c0dSopenharmony_ci+/** 13436be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13437be168c0dSopenharmony_ci+ * 13438be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13439be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13440be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13441be168c0dSopenharmony_ci+ * 13442be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13443be168c0dSopenharmony_ci+ * 13444be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13445be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13446be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13447be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13448be168c0dSopenharmony_ci+ * limitations under the License. 13449be168c0dSopenharmony_ci+ */ 13450be168c0dSopenharmony_ci+ 13451be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SPLIT_DYNAMIC_PARAMETER_H_ 13452be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SPLIT_DYNAMIC_PARAMETER_H_ 13453be168c0dSopenharmony_ci+#include <string> 13454be168c0dSopenharmony_ci+ 13455be168c0dSopenharmony_ci+typedef struct SplitDynamicParameter { 13456be168c0dSopenharmony_ci+ std::string strides_; 13457be168c0dSopenharmony_ci+ std::string split_count_; 13458be168c0dSopenharmony_ci+} SplitDynamicParameter; 13459be168c0dSopenharmony_ci+ 13460be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_SPLIT_DYNAMIC_PARAMETER_H_ 13461be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h 13462be168c0dSopenharmony_cinew file mode 100644 13463be168c0dSopenharmony_ciindex 00000000..202ee7dd 13464be168c0dSopenharmony_ci--- /dev/null 13465be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h 13466be168c0dSopenharmony_ci@@ -0,0 +1,25 @@ 13467be168c0dSopenharmony_ci+/** 13468be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13469be168c0dSopenharmony_ci+ * 13470be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13471be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13472be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13473be168c0dSopenharmony_ci+ * 13474be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13475be168c0dSopenharmony_ci+ * 13476be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13477be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13478be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13479be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13480be168c0dSopenharmony_ci+ * limitations under the License. 13481be168c0dSopenharmony_ci+ */ 13482be168c0dSopenharmony_ci+ 13483be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_STRIDED_SLICE_DYNAMIC_PARAMETER_H_ 13484be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_STRIDED_SLICE_DYNAMIC_PARAMETER_H_ 13485be168c0dSopenharmony_ci+#include <string> 13486be168c0dSopenharmony_ci+ 13487be168c0dSopenharmony_ci+typedef struct StridedSliceDynamicParameter { 13488be168c0dSopenharmony_ci+ std::string in_shape_; 13489be168c0dSopenharmony_ci+} StridedSliceDynamicParameter; 13490be168c0dSopenharmony_ci+ 13491be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_STRIDED_SLICE_DYNAMIC_PARAMETER_H_ 13492be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h 13493be168c0dSopenharmony_cinew file mode 100644 13494be168c0dSopenharmony_ciindex 00000000..ed4f21f2 13495be168c0dSopenharmony_ci--- /dev/null 13496be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h 13497be168c0dSopenharmony_ci@@ -0,0 +1,28 @@ 13498be168c0dSopenharmony_ci+/** 13499be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13500be168c0dSopenharmony_ci+ * 13501be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13502be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13503be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13504be168c0dSopenharmony_ci+ * 13505be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13506be168c0dSopenharmony_ci+ * 13507be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13508be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13509be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13510be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13511be168c0dSopenharmony_ci+ * limitations under the License. 13512be168c0dSopenharmony_ci+ */ 13513be168c0dSopenharmony_ci+ 13514be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_TRANSPOSE_DYNAMIC_PARAMETER_H_ 13515be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_TRANSPOSE_DYNAMIC_PARAMETER_H_ 13516be168c0dSopenharmony_ci+#include <string> 13517be168c0dSopenharmony_ci+ 13518be168c0dSopenharmony_ci+typedef struct TransposeDynamicParameter { 13519be168c0dSopenharmony_ci+ // shape correlative 13520be168c0dSopenharmony_ci+ std::string strides_; 13521be168c0dSopenharmony_ci+ std::string out_strides_; 13522be168c0dSopenharmony_ci+ std::string data_num_; 13523be168c0dSopenharmony_ci+} TransposeDynamicParameter; 13524be168c0dSopenharmony_ci+ 13525be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_DYNAMIC_PARAMETER_TRANSPOSE_DYNAMIC_PARAMETER_H_ 13526be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc 13527be168c0dSopenharmony_cinew file mode 100644 13528be168c0dSopenharmony_ciindex 00000000..86048179 13529be168c0dSopenharmony_ci--- /dev/null 13530be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.cc 13531be168c0dSopenharmony_ci@@ -0,0 +1,93 @@ 13532be168c0dSopenharmony_ci+/** 13533be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13534be168c0dSopenharmony_ci+ * 13535be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13536be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13537be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13538be168c0dSopenharmony_ci+ * 13539be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13540be168c0dSopenharmony_ci+ * 13541be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13542be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13543be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13544be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13545be168c0dSopenharmony_ci+ * limitations under the License. 13546be168c0dSopenharmony_ci+ */ 13547be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h" 13548be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 13549be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 13550be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 13551be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 13552be168c0dSopenharmony_ci+ 13553be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Activation; 13554be168c0dSopenharmony_ci+ 13555be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 13556be168c0dSopenharmony_ci+int ActivationDynamicFP16Coder::Prepare(CoderContext *const context) { 13557be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13558be168c0dSopenharmony_ci+ "Input tensor data type is invalid."); 13559be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13560be168c0dSopenharmony_ci+ "Output tensor data type is invalid."); 13561be168c0dSopenharmony_ci+ return RET_OK; 13562be168c0dSopenharmony_ci+} 13563be168c0dSopenharmony_ci+ 13564be168c0dSopenharmony_ci+int ActivationDynamicFP16Coder::DoCode(CoderContext *const context) { 13565be168c0dSopenharmony_ci+ Collect(context, 13566be168c0dSopenharmony_ci+ { 13567be168c0dSopenharmony_ci+ "nnacl/fp16/activation_fp16.h", 13568be168c0dSopenharmony_ci+ }, 13569be168c0dSopenharmony_ci+ { 13570be168c0dSopenharmony_ci+ "activation_fp16.c", 13571be168c0dSopenharmony_ci+ }); 13572be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 13573be168c0dSopenharmony_ci+ // attribute 13574be168c0dSopenharmony_ci+ auto *activation_parameter = reinterpret_cast<ActivationParameter *>(parameter_); 13575be168c0dSopenharmony_ci+ MS_CHECK_PTR(activation_parameter); 13576be168c0dSopenharmony_ci+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13577be168c0dSopenharmony_ci+ count_ = AccumulateShape(in_shape, 0, in_shape.size()); 13578be168c0dSopenharmony_ci+ input_data_ = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 13579be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 13580be168c0dSopenharmony_ci+ output_data_ = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 13581be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!output_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 13582be168c0dSopenharmony_ci+ input_data_ = "(float16_t *)(" + input_data_ + ")"; 13583be168c0dSopenharmony_ci+ output_data_ = "(float16_t *)(" + output_data_ + ")"; 13584be168c0dSopenharmony_ci+ 13585be168c0dSopenharmony_ci+ switch (activation_parameter->type_) { 13586be168c0dSopenharmony_ci+ case schema::ActivationType_RELU: 13587be168c0dSopenharmony_ci+ code.CodeFunction("ReluFp16", input_data_, output_data_, count_); 13588be168c0dSopenharmony_ci+ break; 13589be168c0dSopenharmony_ci+ case schema::ActivationType_RELU6: 13590be168c0dSopenharmony_ci+ code.CodeFunction("Relu6Fp16", input_data_, output_data_, count_); 13591be168c0dSopenharmony_ci+ break; 13592be168c0dSopenharmony_ci+ case schema::ActivationType_LEAKY_RELU: 13593be168c0dSopenharmony_ci+ code.CodeFunction("LReluFp16", input_data_, output_data_, count_, activation_parameter->alpha_); 13594be168c0dSopenharmony_ci+ break; 13595be168c0dSopenharmony_ci+ case schema::ActivationType_SIGMOID: 13596be168c0dSopenharmony_ci+ code.CodeFunction("SigmoidFp16", input_data_, output_data_, count_); 13597be168c0dSopenharmony_ci+ break; 13598be168c0dSopenharmony_ci+ case schema::ActivationType_TANH: 13599be168c0dSopenharmony_ci+ code.CodeFunction("TanhFp16", input_data_, output_data_, count_); 13600be168c0dSopenharmony_ci+ break; 13601be168c0dSopenharmony_ci+ case schema::ActivationType_HSWISH: 13602be168c0dSopenharmony_ci+ code.CodeFunction("HSwishFp16", input_data_, output_data_, count_); 13603be168c0dSopenharmony_ci+ break; 13604be168c0dSopenharmony_ci+ case schema::ActivationType_SWISH: 13605be168c0dSopenharmony_ci+ code.CodeFunction("SwishFp16", input_data_, output_data_, count_); 13606be168c0dSopenharmony_ci+ break; 13607be168c0dSopenharmony_ci+ case schema::ActivationType_HSIGMOID: 13608be168c0dSopenharmony_ci+ code.CodeFunction("HSigmoidFp16", input_data_, output_data_, count_); 13609be168c0dSopenharmony_ci+ break; 13610be168c0dSopenharmony_ci+ case schema::ActivationType_ELU: 13611be168c0dSopenharmony_ci+ code.CodeFunction("EluFp16", input_data_, output_data_, count_, activation_parameter->alpha_); 13612be168c0dSopenharmony_ci+ break; 13613be168c0dSopenharmony_ci+ default: 13614be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Activation type error"; 13615be168c0dSopenharmony_ci+ return RET_ERROR; 13616be168c0dSopenharmony_ci+ } 13617be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "ActivationFP16Code has been called"; 13618be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 13619be168c0dSopenharmony_ci+ return lite::RET_OK; 13620be168c0dSopenharmony_ci+} 13621be168c0dSopenharmony_ci+ 13622be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Activation, 13623be168c0dSopenharmony_ci+ CPUOpCoderCreator<ActivationDynamicFP16Coder>) 13624be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 13625be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h 13626be168c0dSopenharmony_cinew file mode 100644 13627be168c0dSopenharmony_ciindex 00000000..c881567f 13628be168c0dSopenharmony_ci--- /dev/null 13629be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/activation_dynamic_fp16_coder.h 13630be168c0dSopenharmony_ci@@ -0,0 +1,37 @@ 13631be168c0dSopenharmony_ci+/** 13632be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13633be168c0dSopenharmony_ci+ * 13634be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13635be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13636be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13637be168c0dSopenharmony_ci+ * 13638be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13639be168c0dSopenharmony_ci+ * 13640be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13641be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13642be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13643be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13644be168c0dSopenharmony_ci+ * limitations under the License. 13645be168c0dSopenharmony_ci+ */ 13646be168c0dSopenharmony_ci+ 13647be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ACTIVATION_DYNAMIC_FP16_CODER_H_ 13648be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ACTIVATION_DYNAMIC_FP16_CODER_H_ 13649be168c0dSopenharmony_ci+ 13650be168c0dSopenharmony_ci+#include <vector> 13651be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h" 13652be168c0dSopenharmony_ci+ 13653be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 13654be168c0dSopenharmony_ci+class ActivationDynamicFP16Coder final : public ActivationDynamicFP32Coder { 13655be168c0dSopenharmony_ci+ public: 13656be168c0dSopenharmony_ci+ ActivationDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 13657be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 13658be168c0dSopenharmony_ci+ : ActivationDynamicFP32Coder(in_tensors, out_tensors, node, node_index, target) {} 13659be168c0dSopenharmony_ci+ 13660be168c0dSopenharmony_ci+ ~ActivationDynamicFP16Coder() override = default; 13661be168c0dSopenharmony_ci+ 13662be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 13663be168c0dSopenharmony_ci+ 13664be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 13665be168c0dSopenharmony_ci+}; 13666be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 13667be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ACTIVATION_DYNAMIC_FP16_CODER_H_ 13668be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc 13669be168c0dSopenharmony_cinew file mode 100644 13670be168c0dSopenharmony_ciindex 00000000..7050b8b0 13671be168c0dSopenharmony_ci--- /dev/null 13672be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.cc 13673be168c0dSopenharmony_ci@@ -0,0 +1,369 @@ 13674be168c0dSopenharmony_ci+/** 13675be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 13676be168c0dSopenharmony_ci+ * 13677be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 13678be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 13679be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 13680be168c0dSopenharmony_ci+ * 13681be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 13682be168c0dSopenharmony_ci+ * 13683be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 13684be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 13685be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13686be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 13687be168c0dSopenharmony_ci+ * limitations under the License. 13688be168c0dSopenharmony_ci+ */ 13689be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h" 13690be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 13691be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 13692be168c0dSopenharmony_ci+#include "coder/log.h" 13693be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 13694be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 13695be168c0dSopenharmony_ci+ 13696be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 13697be168c0dSopenharmony_ci+namespace { 13698be168c0dSopenharmony_ci+std::string wrap_void(const std::string &a) { return "(void *)(" + a + ")"; } 13699be168c0dSopenharmony_ci+} // namespace 13700be168c0dSopenharmony_ci+ 13701be168c0dSopenharmony_ci+void ArithmeticDynamicFP16Coder::InitFunTable() { 13702be168c0dSopenharmony_ci+ fun_table_ = { 13703be168c0dSopenharmony_ci+ {PrimitiveType_MulFusion, schema::ActivationType_RELU, "ElementMulReluFp16", "", "", "", ""}, 13704be168c0dSopenharmony_ci+ {PrimitiveType_MulFusion, schema::ActivationType_RELU6, "ElementMulRelu6Fp16", "", "", "", ""}, 13705be168c0dSopenharmony_ci+ {PrimitiveType_MulFusion, schema::ActivationType_NO_ACTIVATION, "ElementMulFp16", "", "", "", ""}, 13706be168c0dSopenharmony_ci+ {PrimitiveType_AddFusion, schema::ActivationType_RELU, "ElementAddReluFp16", "", "", "", ""}, 13707be168c0dSopenharmony_ci+ {PrimitiveType_AddFusion, schema::ActivationType_RELU6, "ElementAddRelu6Fp16", "", "", "", ""}, 13708be168c0dSopenharmony_ci+ {PrimitiveType_AddFusion, schema::ActivationType_NO_ACTIVATION, "ElementAddFp16", "", "", "", ""}, 13709be168c0dSopenharmony_ci+ {PrimitiveType_SubFusion, schema::ActivationType_RELU, "ElementSubReluFp16", "", "", "", ""}, 13710be168c0dSopenharmony_ci+ {PrimitiveType_SubFusion, schema::ActivationType_RELU6, "ElementSubRelu6Fp16", "", "", "", ""}, 13711be168c0dSopenharmony_ci+ {PrimitiveType_SubFusion, schema::ActivationType_NO_ACTIVATION, "ElementSubFp16", "", "", "", ""}, 13712be168c0dSopenharmony_ci+ {PrimitiveType_DivFusion, schema::ActivationType_RELU, "ElementDivReluFp16", "", "", "", ""}, 13713be168c0dSopenharmony_ci+ {PrimitiveType_DivFusion, schema::ActivationType_RELU6, "ElementDivRelu6Fp16", "", "", "", ""}, 13714be168c0dSopenharmony_ci+ {PrimitiveType_DivFusion, schema::ActivationType_NO_ACTIVATION, "ElementDivFp16", "", "", "", ""}, 13715be168c0dSopenharmony_ci+ {PrimitiveType_RealDiv, schema::ActivationType_RELU, "ElementDivReluFp16", "", "", "", ""}, 13716be168c0dSopenharmony_ci+ {PrimitiveType_RealDiv, schema::ActivationType_RELU6, "ElementDivRelu6Fp16", "", "", "", ""}, 13717be168c0dSopenharmony_ci+ {PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, "ElementDivFp16", "", "", "", ""}, 13718be168c0dSopenharmony_ci+ {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, "ElementLogicalAndFp16", "", "", "", ""}, 13719be168c0dSopenharmony_ci+ {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, "ElementLogicalOrFp16", "", "", "", ""}, 13720be168c0dSopenharmony_ci+ {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, "ElementMaximumFp16", "", "", "", ""}, 13721be168c0dSopenharmony_ci+ {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, "ElementMinimumFp16", "", "", "", ""}, 13722be168c0dSopenharmony_ci+ {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, "ElementFloorModFp16", "", "", "", ""}, 13723be168c0dSopenharmony_ci+ {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, "ElementFloorDivFp16", "", "", "", ""}, 13724be168c0dSopenharmony_ci+ {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, "ElementSquaredDifferenceFp16", "", "", "", 13725be168c0dSopenharmony_ci+ ""}}; 13726be168c0dSopenharmony_ci+} 13727be168c0dSopenharmony_ci+ 13728be168c0dSopenharmony_ci+int ArithmeticDynamicFP16Coder::Prepare(CoderContext *const context) { 13729be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 13730be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 13731be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 13732be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13733be168c0dSopenharmony_ci+ "Tensor data type is invalid"); 13734be168c0dSopenharmony_ci+ } 13735be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 13736be168c0dSopenharmony_ci+ "Tensor data type is invalid"); 13737be168c0dSopenharmony_ci+ filter_tensor_ = input_tensors_.at(SECOND_INPUT); 13738be168c0dSopenharmony_ci+ MS_CHECK_PTR(filter_tensor_); 13739be168c0dSopenharmony_ci+ param_ = reinterpret_cast<ArithmeticParameter *>(parameter_); 13740be168c0dSopenharmony_ci+ MS_CHECK_PTR(param_); 13741be168c0dSopenharmony_ci+ auto primitive_type = param_->op_parameter_.type_; 13742be168c0dSopenharmony_ci+ if (primitive_type == schema::PrimitiveType_Eltwise) { 13743be168c0dSopenharmony_ci+ switch (param_->eltwise_mode_) { 13744be168c0dSopenharmony_ci+ case schema::EltwiseMode_PROD: 13745be168c0dSopenharmony_ci+ primitive_type = schema::PrimitiveType_MulFusion; 13746be168c0dSopenharmony_ci+ break; 13747be168c0dSopenharmony_ci+ case schema::EltwiseMode_SUM: 13748be168c0dSopenharmony_ci+ primitive_type = schema::PrimitiveType_AddFusion; 13749be168c0dSopenharmony_ci+ break; 13750be168c0dSopenharmony_ci+ case schema::EltwiseMode_MAXIMUM: 13751be168c0dSopenharmony_ci+ primitive_type = schema::PrimitiveType_Maximum; 13752be168c0dSopenharmony_ci+ break; 13753be168c0dSopenharmony_ci+ default: 13754be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Eltwise mode not support, mode:" << param_->eltwise_mode_; 13755be168c0dSopenharmony_ci+ return RET_ERROR; 13756be168c0dSopenharmony_ci+ } 13757be168c0dSopenharmony_ci+ } 13758be168c0dSopenharmony_ci+ InitRunFunction(primitive_type); 13759be168c0dSopenharmony_ci+ InitDynamicParams(); 13760be168c0dSopenharmony_ci+ ResetStatus(); 13761be168c0dSopenharmony_ci+ CalcMultiplesAndStrides(); 13762be168c0dSopenharmony_ci+ return RET_OK; 13763be168c0dSopenharmony_ci+} 13764be168c0dSopenharmony_ci+ 13765be168c0dSopenharmony_ci+int ArithmeticDynamicFP16Coder::DoCode(CoderContext *const context) { 13766be168c0dSopenharmony_ci+ input0_ptr_str_ = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13767be168c0dSopenharmony_ci+ input1_ptr_str_ = GetTensorAddr(filter_tensor_, filter_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13768be168c0dSopenharmony_ci+ output_ptr_str_ = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 13769be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 13770be168c0dSopenharmony_ci+ Collect(context, 13771be168c0dSopenharmony_ci+ { 13772be168c0dSopenharmony_ci+ "nnacl/fp16/arithmetic_fp16.h", 13773be168c0dSopenharmony_ci+ "nnacl/base/broadcast_to.h", 13774be168c0dSopenharmony_ci+ }, 13775be168c0dSopenharmony_ci+ { 13776be168c0dSopenharmony_ci+ "arithmetic_fp16.c", 13777be168c0dSopenharmony_ci+ "arithmetic_base.c", 13778be168c0dSopenharmony_ci+ "broadcast_to.c", 13779be168c0dSopenharmony_ci+ }); 13780be168c0dSopenharmony_ci+ 13781be168c0dSopenharmony_ci+ // all elements eltwise calculation 13782be168c0dSopenharmony_ci+ arithmetic_func_str_ = wrap_void(arithmetic_run_); 13783be168c0dSopenharmony_ci+ // run broadcast 13784be168c0dSopenharmony_ci+ auto in0_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13785be168c0dSopenharmony_ci+ std::vector<std::string> in1_shape; 13786be168c0dSopenharmony_ci+ if (filter_tensor_->IsConst()) { 13787be168c0dSopenharmony_ci+ for (auto dim : filter_tensor_->shape()) { 13788be168c0dSopenharmony_ci+ in1_shape.emplace_back(std::to_string(dim)); 13789be168c0dSopenharmony_ci+ } 13790be168c0dSopenharmony_ci+ } else { 13791be168c0dSopenharmony_ci+ in1_shape = shape_info_container_->GetTemplateShape(filter_tensor_); 13792be168c0dSopenharmony_ci+ } 13793be168c0dSopenharmony_ci+ auto out_shape = shape_info_container_->GetTemplateShape(output_tensor_); 13794be168c0dSopenharmony_ci+ broadcast_info_.output_shape_size_ = static_cast<int>(out_shape_.size()); 13795be168c0dSopenharmony_ci+ if (in0_shape != out_shape) { 13796be168c0dSopenharmony_ci+ broadcast_info_.input_shape_size_ = static_cast<int>(in0_shape.size()); 13797be168c0dSopenharmony_ci+ dynamic_shape_info_.input_shape_ = dynamic_param_.in_shape0_; 13798be168c0dSopenharmony_ci+ dynamic_shape_info_.output_shape_ = dynamic_param_.out_shape_; 13799be168c0dSopenharmony_ci+ code.CodeStruct("in0_broadcast_info", broadcast_info_, dynamic_shape_info_); 13800be168c0dSopenharmony_ci+ code.CodeFunction("BroadcastToSize16", input0_ptr_str_, "&in0_broadcast_info", output_ptr_str_); 13801be168c0dSopenharmony_ci+ input0_ptr_str_ = output_ptr_str_; 13802be168c0dSopenharmony_ci+ } 13803be168c0dSopenharmony_ci+ if (in1_shape != out_shape) { 13804be168c0dSopenharmony_ci+ broadcast_info_.input_shape_size_ = static_cast<int>(in1_shape.size()); 13805be168c0dSopenharmony_ci+ dynamic_shape_info_.input_shape_ = dynamic_param_.in_shape1_; 13806be168c0dSopenharmony_ci+ dynamic_shape_info_.output_shape_ = dynamic_param_.out_shape_; 13807be168c0dSopenharmony_ci+ code.CodeStruct("in1_broadcast_info", broadcast_info_, dynamic_shape_info_); 13808be168c0dSopenharmony_ci+ auto temp = output_ptr_str_; 13809be168c0dSopenharmony_ci+ if (input0_ptr_str_ == output_ptr_str_) { 13810be168c0dSopenharmony_ci+ std::map<std::string, std::vector<int>> real_nums; 13811be168c0dSopenharmony_ci+ size_t scene_num = 0; 13812be168c0dSopenharmony_ci+ for (auto &dim_template : out_shape) { 13813be168c0dSopenharmony_ci+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 13814be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 13815be168c0dSopenharmony_ci+ real_nums[dim_template] = dim_nums; 13816be168c0dSopenharmony_ci+ scene_num = std::max(scene_num, dim_nums.size()); 13817be168c0dSopenharmony_ci+ } 13818be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 13819be168c0dSopenharmony_ci+ int out_element_num = 1; 13820be168c0dSopenharmony_ci+ for (size_t j = 0; j < out_shape.size(); ++j) { 13821be168c0dSopenharmony_ci+ if (IsNumber(out_shape[j])) { 13822be168c0dSopenharmony_ci+ out_element_num *= std::stoi(out_shape[j]); 13823be168c0dSopenharmony_ci+ } else { 13824be168c0dSopenharmony_ci+ out_element_num *= real_nums[out_shape[j]][i % real_nums[out_shape[j]].size()]; 13825be168c0dSopenharmony_ci+ } 13826be168c0dSopenharmony_ci+ } 13827be168c0dSopenharmony_ci+ int workspace = out_element_num * DataTypeSize(kNumberTypeFloat16); 13828be168c0dSopenharmony_ci+ temp = dynamic_mem_manager_->AllocWorkSpace(workspace, i); 13829be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!temp.empty(), RET_ERROR, "Arithmetic cannot alloc workspace."); 13830be168c0dSopenharmony_ci+ } 13831be168c0dSopenharmony_ci+ } 13832be168c0dSopenharmony_ci+ code.CodeFunction("BroadcastToSize16", input1_ptr_str_, "&in1_broadcast_info", temp); 13833be168c0dSopenharmony_ci+ input1_ptr_str_ = temp; 13834be168c0dSopenharmony_ci+ } 13835be168c0dSopenharmony_ci+ return ExecuteCode("(float16_t *)(" + input0_ptr_str_ + ")", "(float16_t *)(" + input1_ptr_str_ + ")", 13836be168c0dSopenharmony_ci+ "(float16_t *)(" + output_ptr_str_ + ")", dynamic_param_.out_elements_num_, context, &code); 13837be168c0dSopenharmony_ci+} 13838be168c0dSopenharmony_ci+ 13839be168c0dSopenharmony_ci+void ArithmeticDynamicFP16Coder::InitDynamicParams() { 13840be168c0dSopenharmony_ci+ auto in0_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13841be168c0dSopenharmony_ci+ std::vector<std::string> in1_shape; 13842be168c0dSopenharmony_ci+ if (filter_tensor_->IsConst()) { 13843be168c0dSopenharmony_ci+ for (auto dim : filter_tensor_->shape()) { 13844be168c0dSopenharmony_ci+ in1_shape.emplace_back(std::to_string(dim)); 13845be168c0dSopenharmony_ci+ } 13846be168c0dSopenharmony_ci+ } else { 13847be168c0dSopenharmony_ci+ in1_shape = shape_info_container_->GetTemplateShape(filter_tensor_); 13848be168c0dSopenharmony_ci+ } 13849be168c0dSopenharmony_ci+ auto out_shape = shape_info_container_->GetTemplateShape(output_tensor_); 13850be168c0dSopenharmony_ci+ dynamic_param_.in_shape0_ = "{"; 13851be168c0dSopenharmony_ci+ dynamic_param_.in_shape1_ = "{"; 13852be168c0dSopenharmony_ci+ dynamic_param_.out_shape_ = "{"; 13853be168c0dSopenharmony_ci+ for (auto shape : in0_shape) { 13854be168c0dSopenharmony_ci+ dynamic_param_.in_shape0_ += shape + ", "; 13855be168c0dSopenharmony_ci+ } 13856be168c0dSopenharmony_ci+ for (auto shape : in1_shape) { 13857be168c0dSopenharmony_ci+ dynamic_param_.in_shape1_ += shape + ", "; 13858be168c0dSopenharmony_ci+ } 13859be168c0dSopenharmony_ci+ for (auto shape : out_shape) { 13860be168c0dSopenharmony_ci+ dynamic_param_.out_shape_ += shape + ", "; 13861be168c0dSopenharmony_ci+ } 13862be168c0dSopenharmony_ci+ dynamic_param_.in_shape0_ += "}"; 13863be168c0dSopenharmony_ci+ dynamic_param_.in_shape1_ += "}"; 13864be168c0dSopenharmony_ci+ dynamic_param_.out_shape_ += "}"; 13865be168c0dSopenharmony_ci+ dynamic_param_.in_elements_num0_ = AccumulateShape(in0_shape, 0, in0_shape.size()); 13866be168c0dSopenharmony_ci+ dynamic_param_.in_elements_num1_ = AccumulateShape(in1_shape, 0, in1_shape.size()); 13867be168c0dSopenharmony_ci+ dynamic_param_.out_elements_num_ = AccumulateShape(out_shape, 0, out_shape.size()); 13868be168c0dSopenharmony_ci+} 13869be168c0dSopenharmony_ci+ 13870be168c0dSopenharmony_ci+void ArithmeticDynamicFP16Coder::InitRunFunction(int primitive_type) { 13871be168c0dSopenharmony_ci+ InitFunTable(); 13872be168c0dSopenharmony_ci+ for (size_t i = 0; i < fun_table_.size(); i++) { 13873be168c0dSopenharmony_ci+ if (fun_table_[i].primitive_type_ == primitive_type && fun_table_[i].activation_type_ == param_->activation_type_) { 13874be168c0dSopenharmony_ci+ arithmetic_run_ = fun_table_[i].func_; 13875be168c0dSopenharmony_ci+ arithmetic_run_int_ = fun_table_[i].int_func_; 13876be168c0dSopenharmony_ci+ arithmetic_run_bool_ = fun_table_[i].bool_func_; 13877be168c0dSopenharmony_ci+ arithmetic_opt_run_ = fun_table_[i].opt_func_; 13878be168c0dSopenharmony_ci+ arithmetic_opt_run_int_ = fun_table_[i].opt_int_func_; 13879be168c0dSopenharmony_ci+ } 13880be168c0dSopenharmony_ci+ } 13881be168c0dSopenharmony_ci+ arithmetic_func_type_ = kArithmeticFuncFloat; 13882be168c0dSopenharmony_ci+} 13883be168c0dSopenharmony_ci+ 13884be168c0dSopenharmony_ci+void ArithmeticDynamicFP16Coder::ResetStatus() { 13885be168c0dSopenharmony_ci+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 13886be168c0dSopenharmony_ci+ std::vector<std::string> filter_shape; 13887be168c0dSopenharmony_ci+ if (filter_tensor_->IsConst()) { 13888be168c0dSopenharmony_ci+ for (auto dim : filter_tensor_->shape()) { 13889be168c0dSopenharmony_ci+ filter_shape.emplace_back(std::to_string(dim)); 13890be168c0dSopenharmony_ci+ } 13891be168c0dSopenharmony_ci+ } else { 13892be168c0dSopenharmony_ci+ filter_shape = shape_info_container_->GetTemplateShape(filter_tensor_); 13893be168c0dSopenharmony_ci+ } 13894be168c0dSopenharmony_ci+ auto dim_num = input_shape.size() >= filter_shape.size() ? input_shape.size() : filter_shape.size(); 13895be168c0dSopenharmony_ci+ for (size_t i = 0; i < dim_num - input_shape.size(); ++i) { 13896be168c0dSopenharmony_ci+ in0_shape_.emplace_back("1"); 13897be168c0dSopenharmony_ci+ } 13898be168c0dSopenharmony_ci+ in0_shape_.insert(in0_shape_.end(), input_shape.begin(), input_shape.end()); 13899be168c0dSopenharmony_ci+ for (size_t i = 0; i < dim_num - filter_shape.size(); ++i) { 13900be168c0dSopenharmony_ci+ in1_shape_.emplace_back("1"); 13901be168c0dSopenharmony_ci+ } 13902be168c0dSopenharmony_ci+ in1_shape_.insert(in1_shape_.end(), filter_shape.begin(), filter_shape.end()); 13903be168c0dSopenharmony_ci+} 13904be168c0dSopenharmony_ci+ 13905be168c0dSopenharmony_ci+void ArithmeticDynamicFP16Coder::CalcMultiplesAndStrides() { 13906be168c0dSopenharmony_ci+ out_shape_ = shape_info_container_->GetTemplateShape(output_tensor_); 13907be168c0dSopenharmony_ci+ dynamic_param_.multiples0_ = "{"; 13908be168c0dSopenharmony_ci+ dynamic_param_.multiples1_ = "{"; 13909be168c0dSopenharmony_ci+ for (size_t i = 0; i < param_->ndim_; i++) { 13910be168c0dSopenharmony_ci+ if (in0_shape_[i] != "0") { 13911be168c0dSopenharmony_ci+ dynamic_param_.multiples0_ += out_shape_[i] + " / " + in0_shape_[i] + ", "; 13912be168c0dSopenharmony_ci+ } 13913be168c0dSopenharmony_ci+ if (in1_shape_[i] != "0") { 13914be168c0dSopenharmony_ci+ dynamic_param_.multiples1_ += out_shape_[i] + " / " + in1_shape_[i] + ", "; 13915be168c0dSopenharmony_ci+ } 13916be168c0dSopenharmony_ci+ } 13917be168c0dSopenharmony_ci+ dynamic_param_.multiples0_ += "}"; 13918be168c0dSopenharmony_ci+ dynamic_param_.multiples1_ += "}"; 13919be168c0dSopenharmony_ci+ 13920be168c0dSopenharmony_ci+ // cal strides 13921be168c0dSopenharmony_ci+ in0_strides_.resize(param_->ndim_); 13922be168c0dSopenharmony_ci+ in1_strides_.resize(param_->ndim_); 13923be168c0dSopenharmony_ci+ out_strides_.resize(param_->ndim_); 13924be168c0dSopenharmony_ci+ ComputeStrides(in0_shape_, in0_strides_); 13925be168c0dSopenharmony_ci+ ComputeStrides(in1_shape_, in1_strides_); 13926be168c0dSopenharmony_ci+ ComputeStrides(out_shape_, out_strides_); 13927be168c0dSopenharmony_ci+ dynamic_param_.in_strides0_ = "{"; 13928be168c0dSopenharmony_ci+ dynamic_param_.in_strides1_ = "{"; 13929be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ = "{"; 13930be168c0dSopenharmony_ci+ for (size_t i = 0; i < param_->ndim_; ++i) { 13931be168c0dSopenharmony_ci+ dynamic_param_.in_strides0_ += in0_strides_[i] + ", "; 13932be168c0dSopenharmony_ci+ dynamic_param_.in_strides1_ += in1_strides_[i] + ", "; 13933be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ += out_strides_[i] + ", "; 13934be168c0dSopenharmony_ci+ } 13935be168c0dSopenharmony_ci+ dynamic_param_.in_strides0_ += "}"; 13936be168c0dSopenharmony_ci+ dynamic_param_.in_strides1_ += "}"; 13937be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ += "}"; 13938be168c0dSopenharmony_ci+} 13939be168c0dSopenharmony_ci+ 13940be168c0dSopenharmony_ci+void ArithmeticDynamicFP16Coder::ComputeStrides(const std::vector<std::string> &shape, 13941be168c0dSopenharmony_ci+ std::vector<std::string> &strides) { 13942be168c0dSopenharmony_ci+ std::string stride = "1"; 13943be168c0dSopenharmony_ci+ for (int i = param_->ndim_ - 1; i >= 0; i--) { 13944be168c0dSopenharmony_ci+ strides[i] = stride; 13945be168c0dSopenharmony_ci+ stride += "*=" + shape[i]; 13946be168c0dSopenharmony_ci+ } 13947be168c0dSopenharmony_ci+} 13948be168c0dSopenharmony_ci+ 13949be168c0dSopenharmony_ci+int ArithmeticDynamicFP16Coder::ExecuteCode(const std::string &input0, const std::string &input1, 13950be168c0dSopenharmony_ci+ const std::string &output, const std::string size, 13951be168c0dSopenharmony_ci+ CoderContext *const context, NNaclFp32Serializer *const code) { 13952be168c0dSopenharmony_ci+ if (arithmetic_func_str_.empty()) { 13953be168c0dSopenharmony_ci+ return RET_ERROR; 13954be168c0dSopenharmony_ci+ } 13955be168c0dSopenharmony_ci+ for (size_t i = 0; i < fun_table_.size(); i++) { 13956be168c0dSopenharmony_ci+ if (fun_table_[i].primitive_type_ == param_->op_parameter_.type_ && 13957be168c0dSopenharmony_ci+ fun_table_[i].activation_type_ == param_->activation_type_) { 13958be168c0dSopenharmony_ci+ code->CodeFunction(fun_table_[i].func_, input0, input1, output, size); 13959be168c0dSopenharmony_ci+ break; 13960be168c0dSopenharmony_ci+ } 13961be168c0dSopenharmony_ci+ } 13962be168c0dSopenharmony_ci+ context->AppendCode(code->str()); 13963be168c0dSopenharmony_ci+ return RET_OK; 13964be168c0dSopenharmony_ci+} 13965be168c0dSopenharmony_ci+ 13966be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_AddFusion, 13967be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13968be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_MulFusion, 13969be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13970be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_SubFusion, 13971be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13972be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_DivFusion, 13973be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13974be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_RealDiv, 13975be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13976be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LogicalAnd, 13977be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13978be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LogicalOr, 13979be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13980be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Maximum, 13981be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13982be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Minimum, 13983be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13984be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_FloorDiv, 13985be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13986be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_FloorMod, 13987be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13988be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_SquaredDifference, 13989be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13990be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Equal, 13991be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13992be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_NotEqual, 13993be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13994be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Less, 13995be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13996be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LessEqual, 13997be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 13998be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Greater, 13999be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14000be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_GreaterEqual, 14001be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14002be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Eltwise, 14003be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14004be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_AddFusion, 14005be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14006be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_MulFusion, 14007be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14008be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_SubFusion, 14009be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14010be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_DivFusion, 14011be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14012be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_RealDiv, 14013be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14014be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LogicalAnd, 14015be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14016be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LogicalOr, 14017be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14018be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Maximum, 14019be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14020be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Minimum, 14021be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14022be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_FloorDiv, 14023be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14024be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_FloorMod, 14025be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14026be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_SquaredDifference, 14027be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14028be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Equal, 14029be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14030be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_NotEqual, 14031be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14032be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Less, 14033be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14034be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LessEqual, 14035be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14036be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Greater, 14037be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14038be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_GreaterEqual, 14039be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14040be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Eltwise, 14041be168c0dSopenharmony_ci+ CPUOpCoderCreator<ArithmeticDynamicFP16Coder>) 14042be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14043be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h 14044be168c0dSopenharmony_cinew file mode 100644 14045be168c0dSopenharmony_ciindex 00000000..87e43687 14046be168c0dSopenharmony_ci--- /dev/null 14047be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/arithmetic_dynamic_fp16_coder.h 14048be168c0dSopenharmony_ci@@ -0,0 +1,132 @@ 14049be168c0dSopenharmony_ci+/** 14050be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14051be168c0dSopenharmony_ci+ * 14052be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14053be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14054be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14055be168c0dSopenharmony_ci+ * 14056be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14057be168c0dSopenharmony_ci+ * 14058be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14059be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14060be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14061be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14062be168c0dSopenharmony_ci+ * limitations under the License. 14063be168c0dSopenharmony_ci+ */ 14064be168c0dSopenharmony_ci+ 14065be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ARITHMETIC_DYNAMIC_FP16_CODER_H_ 14066be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ARITHMETIC_DYNAMIC_FP16_CODER_H_ 14067be168c0dSopenharmony_ci+ 14068be168c0dSopenharmony_ci+#include <vector> 14069be168c0dSopenharmony_ci+#include <string> 14070be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 14071be168c0dSopenharmony_ci+#include "nnacl/base/cast_base.h" 14072be168c0dSopenharmony_ci+#include "nnacl/arithmetic_parameter.h" 14073be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 14074be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h" 14075be168c0dSopenharmony_ci+#include "nnacl/broadcast_to_parameter.h" 14076be168c0dSopenharmony_ci+ 14077be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14078be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_AddFusion; 14079be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_DivFusion; 14080be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Eltwise; 14081be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Equal; 14082be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_FloorDiv; 14083be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_FloorMod; 14084be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Greater; 14085be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_GreaterEqual; 14086be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Less; 14087be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_LessEqual; 14088be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_LogicalAnd; 14089be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_LogicalOr; 14090be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Maximum; 14091be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Minimum; 14092be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Mod; 14093be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_MulFusion; 14094be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_NotEqual; 14095be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_RealDiv; 14096be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_SquaredDifference; 14097be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_SubFusion; 14098be168c0dSopenharmony_ci+ 14099be168c0dSopenharmony_ci+class ArithmeticDynamicFP16Coder final : public OperatorCoder { 14100be168c0dSopenharmony_ci+ typedef struct { 14101be168c0dSopenharmony_ci+ int primitive_type_; 14102be168c0dSopenharmony_ci+ int activation_type_; 14103be168c0dSopenharmony_ci+ std::string func_; 14104be168c0dSopenharmony_ci+ std::string int_func_; 14105be168c0dSopenharmony_ci+ std::string bool_func_; 14106be168c0dSopenharmony_ci+ std::string opt_func_; 14107be168c0dSopenharmony_ci+ std::string opt_int_func_; 14108be168c0dSopenharmony_ci+ } ARITHMETIC_FUNC_INFO_FP16; 14109be168c0dSopenharmony_ci+ 14110be168c0dSopenharmony_ci+ // typedef struct MATRIC_INFO { 14111be168c0dSopenharmony_ci+ // bool is_const{false}; 14112be168c0dSopenharmony_ci+ // bool is_valid{false}; 14113be168c0dSopenharmony_ci+ // void *data{nullptr}; 14114be168c0dSopenharmony_ci+ // int64_t inner_size{1}; // the element num of once batch 14115be168c0dSopenharmony_ci+ // std::vector<int64_t> shape; 14116be168c0dSopenharmony_ci+ // std::vector<int64_t> batch_post_sum; 14117be168c0dSopenharmony_ci+ // void Reset() { 14118be168c0dSopenharmony_ci+ // is_valid = false; 14119be168c0dSopenharmony_ci+ // data = nullptr; 14120be168c0dSopenharmony_ci+ // inner_size = 1; 14121be168c0dSopenharmony_ci+ // shape.clear(); 14122be168c0dSopenharmony_ci+ // batch_post_sum.clear(); 14123be168c0dSopenharmony_ci+ // } 14124be168c0dSopenharmony_ci+ // } MATRIC_INFO; 14125be168c0dSopenharmony_ci+ 14126be168c0dSopenharmony_ci+ public: 14127be168c0dSopenharmony_ci+ ArithmeticDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14128be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 14129be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14130be168c0dSopenharmony_ci+ 14131be168c0dSopenharmony_ci+ ~ArithmeticDynamicFP16Coder() override = default; 14132be168c0dSopenharmony_ci+ 14133be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 14134be168c0dSopenharmony_ci+ 14135be168c0dSopenharmony_ci+ private: 14136be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 14137be168c0dSopenharmony_ci+ 14138be168c0dSopenharmony_ci+ void InitFunTable(); 14139be168c0dSopenharmony_ci+ 14140be168c0dSopenharmony_ci+ void InitRunFunction(int primitive_type); 14141be168c0dSopenharmony_ci+ 14142be168c0dSopenharmony_ci+ void InitDynamicParams(); 14143be168c0dSopenharmony_ci+ 14144be168c0dSopenharmony_ci+ void ResetStatus(); 14145be168c0dSopenharmony_ci+ 14146be168c0dSopenharmony_ci+ void CalcMultiplesAndStrides(); 14147be168c0dSopenharmony_ci+ 14148be168c0dSopenharmony_ci+ void ComputeStrides(const std::vector<std::string> &shape, std::vector<std::string> &strides); 14149be168c0dSopenharmony_ci+ 14150be168c0dSopenharmony_ci+ int ExecuteCode(const std::string &input0, const std::string &input1, const std::string &output, 14151be168c0dSopenharmony_ci+ const std::string size, CoderContext *const context, NNaclFp32Serializer *const code); 14152be168c0dSopenharmony_ci+ 14153be168c0dSopenharmony_ci+ std::vector<ARITHMETIC_FUNC_INFO_FP16> fun_table_; 14154be168c0dSopenharmony_ci+ ArithmeticFuncType arithmetic_func_type_{kArithmeticFuncUnknow}; 14155be168c0dSopenharmony_ci+ ArithmeticParameter *param_{nullptr}; 14156be168c0dSopenharmony_ci+ ArithmeticDynamicParameter dynamic_param_; 14157be168c0dSopenharmony_ci+ BroadcastShapeInfo broadcast_info_; 14158be168c0dSopenharmony_ci+ BroadcastDynamicShapeInfo dynamic_shape_info_; 14159be168c0dSopenharmony_ci+ Tensor *filter_tensor_{nullptr}; 14160be168c0dSopenharmony_ci+ std::string input0_ptr_str_; 14161be168c0dSopenharmony_ci+ std::string input1_ptr_str_; 14162be168c0dSopenharmony_ci+ std::string output_ptr_str_; 14163be168c0dSopenharmony_ci+ std::string arithmetic_run_; 14164be168c0dSopenharmony_ci+ std::string arithmetic_run_int_; 14165be168c0dSopenharmony_ci+ std::string arithmetic_opt_run_; 14166be168c0dSopenharmony_ci+ std::string arithmetic_opt_run_int_; 14167be168c0dSopenharmony_ci+ std::string arithmetic_run_bool_; 14168be168c0dSopenharmony_ci+ std::string arithmetic_func_str_; 14169be168c0dSopenharmony_ci+ std::vector<std::string> in0_shape_; 14170be168c0dSopenharmony_ci+ std::vector<std::string> in1_shape_; 14171be168c0dSopenharmony_ci+ std::vector<std::string> out_shape_; 14172be168c0dSopenharmony_ci+ std::vector<std::string> in0_strides_; 14173be168c0dSopenharmony_ci+ std::vector<std::string> in1_strides_; 14174be168c0dSopenharmony_ci+ std::vector<std::string> out_strides_; 14175be168c0dSopenharmony_ci+ // MATRIC_INFO a_matric_; 14176be168c0dSopenharmony_ci+ // MATRIC_INFO b_matric_; 14177be168c0dSopenharmony_ci+ // MATRIC_INFO c_matric_; 14178be168c0dSopenharmony_ci+}; 14179be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14180be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_ARITHMETIC_DYNAMIC_FP16_CODER_H_ 14181be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc 14182be168c0dSopenharmony_cinew file mode 100644 14183be168c0dSopenharmony_ciindex 00000000..bf8bd06b 14184be168c0dSopenharmony_ci--- /dev/null 14185be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.cc 14186be168c0dSopenharmony_ci@@ -0,0 +1,92 @@ 14187be168c0dSopenharmony_ci+/** 14188be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14189be168c0dSopenharmony_ci+ * 14190be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14191be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14192be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14193be168c0dSopenharmony_ci+ * 14194be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14195be168c0dSopenharmony_ci+ * 14196be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14197be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14198be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14199be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14200be168c0dSopenharmony_ci+ * limitations under the License. 14201be168c0dSopenharmony_ci+ */ 14202be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h" 14203be168c0dSopenharmony_ci+#include <string> 14204be168c0dSopenharmony_ci+#include <vector> 14205be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 14206be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 14207be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 14208be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 14209be168c0dSopenharmony_ci+ 14210be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Concat; 14211be168c0dSopenharmony_ci+ 14212be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14213be168c0dSopenharmony_ci+int ConcatDynamicFP16Coder::Prepare(CoderContext *const context) { 14214be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 14215be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_.at(i)->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14216be168c0dSopenharmony_ci+ "input tensor data type is invalid."); 14217be168c0dSopenharmony_ci+ } 14218be168c0dSopenharmony_ci+ concat_param_ = reinterpret_cast<ConcatParameter *>(parameter_); 14219be168c0dSopenharmony_ci+ MS_CHECK_PTR(concat_param_); 14220be168c0dSopenharmony_ci+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 14221be168c0dSopenharmony_ci+ axis_ = 14222be168c0dSopenharmony_ci+ concat_param_->axis_ >= 0 ? concat_param_->axis_ : static_cast<int>(input_shape.size()) + concat_param_->axis_; 14223be168c0dSopenharmony_ci+ return RET_OK; 14224be168c0dSopenharmony_ci+} 14225be168c0dSopenharmony_ci+ 14226be168c0dSopenharmony_ci+int ConcatDynamicFP16Coder::DoCode(CoderContext *const context) { 14227be168c0dSopenharmony_ci+ Collect(context, 14228be168c0dSopenharmony_ci+ { 14229be168c0dSopenharmony_ci+ "nnacl/base/concat_base.h", 14230be168c0dSopenharmony_ci+ }, 14231be168c0dSopenharmony_ci+ { 14232be168c0dSopenharmony_ci+ "concat_base.c", 14233be168c0dSopenharmony_ci+ }); 14234be168c0dSopenharmony_ci+ 14235be168c0dSopenharmony_ci+ size_t input_num = input_tensors_.size(); 14236be168c0dSopenharmony_ci+ 14237be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 14238be168c0dSopenharmony_ci+ code << "\t\tvoid *inputs_addr[] = {"; 14239be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_num; ++i) { 14240be168c0dSopenharmony_ci+ code << "(void *)(" 14241be168c0dSopenharmony_ci+ << GetTensorAddr(input_tensors_.at(i), input_tensors_.at(i)->IsConst(), dynamic_mem_manager_, allocator_) 14242be168c0dSopenharmony_ci+ << "), "; 14243be168c0dSopenharmony_ci+ } 14244be168c0dSopenharmony_ci+ code << "};\n"; 14245be168c0dSopenharmony_ci+ 14246be168c0dSopenharmony_ci+ size_t i; 14247be168c0dSopenharmony_ci+ for (i = 0; i < input_num; ++i) { 14248be168c0dSopenharmony_ci+ code << "\t\tint shape_" << i << "[] = {"; 14249be168c0dSopenharmony_ci+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensors_.at(i)); 14250be168c0dSopenharmony_ci+ for (auto &shape : in_shape) { 14251be168c0dSopenharmony_ci+ code << shape << ", "; 14252be168c0dSopenharmony_ci+ } 14253be168c0dSopenharmony_ci+ code << "};\n"; 14254be168c0dSopenharmony_ci+ } 14255be168c0dSopenharmony_ci+ 14256be168c0dSopenharmony_ci+ auto out_shape = shape_info_container_->GetTemplateShape(output_tensor_); 14257be168c0dSopenharmony_ci+ code << "\t\tint shape_" << i << "[] = {"; 14258be168c0dSopenharmony_ci+ for (auto &shape : out_shape) { 14259be168c0dSopenharmony_ci+ code << shape << ", "; 14260be168c0dSopenharmony_ci+ } 14261be168c0dSopenharmony_ci+ code << "};\n"; 14262be168c0dSopenharmony_ci+ 14263be168c0dSopenharmony_ci+ code << "\t\tint *inputs_output_shape[] = {"; 14264be168c0dSopenharmony_ci+ for (i = 0; i <= input_num; ++i) { 14265be168c0dSopenharmony_ci+ code << "shape_" << i << ", "; 14266be168c0dSopenharmony_ci+ } 14267be168c0dSopenharmony_ci+ code << "};\n"; 14268be168c0dSopenharmony_ci+ std::string output_data = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 14269be168c0dSopenharmony_ci+ code.CodeFunction("Concat", "inputs_addr", input_num, axis_, "inputs_output_shape", out_shape.size(), output_data, 0, 14270be168c0dSopenharmony_ci+ 1, sizeof(uint16_t)); 14271be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 14272be168c0dSopenharmony_ci+ return RET_OK; 14273be168c0dSopenharmony_ci+} 14274be168c0dSopenharmony_ci+ 14275be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Concat, CPUOpCoderCreator<ConcatDynamicFP16Coder>) 14276be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Concat, CPUOpCoderCreator<ConcatDynamicFP16Coder>) 14277be168c0dSopenharmony_ci+ 14278be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14279be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h 14280be168c0dSopenharmony_cinew file mode 100644 14281be168c0dSopenharmony_ciindex 00000000..bd1b7ff6 14282be168c0dSopenharmony_ci--- /dev/null 14283be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/concat_dynamic_fp16_coder.h 14284be168c0dSopenharmony_ci@@ -0,0 +1,40 @@ 14285be168c0dSopenharmony_ci+/** 14286be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14287be168c0dSopenharmony_ci+ * 14288be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14289be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14290be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14291be168c0dSopenharmony_ci+ * 14292be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14293be168c0dSopenharmony_ci+ * 14294be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14295be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14296be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14297be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14298be168c0dSopenharmony_ci+ * limitations under the License. 14299be168c0dSopenharmony_ci+ */ 14300be168c0dSopenharmony_ci+ 14301be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONCAT_DYNAMIC_FP16_CODER_H_ 14302be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONCAT_DYNAMIC_FP16_CODER_H_ 14303be168c0dSopenharmony_ci+ 14304be168c0dSopenharmony_ci+#include <vector> 14305be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 14306be168c0dSopenharmony_ci+#include "nnacl/concat_parameter.h" 14307be168c0dSopenharmony_ci+ 14308be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14309be168c0dSopenharmony_ci+class ConcatDynamicFP16Coder final : public OperatorCoder { 14310be168c0dSopenharmony_ci+ public: 14311be168c0dSopenharmony_ci+ ConcatDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14312be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 14313be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14314be168c0dSopenharmony_ci+ ~ConcatDynamicFP16Coder() override = default; 14315be168c0dSopenharmony_ci+ 14316be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 14317be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 14318be168c0dSopenharmony_ci+ 14319be168c0dSopenharmony_ci+ private: 14320be168c0dSopenharmony_ci+ int axis_{0}; 14321be168c0dSopenharmony_ci+ ConcatParameter *concat_param_{nullptr}; 14322be168c0dSopenharmony_ci+}; 14323be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14324be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONCAT_DYNAMIC_FP16_CODER_H_ 14325be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc 14326be168c0dSopenharmony_cinew file mode 100644 14327be168c0dSopenharmony_ciindex 00000000..2f4e42e7 14328be168c0dSopenharmony_ci--- /dev/null 14329be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.cc 14330be168c0dSopenharmony_ci@@ -0,0 +1,155 @@ 14331be168c0dSopenharmony_ci+/** 14332be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14333be168c0dSopenharmony_ci+ * 14334be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14335be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14336be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14337be168c0dSopenharmony_ci+ * 14338be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14339be168c0dSopenharmony_ci+ * 14340be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14341be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14342be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14343be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14344be168c0dSopenharmony_ci+ * limitations under the License. 14345be168c0dSopenharmony_ci+ */ 14346be168c0dSopenharmony_ci+ 14347be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h" 14348be168c0dSopenharmony_ci+#include "src/common/version_manager.h" 14349be168c0dSopenharmony_ci+#include "src/common/tensor_util.h" 14350be168c0dSopenharmony_ci+#include "src/common/ops/populate/populate_register.h" 14351be168c0dSopenharmony_ci+#include "nnacl/fp32/winograd_utils.h" 14352be168c0dSopenharmony_ci+#include "nnacl/base/conv_common_base.h" 14353be168c0dSopenharmony_ci+#include "nnacl/infer/conv2d_infer.h" 14354be168c0dSopenharmony_ci+#include "coder/shape_info_container.h" 14355be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" 14356be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" 14357be168c0dSopenharmony_ci+ 14358be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Conv2DFusion; 14359be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14360be168c0dSopenharmony_ci+int ConvDelegateDynamicFP16Coder::Prepare(CoderContext *const context) { 14361be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 14362be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14363be168c0dSopenharmony_ci+ "Input tensor data type is invalid"); 14364be168c0dSopenharmony_ci+ } 14365be168c0dSopenharmony_ci+ for (size_t i = 0; i < output_tensors_.size(); ++i) { 14366be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14367be168c0dSopenharmony_ci+ "Output tensor data type is invalid"); 14368be168c0dSopenharmony_ci+ } 14369be168c0dSopenharmony_ci+ // Update shape info of input and output 14370be168c0dSopenharmony_ci+ ConvDynamicParameter dynamic_param; 14371be168c0dSopenharmony_ci+ SetInputOutputShapeInfo(reinterpret_cast<ConvParameter *>(parameter_), dynamic_param, input_tensor_, output_tensor_); 14372be168c0dSopenharmony_ci+ if (conv_coder_ == nullptr) { 14373be168c0dSopenharmony_ci+ // need to select actual execute coder here 14374be168c0dSopenharmony_ci+ conv_coder_ = 14375be168c0dSopenharmony_ci+ CPUConvFP16DynamicCoderSelect(input_tensors_, output_tensors_, node_, node_index(), target_, schema_version_); 14376be168c0dSopenharmony_ci+ MS_CHECK_PTR(conv_coder_); 14377be168c0dSopenharmony_ci+ ConvParameter *op_parameter = static_cast<ConvParameter *>(malloc(sizeof(ConvParameter))); 14378be168c0dSopenharmony_ci+ if (op_parameter == nullptr) { 14379be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc ConvParameter failed."; 14380be168c0dSopenharmony_ci+ return RET_ERROR; 14381be168c0dSopenharmony_ci+ } 14382be168c0dSopenharmony_ci+ if (memcpy_s(op_parameter, sizeof(ConvParameter), parameter_, sizeof(ConvParameter)) != EOK) { 14383be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "memcpy_s failed."; 14384be168c0dSopenharmony_ci+ free(op_parameter); 14385be168c0dSopenharmony_ci+ return RET_ERROR; 14386be168c0dSopenharmony_ci+ } 14387be168c0dSopenharmony_ci+ conv_coder_->set_type(GetPrimitiveType(node_->primitive_, schema_version_)); 14388be168c0dSopenharmony_ci+ conv_coder_->set_thread_num(thread_num_); 14389be168c0dSopenharmony_ci+ conv_coder_->set_parameter(reinterpret_cast<OpParameter *>(op_parameter)); 14390be168c0dSopenharmony_ci+ conv_coder_->set_shape_info_container(shape_info_container_); 14391be168c0dSopenharmony_ci+ conv_coder_->set_dynamic_mem_manager(dynamic_mem_manager_); 14392be168c0dSopenharmony_ci+ } 14393be168c0dSopenharmony_ci+ return conv_coder_->Prepare(context); 14394be168c0dSopenharmony_ci+} 14395be168c0dSopenharmony_ci+ 14396be168c0dSopenharmony_ci+int ConvDelegateDynamicFP16Coder::DoCode(CoderContext *const context) { return conv_coder_->DoCode(context); } 14397be168c0dSopenharmony_ci+ 14398be168c0dSopenharmony_ci+void ConvDelegateDynamicFP16Coder::SetInputOutputShapeInfo(ConvParameter *conv_param, 14399be168c0dSopenharmony_ci+ ConvDynamicParameter &dynamic_param, 14400be168c0dSopenharmony_ci+ const lite::Tensor *input, const lite::Tensor *output) { 14401be168c0dSopenharmony_ci+ dynamic_param.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_).at(0); 14402be168c0dSopenharmony_ci+ conv_param->input_h_ = input->Height(); 14403be168c0dSopenharmony_ci+ conv_param->input_w_ = input->Width(); 14404be168c0dSopenharmony_ci+ conv_param->input_channel_ = input->Channel(); 14405be168c0dSopenharmony_ci+ dynamic_param.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_).at(0); 14406be168c0dSopenharmony_ci+ conv_param->output_h_ = output->Height(); 14407be168c0dSopenharmony_ci+ conv_param->output_w_ = output->Width(); 14408be168c0dSopenharmony_ci+ conv_param->output_channel_ = output->Channel(); 14409be168c0dSopenharmony_ci+} 14410be168c0dSopenharmony_ci+ 14411be168c0dSopenharmony_ci+std::unique_ptr<OperatorCoder> CPUConvFP16DynamicCoderSelect(const std::vector<lite::Tensor *> &in_tensors, 14412be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &out_tensors, 14413be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, 14414be168c0dSopenharmony_ci+ Target target, int schema_version) { 14415be168c0dSopenharmony_ci+ const void *primitive = node->primitive_; 14416be168c0dSopenharmony_ci+ if (primitive == nullptr) { 14417be168c0dSopenharmony_ci+ return nullptr; 14418be168c0dSopenharmony_ci+ } 14419be168c0dSopenharmony_ci+ ParameterGen paramGen = PopulateRegistry::GetInstance()->GetParameterCreator( 14420be168c0dSopenharmony_ci+ GetPrimitiveType(node->primitive_, schema_version), schema_version); 14421be168c0dSopenharmony_ci+ MS_CHECK_PTR_RET_NULL(paramGen); 14422be168c0dSopenharmony_ci+ auto conv_param = reinterpret_cast<ConvParameter *>(paramGen(node->primitive_)); 14423be168c0dSopenharmony_ci+ MS_CHECK_PTR_RET_NULL(conv_param); 14424be168c0dSopenharmony_ci+ int kernel_h = conv_param->kernel_h_; 14425be168c0dSopenharmony_ci+ int kernel_w = conv_param->kernel_w_; 14426be168c0dSopenharmony_ci+ conv_param->input_h_ = in_tensors.at(kInputIndex)->Height(); 14427be168c0dSopenharmony_ci+ conv_param->input_w_ = in_tensors.at(kInputIndex)->Width(); 14428be168c0dSopenharmony_ci+ conv_param->input_channel_ = in_tensors.at(kInputIndex)->Channel(); 14429be168c0dSopenharmony_ci+ conv_param->output_h_ = out_tensors.at(kOutputIndex)->Height(); 14430be168c0dSopenharmony_ci+ conv_param->output_w_ = out_tensors.at(kOutputIndex)->Width(); 14431be168c0dSopenharmony_ci+ conv_param->output_channel_ = out_tensors.at(kOutputIndex)->Channel(); 14432be168c0dSopenharmony_ci+ conv_param->op_parameter_.thread_num_ = 1; 14433be168c0dSopenharmony_ci+ free(conv_param); 14434be168c0dSopenharmony_ci+ std::unique_ptr<OperatorCoder> coder; 14435be168c0dSopenharmony_ci+ if (kernel_h == 1 && kernel_w == 1) { 14436be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "create Convolution1x1DynamicFP16CPUKernel"; 14437be168c0dSopenharmony_ci+ coder = CPUOpCoderCreator<Convolution1x1DynamicFP16Coder>(in_tensors, out_tensors, node, node_index, target, 14438be168c0dSopenharmony_ci+ schema_version); 14439be168c0dSopenharmony_ci+ } else { 14440be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "create ConvolutionDynamicFP16Coder"; 14441be168c0dSopenharmony_ci+ coder = 14442be168c0dSopenharmony_ci+ CPUOpCoderCreator<ConvolutionDynamicFP16Coder>(in_tensors, out_tensors, node, node_index, target, schema_version); 14443be168c0dSopenharmony_ci+ } 14444be168c0dSopenharmony_ci+ return coder; 14445be168c0dSopenharmony_ci+} 14446be168c0dSopenharmony_ci+ 14447be168c0dSopenharmony_ci+std::unique_ptr<OperatorCoder> CreateConvDelegateFp16(const std::vector<lite::Tensor *> &in_tensors, 14448be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &out_tensors, 14449be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target, 14450be168c0dSopenharmony_ci+ int schema_version) { 14451be168c0dSopenharmony_ci+ return CPUOpCoderCreator<ConvDelegateDynamicFP16Coder>(in_tensors, out_tensors, node, node_index, target, 14452be168c0dSopenharmony_ci+ schema_version); 14453be168c0dSopenharmony_ci+} 14454be168c0dSopenharmony_ci+ 14455be168c0dSopenharmony_ci+std::unique_ptr<OperatorCoder> CPUConv2DFusionDynamicFP16CoderCreator(const std::vector<lite::Tensor *> &in_tensors, 14456be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &out_tensors, 14457be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, 14458be168c0dSopenharmony_ci+ Target target, int schema_version) { 14459be168c0dSopenharmony_ci+ const void *primitive = node->primitive_; 14460be168c0dSopenharmony_ci+ if (primitive == nullptr) { 14461be168c0dSopenharmony_ci+ return nullptr; 14462be168c0dSopenharmony_ci+ } 14463be168c0dSopenharmony_ci+ ParameterGen param_gen = PopulateRegistry::GetInstance()->GetParameterCreator( 14464be168c0dSopenharmony_ci+ GetPrimitiveType(node->primitive_, schema_version), schema_version); 14465be168c0dSopenharmony_ci+ if (param_gen == nullptr) { 14466be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "parameter generator is null"; 14467be168c0dSopenharmony_ci+ return nullptr; 14468be168c0dSopenharmony_ci+ } 14469be168c0dSopenharmony_ci+ auto conv_param = reinterpret_cast<ConvParameter *>(param_gen(node->primitive_)); 14470be168c0dSopenharmony_ci+ std::unique_ptr<OperatorCoder> coder; 14471be168c0dSopenharmony_ci+ if (conv_param->group_ == 1) { 14472be168c0dSopenharmony_ci+ coder = CreateConvDelegateFp16(in_tensors, out_tensors, node, node_index, target, schema_version); 14473be168c0dSopenharmony_ci+ } else { 14474be168c0dSopenharmony_ci+ // GroupConv 14475be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "currently, only support conv_param->group_ == 1 in dynamic coder scene"; 14476be168c0dSopenharmony_ci+ return nullptr; 14477be168c0dSopenharmony_ci+ } 14478be168c0dSopenharmony_ci+ return coder; 14479be168c0dSopenharmony_ci+} 14480be168c0dSopenharmony_ci+ 14481be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Conv2DFusion, 14482be168c0dSopenharmony_ci+ CPUConv2DFusionDynamicFP16CoderCreator) 14483be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Conv2DFusion, 14484be168c0dSopenharmony_ci+ CPUConv2DFusionDynamicFP16CoderCreator) 14485be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14486be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h 14487be168c0dSopenharmony_cinew file mode 100644 14488be168c0dSopenharmony_ciindex 00000000..c352c469 14489be168c0dSopenharmony_ci--- /dev/null 14490be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/conv2d_delegate_dynamic_fp16_coder.h 14491be168c0dSopenharmony_ci@@ -0,0 +1,56 @@ 14492be168c0dSopenharmony_ci+/** 14493be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14494be168c0dSopenharmony_ci+ * 14495be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14496be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14497be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14498be168c0dSopenharmony_ci+ * 14499be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14500be168c0dSopenharmony_ci+ * 14501be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14502be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14503be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14504be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14505be168c0dSopenharmony_ci+ * limitations under the License. 14506be168c0dSopenharmony_ci+ */ 14507be168c0dSopenharmony_ci+ 14508be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONV2D_DELEGATE_DYNAMIC_FP16_CODER_H_ 14509be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONV2D_DELEGATE_DYNAMIC_FP16_CODER_H_ 14510be168c0dSopenharmony_ci+#include <vector> 14511be168c0dSopenharmony_ci+#include <memory> 14512be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 14513be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 14514be168c0dSopenharmony_ci+#include "nnacl/conv_parameter.h" 14515be168c0dSopenharmony_ci+ 14516be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14517be168c0dSopenharmony_ci+class ConvDelegateDynamicFP16Coder : public OperatorCoder { 14518be168c0dSopenharmony_ci+ public: 14519be168c0dSopenharmony_ci+ ConvDelegateDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14520be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 14521be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14522be168c0dSopenharmony_ci+ 14523be168c0dSopenharmony_ci+ ~ConvDelegateDynamicFP16Coder() override = default; 14524be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 14525be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 14526be168c0dSopenharmony_ci+ 14527be168c0dSopenharmony_ci+ protected: 14528be168c0dSopenharmony_ci+ std::unique_ptr<OperatorCoder> conv_coder_ = nullptr; 14529be168c0dSopenharmony_ci+ ConvParameter *conv_param_{nullptr}; 14530be168c0dSopenharmony_ci+ ConvDynamicParameter dynamic_param_; 14531be168c0dSopenharmony_ci+ 14532be168c0dSopenharmony_ci+ private: 14533be168c0dSopenharmony_ci+ void SetInputOutputShapeInfo(ConvParameter *conv_param, ConvDynamicParameter &dynamic_param, 14534be168c0dSopenharmony_ci+ const lite::Tensor *input, const lite::Tensor *output); 14535be168c0dSopenharmony_ci+}; 14536be168c0dSopenharmony_ci+ 14537be168c0dSopenharmony_ci+std::unique_ptr<OperatorCoder> CPUConvFP16DynamicCoderSelect(const std::vector<lite::Tensor *> &in_tensors, 14538be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &out_tensors, 14539be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, 14540be168c0dSopenharmony_ci+ Target target, int schema_version); 14541be168c0dSopenharmony_ci+ 14542be168c0dSopenharmony_ci+std::unique_ptr<OperatorCoder> CPUConv2DFusionDynamicFP16CoderCreator(const std::vector<lite::Tensor *> &in_tensors, 14543be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &out_tensors, 14544be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, 14545be168c0dSopenharmony_ci+ Target target, int schema_version); 14546be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14547be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONV2D_DELEGATE_DYNAMIC_FP16_CODER_H_ 14548be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc 14549be168c0dSopenharmony_cinew file mode 100644 14550be168c0dSopenharmony_ciindex 00000000..c682b2ed 14551be168c0dSopenharmony_ci--- /dev/null 14552be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.cc 14553be168c0dSopenharmony_ci@@ -0,0 +1,252 @@ 14554be168c0dSopenharmony_ci+/** 14555be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14556be168c0dSopenharmony_ci+ * 14557be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14558be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14559be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14560be168c0dSopenharmony_ci+ * 14561be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14562be168c0dSopenharmony_ci+ * 14563be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14564be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14565be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14566be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14567be168c0dSopenharmony_ci+ * limitations under the License. 14568be168c0dSopenharmony_ci+ */ 14569be168c0dSopenharmony_ci+ 14570be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h" 14571be168c0dSopenharmony_ci+#include <string> 14572be168c0dSopenharmony_ci+#include <vector> 14573be168c0dSopenharmony_ci+#include "nnacl/fp32/winograd_utils.h" 14574be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 14575be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 14576be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 14577be168c0dSopenharmony_ci+ 14578be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14579be168c0dSopenharmony_ci+int Convolution1x1DynamicFP16Coder::Prepare(CoderContext *const context) { 14580be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 14581be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 14582be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 14583be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 14584be168c0dSopenharmony_ci+ "Tensor data type is invalid"); 14585be168c0dSopenharmony_ci+ } 14586be168c0dSopenharmony_ci+ for (size_t i = 0; i < output_tensors_.size(); ++i) { 14587be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensors_[i]->data_type() == kNumberTypeFloat16, RET_PARAM_INVALID, 14588be168c0dSopenharmony_ci+ "Tensor data type is invalid"); 14589be168c0dSopenharmony_ci+ } 14590be168c0dSopenharmony_ci+ if (target_ == kARM64) { 14591be168c0dSopenharmony_ci+ row_tile_ = (output_tensor_->format() == NC4HW4) ? C16NUM : C12NUM; 14592be168c0dSopenharmony_ci+ col_tile_ = (output_tensor_->format() == NC4HW4) ? C8NUM : C16NUM; 14593be168c0dSopenharmony_ci+ } 14594be168c0dSopenharmony_ci+ if (matmul_param_ == nullptr) { 14595be168c0dSopenharmony_ci+ matmul_param_ = new (std::nothrow) MatMulParameter(); 14596be168c0dSopenharmony_ci+ if (matmul_param_ == nullptr) { 14597be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init matmul_param_ failed."; 14598be168c0dSopenharmony_ci+ return RET_ERROR; 14599be168c0dSopenharmony_ci+ } 14600be168c0dSopenharmony_ci+ } 14601be168c0dSopenharmony_ci+ conv_param_ = reinterpret_cast<ConvParameter *>(parameter_); 14602be168c0dSopenharmony_ci+ filter_tensor_ = input_tensors_.at(kWeightIndex); 14603be168c0dSopenharmony_ci+ MS_CHECK_PTR(filter_tensor_); 14604be168c0dSopenharmony_ci+ if (input_tensors_.size() == kInputSize2) { 14605be168c0dSopenharmony_ci+ bias_tensor_ = input_tensors_.at(kBiasIndex); 14606be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_tensor_); 14607be168c0dSopenharmony_ci+ } else { 14608be168c0dSopenharmony_ci+ MS_CHECK_TRUE(input_tensors_.size() == kInputSize1, "wrong input size"); 14609be168c0dSopenharmony_ci+ } 14610be168c0dSopenharmony_ci+ dynamic_param_.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_)[0]; 14611be168c0dSopenharmony_ci+ conv_param_->input_h_ = input_tensor_->Height(); 14612be168c0dSopenharmony_ci+ conv_param_->input_w_ = input_tensor_->Width(); 14613be168c0dSopenharmony_ci+ conv_param_->input_channel_ = input_tensor_->Channel(); 14614be168c0dSopenharmony_ci+ dynamic_param_.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_)[0]; 14615be168c0dSopenharmony_ci+ conv_param_->output_h_ = output_tensor_->Height(); 14616be168c0dSopenharmony_ci+ conv_param_->output_w_ = output_tensor_->Width(); 14617be168c0dSopenharmony_ci+ conv_param_->output_channel_ = output_tensor_->Channel(); 14618be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(InitWeightBias(context), "Init weight bias failed."); 14619be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(InitMatmulParam(), "Init matmul param failed."); 14620be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(InitTmpBuffer(context), "Init tmp buffer failed."); 14621be168c0dSopenharmony_ci+ return RET_OK; 14622be168c0dSopenharmony_ci+} 14623be168c0dSopenharmony_ci+ 14624be168c0dSopenharmony_ci+int Convolution1x1DynamicFP16Coder::DoCode(CoderContext *const context) { 14625be168c0dSopenharmony_ci+ CollectFilesForFunc(context); 14626be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 14627be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(ComputeWorkspace(), "ComputeWorkspace failed."); 14628be168c0dSopenharmony_ci+ auto tmp_input_str = "(float16_t *)(" + allocator_->GetRuntimeAddr(static_cast<float16 *>(tmp_input_)) + ")"; 14629be168c0dSopenharmony_ci+ auto input_str = 14630be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 14631be168c0dSopenharmony_ci+ auto output_str = 14632be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 14633be168c0dSopenharmony_ci+ auto packed_weight_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 14634be168c0dSopenharmony_ci+ 14635be168c0dSopenharmony_ci+ code << " for (int batch_index = 0; batch_index < " << dynamic_param_.input_batch_ << "; batch_index++) {\n"; 14636be168c0dSopenharmony_ci+ output_ptr_ = output_str + " + batch_index * " + std::to_string(matmul_param_->row_ * matmul_param_->col_); 14637be168c0dSopenharmony_ci+ auto batch_in = input_str + " + batch_index * " + 14638be168c0dSopenharmony_ci+ std::to_string(conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_); 14639be168c0dSopenharmony_ci+ if (pre_trans_input_) { 14640be168c0dSopenharmony_ci+ code.CodeStruct("conv_parameter", *conv_param_, dynamic_param_); 14641be168c0dSopenharmony_ci+ code.CodeFunction("Conv1x1InputPack", batch_in, tmp_input_str, "&conv_parameter", DataTypeSize(data_type_)); 14642be168c0dSopenharmony_ci+ } else { 14643be168c0dSopenharmony_ci+ tmp_input_str = batch_in; 14644be168c0dSopenharmony_ci+ } 14645be168c0dSopenharmony_ci+ 14646be168c0dSopenharmony_ci+ if (output_tensor_->format() == NC4HW4) { 14647be168c0dSopenharmony_ci+ code.CodeFunction(target_ == kARM64 ? "RowMajor2Col16MajorFp16Opt" : "RowMajor2Col12MajorFp16Opt", tmp_input_str, 14648be168c0dSopenharmony_ci+ "(float16_t *)(" + pack_input_str_ + ")", matmul_param_->row_, matmul_param_->deep_); 14649be168c0dSopenharmony_ci+ } else { 14650be168c0dSopenharmony_ci+ code.CodeFunction("RowMajor2Col12MajorFp16Opt", tmp_input_str, "(float16_t *)(" + pack_input_str_ + ")", 14651be168c0dSopenharmony_ci+ matmul_param_->row_, matmul_param_->deep_); 14652be168c0dSopenharmony_ci+ } 14653be168c0dSopenharmony_ci+ 14654be168c0dSopenharmony_ci+ if (output_tensor_->format() == NC4HW4) { 14655be168c0dSopenharmony_ci+ code.CodeStruct("matmul_param", *matmul_param_); 14656be168c0dSopenharmony_ci+ code.CodeFunction("Conv1x1OutNc8hw8MultiThreadByWeightFp16", tmp_input_str, 14657be168c0dSopenharmony_ci+ "(float16_t *)(" + pack_input_str_ + ")", packed_weight_str, bias_data_, output_ptr_, 14658be168c0dSopenharmony_ci+ kDefaultTaskId, "&matmul_param"); 14659be168c0dSopenharmony_ci+ } else { 14660be168c0dSopenharmony_ci+ code.CodeFunction(target_ == kARM64 ? "MatMul12x16Fp16Opt" : "MatMul12x8A32Fp16", 14661be168c0dSopenharmony_ci+ "(float16_t *)(" + pack_input_str_ + ")", packed_weight_str, output_ptr_, bias_data_, 14662be168c0dSopenharmony_ci+ matmul_param_->act_type_, matmul_param_->deep_, matmul_param_->row_, matmul_param_->col_, 14663be168c0dSopenharmony_ci+ matmul_param_->col_, OutType_Nhwc); 14664be168c0dSopenharmony_ci+ } 14665be168c0dSopenharmony_ci+ code << " }\n"; 14666be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 14667be168c0dSopenharmony_ci+ return RET_OK; 14668be168c0dSopenharmony_ci+} 14669be168c0dSopenharmony_ci+ 14670be168c0dSopenharmony_ci+Convolution1x1DynamicFP16Coder::~Convolution1x1DynamicFP16Coder() { 14671be168c0dSopenharmony_ci+ FreeTmpBuffer(); 14672be168c0dSopenharmony_ci+ if (matmul_param_ != nullptr) { 14673be168c0dSopenharmony_ci+ delete matmul_param_; 14674be168c0dSopenharmony_ci+ matmul_param_ = nullptr; 14675be168c0dSopenharmony_ci+ } 14676be168c0dSopenharmony_ci+ return; 14677be168c0dSopenharmony_ci+} 14678be168c0dSopenharmony_ci+ 14679be168c0dSopenharmony_ci+void Convolution1x1DynamicFP16Coder::FreeTmpBuffer() { 14680be168c0dSopenharmony_ci+ if (pre_trans_input_ && tmp_input_ != nullptr) { 14681be168c0dSopenharmony_ci+ free(tmp_input_); 14682be168c0dSopenharmony_ci+ tmp_input_ = nullptr; 14683be168c0dSopenharmony_ci+ } 14684be168c0dSopenharmony_ci+ return; 14685be168c0dSopenharmony_ci+} 14686be168c0dSopenharmony_ci+ 14687be168c0dSopenharmony_ci+int Convolution1x1DynamicFP16Coder::ComputeWorkspace() { 14688be168c0dSopenharmony_ci+ pack_input_size_ = matmul_param_->row_align_ * matmul_param_->deep_ * DataTypeSize(data_type_); 14689be168c0dSopenharmony_ci+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 14690be168c0dSopenharmony_ci+ size_t scene_num = 0; 14691be168c0dSopenharmony_ci+ for (auto &dim_template : input_shape) { 14692be168c0dSopenharmony_ci+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 14693be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 14694be168c0dSopenharmony_ci+ scene_num = std::max(scene_num, dim_nums.size()); 14695be168c0dSopenharmony_ci+ } 14696be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 14697be168c0dSopenharmony_ci+ pack_input_str_ = dynamic_mem_manager_->AllocWorkSpace(pack_input_size_, i); 14698be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!pack_input_str_.empty(), RET_ERROR, "Convolution cannot alloc workspace."); 14699be168c0dSopenharmony_ci+ } 14700be168c0dSopenharmony_ci+ return RET_OK; 14701be168c0dSopenharmony_ci+} 14702be168c0dSopenharmony_ci+ 14703be168c0dSopenharmony_ci+int Convolution1x1DynamicFP16Coder::InitMatmulParam() { 14704be168c0dSopenharmony_ci+ matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; 14705be168c0dSopenharmony_ci+ matmul_param_->col_ = conv_param_->output_channel_; 14706be168c0dSopenharmony_ci+ matmul_param_->deep_ = conv_param_->input_channel_; 14707be168c0dSopenharmony_ci+ matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_); 14708be168c0dSopenharmony_ci+ matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_); 14709be168c0dSopenharmony_ci+ matmul_param_->act_type_ = conv_param_->act_type_; 14710be168c0dSopenharmony_ci+ return RET_OK; 14711be168c0dSopenharmony_ci+} 14712be168c0dSopenharmony_ci+ 14713be168c0dSopenharmony_ci+int Convolution1x1DynamicFP16Coder::InitWeightBias(CoderContext *const context) { 14714be168c0dSopenharmony_ci+ auto input_channel = filter_tensor_->Channel(); 14715be168c0dSopenharmony_ci+ auto output_channel = filter_tensor_->Batch(); 14716be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(input_channel > 0 && output_channel > 0, RET_ERROR); 14717be168c0dSopenharmony_ci+ pack_weight_size_ = input_channel * UP_ROUND(output_channel, col_tile_) * DataTypeSize(data_type_); 14718be168c0dSopenharmony_ci+ packed_weight_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 14719be168c0dSopenharmony_ci+ MS_CHECK_PTR(packed_weight_); 14720be168c0dSopenharmony_ci+ 14721be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 14722be168c0dSopenharmony_ci+ std::string ori_weight_addr = allocator_->GetRuntimeAddr(filter_tensor_); 14723be168c0dSopenharmony_ci+ size_t w_buf_size = 0; 14724be168c0dSopenharmony_ci+ w_buf_size += pack_weight_size_; 14725be168c0dSopenharmony_ci+ auto packed_weight_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 14726be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(), 14727be168c0dSopenharmony_ci+ context->weight_size_name(), pack_weight_size_); 14728be168c0dSopenharmony_ci+ if (target_ == kARM64 && output_tensor_->format() != NC4HW4) { 14729be168c0dSopenharmony_ci+ init_code.CodeFunction("RowMajor2Col16MajorFp16Opt", ori_weight_addr, packed_weight_str, output_channel, 14730be168c0dSopenharmony_ci+ input_channel); 14731be168c0dSopenharmony_ci+ } else { 14732be168c0dSopenharmony_ci+ init_code.CodeFunction("ColMajor2Row8MajorFp16", ori_weight_addr, packed_weight_str, input_channel, output_channel, 14733be168c0dSopenharmony_ci+ true); 14734be168c0dSopenharmony_ci+ } 14735be168c0dSopenharmony_ci+ bias_data_size_ = UP_ROUND(output_channel, col_tile_) * DataTypeSize(data_type_); 14736be168c0dSopenharmony_ci+ if (input_tensors_.size() == kInputSize2) { 14737be168c0dSopenharmony_ci+ bias_data_ = 14738be168c0dSopenharmony_ci+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 14739be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_data_); 14740be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(bias_data_, context->weight_name(), context->weight_offset_name(), 14741be168c0dSopenharmony_ci+ context->weight_size_name(), bias_data_size_); 14742be168c0dSopenharmony_ci+ w_buf_size += bias_data_size_; 14743be168c0dSopenharmony_ci+ auto bias_data_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(bias_data_)); 14744be168c0dSopenharmony_ci+ std::string bias_tensor_str = allocator_->GetRuntimeAddr(bias_tensor_); 14745be168c0dSopenharmony_ci+ init_code.CodeFunction("memcpy", bias_data_str, bias_tensor_str, bias_tensor_->Size()); 14746be168c0dSopenharmony_ci+ } else { 14747be168c0dSopenharmony_ci+ bias_data_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 14748be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_data_); 14749be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", bias_data_, 0, bias_data_size_); 14750be168c0dSopenharmony_ci+ } 14751be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(w_buf_size); 14752be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 14753be168c0dSopenharmony_ci+ return RET_OK; 14754be168c0dSopenharmony_ci+} 14755be168c0dSopenharmony_ci+ 14756be168c0dSopenharmony_ci+int Convolution1x1DynamicFP16Coder::InitTmpBuffer(CoderContext *const context) { 14757be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 14758be168c0dSopenharmony_ci+ pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || 14759be168c0dSopenharmony_ci+ conv_param_->stride_w_ != 1); 14760be168c0dSopenharmony_ci+ size_t w_size = 0; 14761be168c0dSopenharmony_ci+ if (pre_trans_input_) { 14762be168c0dSopenharmony_ci+ tmp_input_size_ = matmul_param_->row_ * matmul_param_->deep_ * DataTypeSize(data_type_); 14763be168c0dSopenharmony_ci+ tmp_input_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 14764be168c0dSopenharmony_ci+ MS_CHECK_PTR(tmp_input_); 14765be168c0dSopenharmony_ci+ w_size += tmp_input_size_; 14766be168c0dSopenharmony_ci+ auto tmp_input_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(tmp_input_)); 14767be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(tmp_input_, context->weight_name(), context->weight_offset_name(), 14768be168c0dSopenharmony_ci+ context->weight_size_name(), tmp_input_size_); 14769be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", tmp_input_, 0, tmp_input_size_); 14770be168c0dSopenharmony_ci+ } 14771be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(w_size); 14772be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 14773be168c0dSopenharmony_ci+ return RET_OK; 14774be168c0dSopenharmony_ci+} 14775be168c0dSopenharmony_ci+ 14776be168c0dSopenharmony_ci+void Convolution1x1DynamicFP16Coder::CollectFilesForFunc(CoderContext *const context) { 14777be168c0dSopenharmony_ci+ if (target_ == kARM64) { 14778be168c0dSopenharmony_ci+ Collect(context, {}, {}, 14779be168c0dSopenharmony_ci+ { 14780be168c0dSopenharmony_ci+ "MatmulFp16.S", 14781be168c0dSopenharmony_ci+ "MatmulFp16Opt.S", 14782be168c0dSopenharmony_ci+ "Matmul12X16Fp16.S", 14783be168c0dSopenharmony_ci+ }); 14784be168c0dSopenharmony_ci+ } else { 14785be168c0dSopenharmony_ci+ Collect(context, {}, {}, 14786be168c0dSopenharmony_ci+ { 14787be168c0dSopenharmony_ci+ "Matmul12x8Fp16.S", 14788be168c0dSopenharmony_ci+ }); 14789be168c0dSopenharmony_ci+ } 14790be168c0dSopenharmony_ci+ Collect(context, 14791be168c0dSopenharmony_ci+ { 14792be168c0dSopenharmony_ci+ "nnacl/fp16/matmul_fp16.h", 14793be168c0dSopenharmony_ci+ "nnacl/conv_parameter.h", 14794be168c0dSopenharmony_ci+ "nnacl/op_base.h", 14795be168c0dSopenharmony_ci+ "nnacl/fp16/conv_fp16.h", 14796be168c0dSopenharmony_ci+ "nnacl/base/conv1x1_base.h", 14797be168c0dSopenharmony_ci+ }, 14798be168c0dSopenharmony_ci+ { 14799be168c0dSopenharmony_ci+ "common_func.c", 14800be168c0dSopenharmony_ci+ "matmul_fp16.c", 14801be168c0dSopenharmony_ci+ "conv_fp16.c", 14802be168c0dSopenharmony_ci+ "conv1x1_base.c", 14803be168c0dSopenharmony_ci+ }); 14804be168c0dSopenharmony_ci+} 14805be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14806be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h 14807be168c0dSopenharmony_cinew file mode 100644 14808be168c0dSopenharmony_ciindex 00000000..558eea53 14809be168c0dSopenharmony_ci--- /dev/null 14810be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_1x1_dynamic_fp16_coder.h 14811be168c0dSopenharmony_ci@@ -0,0 +1,68 @@ 14812be168c0dSopenharmony_ci+/** 14813be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14814be168c0dSopenharmony_ci+ * 14815be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14816be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14817be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14818be168c0dSopenharmony_ci+ * 14819be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14820be168c0dSopenharmony_ci+ * 14821be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14822be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14823be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14824be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14825be168c0dSopenharmony_ci+ * limitations under the License. 14826be168c0dSopenharmony_ci+ */ 14827be168c0dSopenharmony_ci+ 14828be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_1X1_DYNAMIC_FP16_CODER_H_ 14829be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_1X1_DYNAMIC_FP16_CODER_H_ 14830be168c0dSopenharmony_ci+ 14831be168c0dSopenharmony_ci+#include <vector> 14832be168c0dSopenharmony_ci+#include <string> 14833be168c0dSopenharmony_ci+#include "nnacl/conv_parameter.h" 14834be168c0dSopenharmony_ci+#include "nnacl/matmul_parameter.h" 14835be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 14836be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 14837be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 14838be168c0dSopenharmony_ci+#include "base/float16.h" 14839be168c0dSopenharmony_ci+ 14840be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14841be168c0dSopenharmony_ci+class Convolution1x1DynamicFP16Coder final : public OperatorCoder { 14842be168c0dSopenharmony_ci+ public: 14843be168c0dSopenharmony_ci+ Convolution1x1DynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 14844be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 14845be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 14846be168c0dSopenharmony_ci+ ~Convolution1x1DynamicFP16Coder() override; 14847be168c0dSopenharmony_ci+ 14848be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 14849be168c0dSopenharmony_ci+ 14850be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 14851be168c0dSopenharmony_ci+ 14852be168c0dSopenharmony_ci+ private: 14853be168c0dSopenharmony_ci+ void CollectFilesForFunc(CoderContext *const context); 14854be168c0dSopenharmony_ci+ int InitWeightBias(CoderContext *const context); 14855be168c0dSopenharmony_ci+ int InitMatmulParam(); 14856be168c0dSopenharmony_ci+ int InitTmpBuffer(CoderContext *const context); 14857be168c0dSopenharmony_ci+ void FreeTmpBuffer(); 14858be168c0dSopenharmony_ci+ int ComputeWorkspace(); 14859be168c0dSopenharmony_ci+ MatMulParameter *matmul_param_{nullptr}; 14860be168c0dSopenharmony_ci+ ConvParameter *conv_param_{nullptr}; 14861be168c0dSopenharmony_ci+ ConvDynamicParameter dynamic_param_; 14862be168c0dSopenharmony_ci+ Tensor *filter_tensor_{nullptr}; 14863be168c0dSopenharmony_ci+ Tensor *bias_tensor_{nullptr}; 14864be168c0dSopenharmony_ci+ int row_tile_{C12NUM}; 14865be168c0dSopenharmony_ci+ int col_tile_{C8NUM}; 14866be168c0dSopenharmony_ci+ void *packed_weight_{nullptr}; 14867be168c0dSopenharmony_ci+ void *bias_data_{nullptr}; 14868be168c0dSopenharmony_ci+ std::string pack_input_str_; 14869be168c0dSopenharmony_ci+ void *tmp_input_{nullptr}; 14870be168c0dSopenharmony_ci+ size_t pack_weight_size_{0}; 14871be168c0dSopenharmony_ci+ size_t bias_data_size_{0}; 14872be168c0dSopenharmony_ci+ size_t tmp_input_size_{0}; 14873be168c0dSopenharmony_ci+ size_t pack_input_size_{0}; 14874be168c0dSopenharmony_ci+ bool pre_trans_input_{false}; 14875be168c0dSopenharmony_ci+ std::string output_ptr_; 14876be168c0dSopenharmony_ci+ TypeId data_type_ = kNumberTypeFloat16; 14877be168c0dSopenharmony_ci+}; 14878be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 14879be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_1X1_DYNAMIC_FP16_CODER_H_ 14880be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc 14881be168c0dSopenharmony_cinew file mode 100644 14882be168c0dSopenharmony_ciindex 00000000..c917b89a 14883be168c0dSopenharmony_ci--- /dev/null 14884be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.cc 14885be168c0dSopenharmony_ci@@ -0,0 +1,172 @@ 14886be168c0dSopenharmony_ci+/** 14887be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 14888be168c0dSopenharmony_ci+ * 14889be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 14890be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 14891be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 14892be168c0dSopenharmony_ci+ * 14893be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 14894be168c0dSopenharmony_ci+ * 14895be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 14896be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 14897be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14898be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 14899be168c0dSopenharmony_ci+ * limitations under the License. 14900be168c0dSopenharmony_ci+ */ 14901be168c0dSopenharmony_ci+ 14902be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h" 14903be168c0dSopenharmony_ci+#include <string> 14904be168c0dSopenharmony_ci+#include <vector> 14905be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.h" 14906be168c0dSopenharmony_ci+#include "nnacl/fp32/winograd_utils.h" 14907be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 14908be168c0dSopenharmony_ci+#include "coder/log.h" 14909be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 14910be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 14911be168c0dSopenharmony_ci+#include "base/float16.h" 14912be168c0dSopenharmony_ci+ 14913be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Conv2DFusion; 14914be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 14915be168c0dSopenharmony_ci+int ConvolutionDynamicFP16Coder::Prepare(CoderContext *const context) { 14916be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(input_tensors_.size(), C2NUM); 14917be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 14918be168c0dSopenharmony_ci+ if (target_ == kARM64) { 14919be168c0dSopenharmony_ci+ row_tile_ = C16NUM; 14920be168c0dSopenharmony_ci+ } 14921be168c0dSopenharmony_ci+ conv_param_ = reinterpret_cast<ConvParameter *>(parameter_); 14922be168c0dSopenharmony_ci+ MS_CHECK_PTR(conv_param_); 14923be168c0dSopenharmony_ci+ dynamic_param_.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_)[0]; 14924be168c0dSopenharmony_ci+ conv_param_->input_h_ = input_tensor_->Height(); 14925be168c0dSopenharmony_ci+ conv_param_->input_w_ = input_tensor_->Width(); 14926be168c0dSopenharmony_ci+ conv_param_->input_channel_ = input_tensor_->Channel(); 14927be168c0dSopenharmony_ci+ dynamic_param_.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_)[0]; 14928be168c0dSopenharmony_ci+ conv_param_->output_h_ = output_tensor_->Height(); 14929be168c0dSopenharmony_ci+ conv_param_->output_w_ = output_tensor_->Width(); 14930be168c0dSopenharmony_ci+ conv_param_->output_channel_ = output_tensor_->Channel(); 14931be168c0dSopenharmony_ci+ conv_param_->thread_num_ = 1; 14932be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(InitWeightBias(context), "Init weight bias failed."); 14933be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(InitTmpBuffer(), "Init tmp buffer failed."); 14934be168c0dSopenharmony_ci+ return RET_OK; 14935be168c0dSopenharmony_ci+} 14936be168c0dSopenharmony_ci+ 14937be168c0dSopenharmony_ci+int ConvolutionDynamicFP16Coder::InitTmpBuffer() { 14938be168c0dSopenharmony_ci+ int uint_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * 14939be168c0dSopenharmony_ci+ conv_param_->thread_num_; 14940be168c0dSopenharmony_ci+ packed_input_size_ = uint_size * DataTypeSize(data_type_); 14941be168c0dSopenharmony_ci+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 14942be168c0dSopenharmony_ci+ size_t scene_num = 0; 14943be168c0dSopenharmony_ci+ for (auto &dim_template : input_shape) { 14944be168c0dSopenharmony_ci+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 14945be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 14946be168c0dSopenharmony_ci+ scene_num = std::max(scene_num, dim_nums.size()); 14947be168c0dSopenharmony_ci+ } 14948be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 14949be168c0dSopenharmony_ci+ packed_input_str_ = dynamic_mem_manager_->AllocWorkSpace(packed_input_size_ * 2, i); 14950be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!packed_input_str_.empty(), RET_ERROR, "Convolution cannot alloc workspace."); 14951be168c0dSopenharmony_ci+ } 14952be168c0dSopenharmony_ci+ col_major_input_str_ = packed_input_str_ + " + " + std::to_string(packed_input_size_); 14953be168c0dSopenharmony_ci+ return RET_OK; 14954be168c0dSopenharmony_ci+} 14955be168c0dSopenharmony_ci+ 14956be168c0dSopenharmony_ci+int ConvolutionDynamicFP16Coder::InitWeightBias(CoderContext *const context) { 14957be168c0dSopenharmony_ci+ filter_tensor_ = input_tensors_.at(kWeightIndex); 14958be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(filter_tensor_); 14959be168c0dSopenharmony_ci+ auto shape = filter_tensor_->shape(); 14960be168c0dSopenharmony_ci+ if (std::find(shape.begin(), shape.end(), -1) != shape.end()) { 14961be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "The shape of weight tensor is not ready, the weight and bias would be inited in runtime."; 14962be168c0dSopenharmony_ci+ return RET_OK; 14963be168c0dSopenharmony_ci+ } 14964be168c0dSopenharmony_ci+ int in_channel = filter_tensor_->Channel(); 14965be168c0dSopenharmony_ci+ int out_channel = filter_tensor_->Batch(); 14966be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(in_channel > 0 && out_channel > 0, RET_ERROR); 14967be168c0dSopenharmony_ci+ conv_param_->input_channel_ = in_channel; 14968be168c0dSopenharmony_ci+ conv_param_->output_channel_ = out_channel; 14969be168c0dSopenharmony_ci+ int oc8 = UP_ROUND(out_channel, col_tile_); 14970be168c0dSopenharmony_ci+ int kernel_plane = filter_tensor_->Height() * filter_tensor_->Width(); 14971be168c0dSopenharmony_ci+ pack_weight_size_ = oc8 * in_channel * kernel_plane * DataTypeSize(data_type_); 14972be168c0dSopenharmony_ci+ // init weight 14973be168c0dSopenharmony_ci+ packed_weight_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 14974be168c0dSopenharmony_ci+ MS_CHECK_PTR(packed_weight_); 14975be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 14976be168c0dSopenharmony_ci+ std::string ori_weight_addr = allocator_->GetRuntimeAddr(filter_tensor_); 14977be168c0dSopenharmony_ci+ size_t w_buf_size = 0; 14978be168c0dSopenharmony_ci+ w_buf_size += pack_weight_size_; 14979be168c0dSopenharmony_ci+ auto packed_weight_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 14980be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(packed_weight_, context->weight_name(), context->weight_offset_name(), 14981be168c0dSopenharmony_ci+ context->weight_size_name(), pack_weight_size_); 14982be168c0dSopenharmony_ci+ init_code.CodeFunction("RowMajor2Col8MajorFp16", ori_weight_addr, packed_weight_str, out_channel, 14983be168c0dSopenharmony_ci+ in_channel * kernel_plane, false); 14984be168c0dSopenharmony_ci+ if (input_tensors_.size() == C3NUM) { 14985be168c0dSopenharmony_ci+ bias_tensor_ = input_tensors_.at(kBiasIndex); 14986be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_tensor_); 14987be168c0dSopenharmony_ci+ bias_data_ = 14988be168c0dSopenharmony_ci+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 14989be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_data_); 14990be168c0dSopenharmony_ci+ } else { 14991be168c0dSopenharmony_ci+ bias_data_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 14992be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_data_); 14993be168c0dSopenharmony_ci+ } 14994be168c0dSopenharmony_ci+ auto bias_data_size = static_cast<size_t>(oc8 * DataTypeSize(data_type_)); 14995be168c0dSopenharmony_ci+ w_buf_size += bias_data_size; 14996be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(bias_data_, context->weight_name(), context->weight_offset_name(), 14997be168c0dSopenharmony_ci+ context->weight_size_name(), bias_data_size); 14998be168c0dSopenharmony_ci+ bias_data_str_ = allocator_->GetRuntimeAddr(bias_data_); 14999be168c0dSopenharmony_ci+ if (input_tensors_.size() == C3NUM) { 15000be168c0dSopenharmony_ci+ auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 15001be168c0dSopenharmony_ci+ init_code.CodeFunction("memcpy", bias_data_str_, origin_bias_str, bias_tensor_->Size()); 15002be168c0dSopenharmony_ci+ } else { 15003be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", bias_data_str_, 0, bias_data_size); 15004be168c0dSopenharmony_ci+ } 15005be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(w_buf_size); 15006be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 15007be168c0dSopenharmony_ci+ return RET_OK; 15008be168c0dSopenharmony_ci+} 15009be168c0dSopenharmony_ci+ 15010be168c0dSopenharmony_ci+void ConvolutionDynamicFP16Coder::CollectFilesForFunc(CoderContext *const context) { 15011be168c0dSopenharmony_ci+ Collect(context, {}, {}, 15012be168c0dSopenharmony_ci+ { 15013be168c0dSopenharmony_ci+ "MatmulFp16.S", 15014be168c0dSopenharmony_ci+ "MatmulFp16Opt.S", 15015be168c0dSopenharmony_ci+ "MatVecMulFp16.S", 15016be168c0dSopenharmony_ci+ "Matmul12X16Fp16.S", 15017be168c0dSopenharmony_ci+ }); 15018be168c0dSopenharmony_ci+ Collect(context, 15019be168c0dSopenharmony_ci+ { 15020be168c0dSopenharmony_ci+ "nnacl/fp16/matmul_fp16.h", 15021be168c0dSopenharmony_ci+ "nnacl/conv_parameter.h", 15022be168c0dSopenharmony_ci+ "nnacl/op_base.h", 15023be168c0dSopenharmony_ci+ "nnacl/fp16/conv_fp16.h", 15024be168c0dSopenharmony_ci+ }, 15025be168c0dSopenharmony_ci+ { 15026be168c0dSopenharmony_ci+ "common_func.c", 15027be168c0dSopenharmony_ci+ "matmul_fp16.c", 15028be168c0dSopenharmony_ci+ "pack_fp16.c", 15029be168c0dSopenharmony_ci+ "conv_fp16.c", 15030be168c0dSopenharmony_ci+ }); 15031be168c0dSopenharmony_ci+} 15032be168c0dSopenharmony_ci+ 15033be168c0dSopenharmony_ci+int ConvolutionDynamicFP16Coder::DoCode(CoderContext *const context) { 15034be168c0dSopenharmony_ci+ CollectFilesForFunc(context); 15035be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 15036be168c0dSopenharmony_ci+ // call the op function 15037be168c0dSopenharmony_ci+ auto packed_weight_str = allocator_->GetRuntimeAddr(static_cast<float16 *>(packed_weight_)); 15038be168c0dSopenharmony_ci+ auto input_str = 15039be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 15040be168c0dSopenharmony_ci+ auto output_str = 15041be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 15042be168c0dSopenharmony_ci+ // code.CodeFunction("memset", packed_input_str_, "0", packed_input_size_); 15043be168c0dSopenharmony_ci+ // code.CodeFunction("memset", col_major_input_str_, "0", packed_input_size_); 15044be168c0dSopenharmony_ci+ code.CodeStruct("conv_parameter", *conv_param_, dynamic_param_); 15045be168c0dSopenharmony_ci+ packed_input_str_ = "(float16_t *)(" + packed_input_str_ + ")"; 15046be168c0dSopenharmony_ci+ col_major_input_str_ = "(float16_t *)(" + col_major_input_str_ + ")"; 15047be168c0dSopenharmony_ci+ if (output_tensor_->format() == NC4HW4) { 15048be168c0dSopenharmony_ci+ code.CodeFunction("ConvOutNc8hw8Fp16", input_str, packed_input_str_, packed_weight_str, bias_data_str_, 15049be168c0dSopenharmony_ci+ col_major_input_str_, output_str, kDefaultTaskId, "&conv_parameter"); 15050be168c0dSopenharmony_ci+ } else { 15051be168c0dSopenharmony_ci+ code.CodeFunction("ConvFp16", input_str, packed_input_str_, packed_weight_str, bias_data_str_, col_major_input_str_, 15052be168c0dSopenharmony_ci+ output_str, kDefaultTaskId, "&conv_parameter"); 15053be168c0dSopenharmony_ci+ } 15054be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 15055be168c0dSopenharmony_ci+ return RET_OK; 15056be168c0dSopenharmony_ci+} 15057be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15058be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h 15059be168c0dSopenharmony_cinew file mode 100644 15060be168c0dSopenharmony_ciindex 00000000..29d70796 15061be168c0dSopenharmony_ci--- /dev/null 15062be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/convolution_dynamic_fp16_coder.h 15063be168c0dSopenharmony_ci@@ -0,0 +1,59 @@ 15064be168c0dSopenharmony_ci+/** 15065be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15066be168c0dSopenharmony_ci+ * 15067be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15068be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15069be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15070be168c0dSopenharmony_ci+ * 15071be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 15072be168c0dSopenharmony_ci+ * 15073be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 15074be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 15075be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15076be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 15077be168c0dSopenharmony_ci+ * limitations under the License. 15078be168c0dSopenharmony_ci+ */ 15079be168c0dSopenharmony_ci+ 15080be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_DYNAMIC_FP16_CODER_H_ 15081be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_DYNAMIC_FP16_CODER_H_ 15082be168c0dSopenharmony_ci+ 15083be168c0dSopenharmony_ci+#include <vector> 15084be168c0dSopenharmony_ci+#include <string> 15085be168c0dSopenharmony_ci+#include "nnacl/conv_parameter.h" 15086be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 15087be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 15088be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 15089be168c0dSopenharmony_ci+ 15090be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 15091be168c0dSopenharmony_ci+class ConvolutionDynamicFP16Coder final : public OperatorCoder { 15092be168c0dSopenharmony_ci+ public: 15093be168c0dSopenharmony_ci+ ConvolutionDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 15094be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 15095be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 15096be168c0dSopenharmony_ci+ 15097be168c0dSopenharmony_ci+ ~ConvolutionDynamicFP16Coder() override = default; 15098be168c0dSopenharmony_ci+ 15099be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 15100be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 15101be168c0dSopenharmony_ci+ 15102be168c0dSopenharmony_ci+ private: 15103be168c0dSopenharmony_ci+ void CollectFilesForFunc(CoderContext *const context); 15104be168c0dSopenharmony_ci+ int InitWeightBias(CoderContext *const context); 15105be168c0dSopenharmony_ci+ int InitTmpBuffer(); 15106be168c0dSopenharmony_ci+ ConvParameter *conv_param_{nullptr}; 15107be168c0dSopenharmony_ci+ ConvDynamicParameter dynamic_param_; 15108be168c0dSopenharmony_ci+ TypeId data_type_{kNumberTypeFloat16}; 15109be168c0dSopenharmony_ci+ int row_tile_{C12NUM}; 15110be168c0dSopenharmony_ci+ int col_tile_{C8NUM}; 15111be168c0dSopenharmony_ci+ Tensor *filter_tensor_{nullptr}; 15112be168c0dSopenharmony_ci+ Tensor *bias_tensor_{nullptr}; 15113be168c0dSopenharmony_ci+ size_t pack_weight_size_{0}; 15114be168c0dSopenharmony_ci+ size_t packed_input_size_{0}; 15115be168c0dSopenharmony_ci+ void *packed_weight_{nullptr}; 15116be168c0dSopenharmony_ci+ void *bias_data_{nullptr}; 15117be168c0dSopenharmony_ci+ std::string packed_input_str_; 15118be168c0dSopenharmony_ci+ std::string col_major_input_str_; 15119be168c0dSopenharmony_ci+ std::string bias_data_str_; 15120be168c0dSopenharmony_ci+}; 15121be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15122be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_CONVOLUTION_DYNAMIC_FP16_CODER_H_ 15123be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc 15124be168c0dSopenharmony_cinew file mode 100644 15125be168c0dSopenharmony_ciindex 00000000..8c4cc31b 15126be168c0dSopenharmony_ci--- /dev/null 15127be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.cc 15128be168c0dSopenharmony_ci@@ -0,0 +1,366 @@ 15129be168c0dSopenharmony_ci+/** 15130be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15131be168c0dSopenharmony_ci+ * 15132be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15133be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15134be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15135be168c0dSopenharmony_ci+ * 15136be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 15137be168c0dSopenharmony_ci+ * 15138be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 15139be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 15140be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15141be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 15142be168c0dSopenharmony_ci+ * limitations under the License. 15143be168c0dSopenharmony_ci+ */ 15144be168c0dSopenharmony_ci+ 15145be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h" 15146be168c0dSopenharmony_ci+#include <cfloat> 15147be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 15148be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 15149be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 15150be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 15151be168c0dSopenharmony_ci+ 15152be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_LSTM; 15153be168c0dSopenharmony_ci+ 15154be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 15155be168c0dSopenharmony_ci+namespace { 15156be168c0dSopenharmony_ci+constexpr size_t kMindirInputTensorNum = 4; 15157be168c0dSopenharmony_ci+} // namespace 15158be168c0dSopenharmony_ci+ 15159be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::Prepare(CoderContext *const context) { 15160be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(context); 15161be168c0dSopenharmony_ci+ CHECK_NOT_EQUAL_RETURN(input_tensors_.size(), kMindirInputTensorNum); 15162be168c0dSopenharmony_ci+ for (auto in : input_tensors_) { 15163be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in != nullptr, RET_INPUT_TENSOR_ERROR, "LstmMindirDynamicFP16Coder input is a nullptr."); 15164be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in->data_type() == kNumberTypeFloat16, RET_INPUT_TENSOR_ERROR, 15165be168c0dSopenharmony_ci+ "LstmMindirDynamicFP16Coder input must be fp16."); 15166be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in->shape().size() == C3NUM, RET_INPUT_TENSOR_ERROR, 15167be168c0dSopenharmony_ci+ "LstmMindirDynamicFP16Coder input must be 3D."); 15168be168c0dSopenharmony_ci+ } 15169be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[FOURTH_INPUT]->IsConst(), RET_INPUT_TENSOR_ERROR, 15170be168c0dSopenharmony_ci+ "LstmMindirDynamicFP16Coder last three inputs must be all constant."); 15171be168c0dSopenharmony_ci+ lstm_param_ = reinterpret_cast<LstmParameter *>(parameter_); 15172be168c0dSopenharmony_ci+ return InitParam(); 15173be168c0dSopenharmony_ci+} 15174be168c0dSopenharmony_ci+ 15175be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::DoCode(CoderContext *const context) { 15176be168c0dSopenharmony_ci+ Collect(context, 15177be168c0dSopenharmony_ci+ { 15178be168c0dSopenharmony_ci+ "nnacl/lstm_parameter.h", 15179be168c0dSopenharmony_ci+ "nnacl/fp16/lstm_fp16.h", 15180be168c0dSopenharmony_ci+ }, 15181be168c0dSopenharmony_ci+ {"lstm_fp16.c", "activation_fp16.c", "arithmetic_fp16.c", "matmul_fp16.c", "pack_fp16.c"}, 15182be168c0dSopenharmony_ci+ {"MatmulBaseFp16Neon.S"}); 15183be168c0dSopenharmony_ci+ 15184be168c0dSopenharmony_ci+ auto ret = InitInputWeightBias(context); 15185be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm InitInputWeightBias failed."); 15186be168c0dSopenharmony_ci+ ret = InitStateWeightBias(context); 15187be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm InitStateWeightBias failed."); 15188be168c0dSopenharmony_ci+ ret = InitProjectWeight(context); 15189be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm InitProjectWeight failed."); 15190be168c0dSopenharmony_ci+ ret = ComputeWorkSpace(); 15191be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Lstm ComputeWorkSpace failed."); 15192be168c0dSopenharmony_ci+ CreateBufferAddrStr(); 15193be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 15194be168c0dSopenharmony_ci+ code << "float16_t *buffer[7] = {"; 15195be168c0dSopenharmony_ci+ for (const auto &buf : buffers_str_) { 15196be168c0dSopenharmony_ci+ code << "(float16_t *)(" << buf << "), "; 15197be168c0dSopenharmony_ci+ } 15198be168c0dSopenharmony_ci+ code << "};\n"; 15199be168c0dSopenharmony_ci+ 15200be168c0dSopenharmony_ci+ auto input1 = dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[FIRST_INPUT]); 15201be168c0dSopenharmony_ci+ auto hidden_init = input_tensors_[SECOND_INPUT]->IsConst() 15202be168c0dSopenharmony_ci+ ? allocator_->GetRuntimeAddr(input_tensors_[SECOND_INPUT], true) 15203be168c0dSopenharmony_ci+ : dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[SECOND_INPUT]); 15204be168c0dSopenharmony_ci+ auto cell_init = input_tensors_[THIRD_INPUT]->IsConst() 15205be168c0dSopenharmony_ci+ ? allocator_->GetRuntimeAddr(input_tensors_[THIRD_INPUT], true) 15206be168c0dSopenharmony_ci+ : dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[THIRD_INPUT]); 15207be168c0dSopenharmony_ci+ auto output1 = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[FIRST_INPUT]); 15208be168c0dSopenharmony_ci+ auto hidden_output = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[SECOND_INPUT]); 15209be168c0dSopenharmony_ci+ auto cell_output = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[THIRD_INPUT]); 15210be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input1.empty() && !hidden_init.empty() && !cell_init.empty() && !output1.empty() && 15211be168c0dSopenharmony_ci+ !hidden_output.empty() && !cell_output.empty(), 15212be168c0dSopenharmony_ci+ RET_ERROR, "Lstm cannot get addr."); 15213be168c0dSopenharmony_ci+ code.CodeStruct("lstm_param", *lstm_param_, dynamic_lstm_param_); 15214be168c0dSopenharmony_ci+ auto input_shape2 = shape_info_container_->GetTemplateShape(input_tensors_[SECOND_INPUT]); 15215be168c0dSopenharmony_ci+ int64_t const_part = 1; 15216be168c0dSopenharmony_ci+ std::string non_const_part; 15217be168c0dSopenharmony_ci+ for (const auto &item : input_shape2) { 15218be168c0dSopenharmony_ci+ if (IsNumber(item)) { 15219be168c0dSopenharmony_ci+ const_part *= std::stoi(item); 15220be168c0dSopenharmony_ci+ } else { 15221be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 15222be168c0dSopenharmony_ci+ non_const_part += " * "; 15223be168c0dSopenharmony_ci+ } 15224be168c0dSopenharmony_ci+ non_const_part += item; 15225be168c0dSopenharmony_ci+ } 15226be168c0dSopenharmony_ci+ } 15227be168c0dSopenharmony_ci+ code.CodeFunction("memcpy", hidden_output, hidden_init, 15228be168c0dSopenharmony_ci+ non_const_part + " * " + std::to_string(const_part * DataTypeSize(kNumberTypeFloat16))); 15229be168c0dSopenharmony_ci+ auto input_shape3 = shape_info_container_->GetTemplateShape(input_tensors_[THIRD_INPUT]); 15230be168c0dSopenharmony_ci+ const_part = 1; 15231be168c0dSopenharmony_ci+ non_const_part = ""; 15232be168c0dSopenharmony_ci+ for (const auto &item : input_shape3) { 15233be168c0dSopenharmony_ci+ if (IsNumber(item)) { 15234be168c0dSopenharmony_ci+ const_part *= std::stoi(item); 15235be168c0dSopenharmony_ci+ } else { 15236be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 15237be168c0dSopenharmony_ci+ non_const_part += " * "; 15238be168c0dSopenharmony_ci+ } 15239be168c0dSopenharmony_ci+ non_const_part += item; 15240be168c0dSopenharmony_ci+ } 15241be168c0dSopenharmony_ci+ } 15242be168c0dSopenharmony_ci+ code.CodeFunction("memcpy", cell_output, cell_init, 15243be168c0dSopenharmony_ci+ non_const_part + " * " + std::to_string(const_part * DataTypeSize(kNumberTypeFloat16))); 15244be168c0dSopenharmony_ci+ auto weight_i_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(weight_i_ptr_)); 15245be168c0dSopenharmony_ci+ auto weight_h_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(weight_h_ptr_)); 15246be168c0dSopenharmony_ci+ auto weight_pro_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(weight_project_ptr_)); 15247be168c0dSopenharmony_ci+ auto input_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(input_bias_)); 15248be168c0dSopenharmony_ci+ auto state_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(hh_bias_)); 15249be168c0dSopenharmony_ci+ auto pro_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(static_cast<float16 *>(project_bias_)); 15250be168c0dSopenharmony_ci+ 15251be168c0dSopenharmony_ci+ code.CodeFunction("LstmFp16", "(float16_t *)(" + output1 + ")", "(float16_t *)(" + input1 + ")", weight_i_str, 15252be168c0dSopenharmony_ci+ weight_h_str, input_bias_str, state_bias_str, weight_pro_str, pro_bias_str, 15253be168c0dSopenharmony_ci+ "(float16_t *)(" + hidden_output + ")", "(float16_t *)(" + cell_output + ")", "buffer", 15254be168c0dSopenharmony_ci+ "&lstm_param"); 15255be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 15256be168c0dSopenharmony_ci+ return RET_OK; 15257be168c0dSopenharmony_ci+} 15258be168c0dSopenharmony_ci+ 15259be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::InitParam() { 15260be168c0dSopenharmony_ci+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[FIRST_INPUT]); 15261be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in_shape1.size() == C3NUM, RET_INPUT_TENSOR_ERROR, "LstmMindir first input's dim must be 3D."); 15262be168c0dSopenharmony_ci+ dynamic_lstm_param_.batch_ = in_shape1[1]; 15263be168c0dSopenharmony_ci+ dynamic_lstm_param_.seq_len_ = in_shape1[0]; 15264be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(IsNumber(in_shape1[C2NUM]), RET_NOT_SUPPORT, 15265be168c0dSopenharmony_ci+ "LstmMindir doesn't support input_size is dynamical in micro."); 15266be168c0dSopenharmony_ci+ lstm_param_->input_size_ = std::atoi(in_shape1[C2NUM].c_str()); 15267be168c0dSopenharmony_ci+ 15268be168c0dSopenharmony_ci+ auto h_init_shape = input_tensors_[SECOND_INPUT]->shape(); 15269be168c0dSopenharmony_ci+ auto c_init_shape = input_tensors_[THIRD_INPUT]->shape(); 15270be168c0dSopenharmony_ci+ lstm_param_->hidden_size_ = c_init_shape.back(); 15271be168c0dSopenharmony_ci+ lstm_param_->output_size_ = h_init_shape.back(); 15272be168c0dSopenharmony_ci+ 15273be168c0dSopenharmony_ci+ lstm_param_->output_step_ = lstm_param_->bidirectional_ ? C2NUM * lstm_param_->batch_ * lstm_param_->output_size_ 15274be168c0dSopenharmony_ci+ : lstm_param_->batch_ * lstm_param_->output_size_; 15275be168c0dSopenharmony_ci+ weight_segment_num_ = lstm_param_->bidirectional_ ? C8NUM : C4NUM; 15276be168c0dSopenharmony_ci+ dynamic_lstm_param_.input_row_align_ = 15277be168c0dSopenharmony_ci+ "(" + dynamic_lstm_param_.batch_ + " * " + dynamic_lstm_param_.seq_len_ + " + 3) / 4 * 4"; 15278be168c0dSopenharmony_ci+ lstm_param_->input_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 15279be168c0dSopenharmony_ci+ 15280be168c0dSopenharmony_ci+ dynamic_lstm_param_.state_row_align_ = "(" + dynamic_lstm_param_.batch_ + " + 3) / 4 * 4"; 15281be168c0dSopenharmony_ci+ lstm_param_->state_col_align_ = UP_ROUND(lstm_param_->hidden_size_, C4NUM); 15282be168c0dSopenharmony_ci+ lstm_param_->proj_col_align_ = UP_ROUND(lstm_param_->project_size_, C4NUM); 15283be168c0dSopenharmony_ci+ dynamic_lstm_param_.output_step_ = 15284be168c0dSopenharmony_ci+ std::to_string((lstm_param_->bidirectional_ ? C2NUM : C1NUM) * lstm_param_->output_size_) + " * " + 15285be168c0dSopenharmony_ci+ dynamic_lstm_param_.batch_; 15286be168c0dSopenharmony_ci+ size_t scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 15287be168c0dSopenharmony_ci+ hi_size_ = scale * C4NUM * lstm_param_->hidden_size_ * lstm_param_->input_size_; 15288be168c0dSopenharmony_ci+ hh_size_ = scale * C4NUM * lstm_param_->hidden_size_ * lstm_param_->output_size_; 15289be168c0dSopenharmony_ci+ hp_size_ = scale * lstm_param_->project_size_ * lstm_param_->hidden_size_; 15290be168c0dSopenharmony_ci+ bias_size_ = scale * C8NUM * lstm_param_->hidden_size_; 15291be168c0dSopenharmony_ci+ auto real_whole_size = input_tensors_[FOURTH_INPUT]->ElementsNum(); 15292be168c0dSopenharmony_ci+ gpu_state_ = (hi_size_ + hh_size_ + hp_size_ + bias_size_) == static_cast<size_t>(real_whole_size); 15293be168c0dSopenharmony_ci+ if (gpu_state_) { 15294be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "LstmMindirDynamicFP16Coder doesn't suuport model which exported from GPU."; 15295be168c0dSopenharmony_ci+ return RET_NOT_SUPPORT; 15296be168c0dSopenharmony_ci+ } 15297be168c0dSopenharmony_ci+ if (hi_size_ + hh_size_ + hp_size_ == static_cast<size_t>(real_whole_size)) { 15298be168c0dSopenharmony_ci+ bias_size_ = 0; 15299be168c0dSopenharmony_ci+ return RET_OK; 15300be168c0dSopenharmony_ci+ } 15301be168c0dSopenharmony_ci+ bias_size_ /= C2NUM; 15302be168c0dSopenharmony_ci+ if ((hi_size_ + hh_size_ + hp_size_ + bias_size_) != static_cast<size_t>(real_whole_size)) { 15303be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Bias of LstmMindir exported from cpu only exist in hi-part."; 15304be168c0dSopenharmony_ci+ return RET_INPUT_TENSOR_ERROR; 15305be168c0dSopenharmony_ci+ } 15306be168c0dSopenharmony_ci+ return RET_OK; 15307be168c0dSopenharmony_ci+} 15308be168c0dSopenharmony_ci+ 15309be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::InitInputWeightBias(CoderContext *const context) { 15310be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 15311be168c0dSopenharmony_ci+ 15312be168c0dSopenharmony_ci+ size_t weight_hi_size = 15313be168c0dSopenharmony_ci+ weight_segment_num_ * lstm_param_->input_col_align_ * lstm_param_->input_size_ * DataTypeSize(data_type_); 15314be168c0dSopenharmony_ci+ weight_i_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15315be168c0dSopenharmony_ci+ MS_CHECK_PTR(weight_i_ptr_); 15316be168c0dSopenharmony_ci+ 15317be168c0dSopenharmony_ci+ size_t w_buf_size = 0; 15318be168c0dSopenharmony_ci+ 15319be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(weight_i_ptr_, context->weight_name(), context->weight_offset_name(), 15320be168c0dSopenharmony_ci+ context->weight_size_name(), weight_hi_size); 15321be168c0dSopenharmony_ci+ w_buf_size += weight_hi_size; 15322be168c0dSopenharmony_ci+ auto weight_i_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FOURTH_INPUT]); 15323be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!weight_i_str.empty(), RET_INPUT_TENSOR_ERROR, "Lstm cannot get weight."); 15324be168c0dSopenharmony_ci+ auto packed_weight_i_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(weight_i_ptr_)); 15325be168c0dSopenharmony_ci+ init_code << " int32_t order[4] = {0, 2, 3, 1};\n"; 15326be168c0dSopenharmony_ci+ init_code.CodeFunction("PackLstmWeightFp16", packed_weight_i_str, weight_i_str, weight_segment_num_, 15327be168c0dSopenharmony_ci+ lstm_param_->input_size_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, "order"); 15328be168c0dSopenharmony_ci+ 15329be168c0dSopenharmony_ci+ auto bias_stride = hi_size_ + hh_size_ + hp_size_; 15330be168c0dSopenharmony_ci+ input_bias_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15331be168c0dSopenharmony_ci+ MS_CHECK_PTR(input_bias_); 15332be168c0dSopenharmony_ci+ size_t bias_i_size = weight_segment_num_ * lstm_param_->input_col_align_ * DataTypeSize(data_type_); 15333be168c0dSopenharmony_ci+ w_buf_size += bias_i_size; 15334be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(input_bias_, context->weight_name(), context->weight_offset_name(), 15335be168c0dSopenharmony_ci+ context->weight_size_name(), bias_i_size); 15336be168c0dSopenharmony_ci+ auto input_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(input_bias_)); 15337be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", input_bias_str, 0, bias_i_size); 15338be168c0dSopenharmony_ci+ if (bias_size_ != 0) { 15339be168c0dSopenharmony_ci+ init_code.CodeFunction("PackLstmBiasFp16", input_bias_str, weight_i_str + " + " + std::to_string(bias_stride), 15340be168c0dSopenharmony_ci+ weight_segment_num_, lstm_param_->hidden_size_, lstm_param_->input_col_align_, 15341be168c0dSopenharmony_ci+ lstm_param_->bidirectional_, "order"); 15342be168c0dSopenharmony_ci+ } 15343be168c0dSopenharmony_ci+ 15344be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(w_buf_size); 15345be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 15346be168c0dSopenharmony_ci+ return RET_OK; 15347be168c0dSopenharmony_ci+} 15348be168c0dSopenharmony_ci+ 15349be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::InitStateWeightBias(CoderContext *const context) { 15350be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 15351be168c0dSopenharmony_ci+ 15352be168c0dSopenharmony_ci+ size_t weight_hh_size = 15353be168c0dSopenharmony_ci+ weight_segment_num_ * lstm_param_->state_col_align_ * lstm_param_->project_size_ * DataTypeSize(data_type_); 15354be168c0dSopenharmony_ci+ weight_h_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15355be168c0dSopenharmony_ci+ MS_CHECK_PTR(weight_h_ptr_); 15356be168c0dSopenharmony_ci+ 15357be168c0dSopenharmony_ci+ size_t w_buf_size = 0; 15358be168c0dSopenharmony_ci+ 15359be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(weight_h_ptr_, context->weight_name(), context->weight_offset_name(), 15360be168c0dSopenharmony_ci+ context->weight_size_name(), weight_hh_size); 15361be168c0dSopenharmony_ci+ w_buf_size += weight_hh_size; 15362be168c0dSopenharmony_ci+ auto weight_hh_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FOURTH_INPUT]); 15363be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!weight_hh_str.empty(), RET_INPUT_TENSOR_ERROR, "Lstm cannot get weight."); 15364be168c0dSopenharmony_ci+ auto packed_weight_hh_str = 15365be168c0dSopenharmony_ci+ MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(weight_h_ptr_)); 15366be168c0dSopenharmony_ci+ init_code << " int32_t order[4] = {0, 2, 3, 1};\n"; 15367be168c0dSopenharmony_ci+ init_code.CodeFunction("PackLstmWeightFp16", packed_weight_hh_str, weight_hh_str + " + " + std::to_string(hi_size_), 15368be168c0dSopenharmony_ci+ weight_segment_num_, lstm_param_->project_size_, lstm_param_->hidden_size_, 15369be168c0dSopenharmony_ci+ lstm_param_->state_col_align_, "order"); 15370be168c0dSopenharmony_ci+ 15371be168c0dSopenharmony_ci+ hh_bias_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15372be168c0dSopenharmony_ci+ MS_CHECK_PTR(hh_bias_); 15373be168c0dSopenharmony_ci+ size_t bias_hh_size = weight_segment_num_ * lstm_param_->state_col_align_ * DataTypeSize(data_type_); 15374be168c0dSopenharmony_ci+ w_buf_size += bias_hh_size; 15375be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(hh_bias_, context->weight_name(), context->weight_offset_name(), 15376be168c0dSopenharmony_ci+ context->weight_size_name(), bias_hh_size); 15377be168c0dSopenharmony_ci+ auto hh_bias_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(hh_bias_)); 15378be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", hh_bias_str, 0, bias_hh_size); 15379be168c0dSopenharmony_ci+ 15380be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(w_buf_size); 15381be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 15382be168c0dSopenharmony_ci+ return RET_OK; 15383be168c0dSopenharmony_ci+} 15384be168c0dSopenharmony_ci+ 15385be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::InitProjectWeight(CoderContext *const context) { 15386be168c0dSopenharmony_ci+ if (hp_size_ == 0) { 15387be168c0dSopenharmony_ci+ return RET_OK; 15388be168c0dSopenharmony_ci+ } 15389be168c0dSopenharmony_ci+ 15390be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 15391be168c0dSopenharmony_ci+ size_t w_buf_size = 0; 15392be168c0dSopenharmony_ci+ int scale = lstm_param_->bidirectional_ ? C2NUM : C1NUM; 15393be168c0dSopenharmony_ci+ int col_align = UP_ROUND(lstm_param_->project_size_, C8NUM); 15394be168c0dSopenharmony_ci+ size_t weight_pro_size = scale * lstm_param_->hidden_size_ * col_align * DataTypeSize(data_type_); 15395be168c0dSopenharmony_ci+ weight_project_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15396be168c0dSopenharmony_ci+ MS_CHECK_PTR(weight_project_ptr_); 15397be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(weight_project_ptr_, context->weight_name(), context->weight_offset_name(), 15398be168c0dSopenharmony_ci+ context->weight_size_name(), weight_pro_size); 15399be168c0dSopenharmony_ci+ w_buf_size += weight_pro_size; 15400be168c0dSopenharmony_ci+ auto weight_hp_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FOURTH_INPUT]); 15401be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!weight_hp_str.empty(), RET_INPUT_TENSOR_ERROR, "Lstm cannot get weight."); 15402be168c0dSopenharmony_ci+ auto weight_pro_str = 15403be168c0dSopenharmony_ci+ MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(weight_project_ptr_)); 15404be168c0dSopenharmony_ci+ init_code.CodeFunction("PackLstmWeightFp16", weight_pro_str, 15405be168c0dSopenharmony_ci+ weight_hp_str + " + " + std::to_string(hi_size_ + hh_size_), scale, lstm_param_->hidden_size_, 15406be168c0dSopenharmony_ci+ lstm_param_->project_size_, col_align, "NULL"); 15407be168c0dSopenharmony_ci+ 15408be168c0dSopenharmony_ci+ size_t bias_pro_size = col_align * DataTypeSize(data_type_); 15409be168c0dSopenharmony_ci+ project_bias_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight); 15410be168c0dSopenharmony_ci+ MS_CHECK_PTR(project_bias_); 15411be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(project_bias_, context->weight_name(), context->weight_offset_name(), 15412be168c0dSopenharmony_ci+ context->weight_size_name(), bias_pro_size); 15413be168c0dSopenharmony_ci+ w_buf_size += bias_pro_size; 15414be168c0dSopenharmony_ci+ auto bias_pro_str = MemoryAllocator::GetInstance()->GetRuntimeAddr(reinterpret_cast<float16 *>(project_bias_)); 15415be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", bias_pro_str, 0, bias_pro_size); 15416be168c0dSopenharmony_ci+ 15417be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(w_buf_size); 15418be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 15419be168c0dSopenharmony_ci+ return RET_OK; 15420be168c0dSopenharmony_ci+} 15421be168c0dSopenharmony_ci+ 15422be168c0dSopenharmony_ci+int LstmMindirDynamicFP16Coder::ComputeWorkSpace() { 15423be168c0dSopenharmony_ci+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[FIRST_INPUT]); 15424be168c0dSopenharmony_ci+ auto seq_lens = shape_info_container_->GetRealNums(in_shape1[0]); 15425be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!seq_lens.empty(), RET_ERROR, "Lstm cannot get seq_len"); 15426be168c0dSopenharmony_ci+ auto batches = shape_info_container_->GetRealNums(in_shape1[1]); 15427be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!batches.empty(), RET_ERROR, "Lstm cannot get batch"); 15428be168c0dSopenharmony_ci+ size_t scene_num = seq_lens.size() > batches.size() ? seq_lens.size() : batches.size(); 15429be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 15430be168c0dSopenharmony_ci+ int seq_len = seq_lens[i % seq_lens.size()]; 15431be168c0dSopenharmony_ci+ int batch = batches[i % batches.size()]; 15432be168c0dSopenharmony_ci+ size_t buffer1 = 15433be168c0dSopenharmony_ci+ seq_len * batch <= C3NUM ? 0 : seq_len * batch * lstm_param_->input_size_ * DataTypeSize(data_type_); 15434be168c0dSopenharmony_ci+ size_t buffer2 = C4NUM * seq_len * batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15435be168c0dSopenharmony_ci+ size_t buffer3 = batch <= C3NUM ? 0 : batch * lstm_param_->output_size_ * DataTypeSize(data_type_); 15436be168c0dSopenharmony_ci+ size_t buffer4 = C4NUM * batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15437be168c0dSopenharmony_ci+ size_t buffer5 = (lstm_param_->zoneout_cell_ >= -FLT_EPSILON && lstm_param_->zoneout_cell_ <= FLT_EPSILON) 15438be168c0dSopenharmony_ci+ ? 0 15439be168c0dSopenharmony_ci+ : batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15440be168c0dSopenharmony_ci+ size_t buffer6 = (lstm_param_->zoneout_hidden_ >= -FLT_EPSILON && lstm_param_->zoneout_hidden_ <= FLT_EPSILON) 15441be168c0dSopenharmony_ci+ ? 0 15442be168c0dSopenharmony_ci+ : batch * lstm_param_->output_size_ * DataTypeSize(data_type_); 15443be168c0dSopenharmony_ci+ size_t buffer7 = (batch <= C3NUM || lstm_param_->project_size_ == 0) 15444be168c0dSopenharmony_ci+ ? 0 15445be168c0dSopenharmony_ci+ : batch * lstm_param_->hidden_size_ * DataTypeSize(data_type_); 15446be168c0dSopenharmony_ci+ auto whole_size = buffer1 + buffer2 + buffer3 + buffer4 + buffer5 + buffer6 + buffer7; 15447be168c0dSopenharmony_ci+ buffers_start_ = dynamic_mem_manager_->AllocWorkSpace(whole_size, i); 15448be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!buffers_start_.empty(), RET_ERROR, "Lstm cannot alloc workspace."); 15449be168c0dSopenharmony_ci+ } 15450be168c0dSopenharmony_ci+ 15451be168c0dSopenharmony_ci+ return RET_OK; 15452be168c0dSopenharmony_ci+} 15453be168c0dSopenharmony_ci+ 15454be168c0dSopenharmony_ci+void LstmMindirDynamicFP16Coder::CreateBufferAddrStr() { 15455be168c0dSopenharmony_ci+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[FIRST_INPUT]); 15456be168c0dSopenharmony_ci+ auto seq_len = in_shape1[0]; 15457be168c0dSopenharmony_ci+ auto batch = in_shape1[1]; 15458be168c0dSopenharmony_ci+ auto input_row_align = "(" + seq_len + " * " + batch + " + 3) / 4 * 4"; 15459be168c0dSopenharmony_ci+ auto state_row_align = "(" + batch + " + 3) / 4 * 4"; 15460be168c0dSopenharmony_ci+ buffers_str_.push_back("(" + seq_len + " * " + batch + " <= 3) ? NULL : " + buffers_start_); 15461be168c0dSopenharmony_ci+ auto offset = "((" + seq_len + " * " + batch + " <= 3) ? 0 : (" + seq_len + " * " + batch + ") * " + 15462be168c0dSopenharmony_ci+ std::to_string(lstm_param_->input_size_ * DataTypeSize(data_type_)) + ")"; 15463be168c0dSopenharmony_ci+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15464be168c0dSopenharmony_ci+ offset = "(" + offset + " + " + seq_len + " * " + batch + " * " + 15465be168c0dSopenharmony_ci+ std::to_string(C4NUM * lstm_param_->hidden_size_ * DataTypeSize(data_type_)) + ")"; 15466be168c0dSopenharmony_ci+ buffers_str_.push_back(batch + " <= 3 ? NULL : (" + buffers_start_ + " + " + offset + ")"); 15467be168c0dSopenharmony_ci+ offset = "(" + offset + " + (" + batch + " <= 3 ? 0 : (" + batch + ") * " + 15468be168c0dSopenharmony_ci+ std::to_string(lstm_param_->output_size_ * DataTypeSize(data_type_)) + "))"; 15469be168c0dSopenharmony_ci+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15470be168c0dSopenharmony_ci+ offset = "(" + offset + " + " + batch + " * " + 15471be168c0dSopenharmony_ci+ std::to_string(C4NUM * lstm_param_->hidden_size_ * DataTypeSize(data_type_)) + ")"; 15472be168c0dSopenharmony_ci+ if (lstm_param_->zoneout_cell_ < -FLT_EPSILON || lstm_param_->zoneout_cell_ > FLT_EPSILON) { 15473be168c0dSopenharmony_ci+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15474be168c0dSopenharmony_ci+ offset = 15475be168c0dSopenharmony_ci+ "(" + offset + " + " + batch + " * " + std::to_string(lstm_param_->hidden_size_ * DataTypeSize(data_type_)) + ")"; 15476be168c0dSopenharmony_ci+ } else { 15477be168c0dSopenharmony_ci+ buffers_str_.emplace_back("NULL"); 15478be168c0dSopenharmony_ci+ } 15479be168c0dSopenharmony_ci+ if (lstm_param_->zoneout_hidden_ < -FLT_EPSILON && lstm_param_->zoneout_hidden_ > FLT_EPSILON) { 15480be168c0dSopenharmony_ci+ buffers_str_.push_back(buffers_start_ + " + " + offset); 15481be168c0dSopenharmony_ci+ offset = 15482be168c0dSopenharmony_ci+ "(" + offset + " + " + batch + " * " + std::to_string(lstm_param_->output_size_ * DataTypeSize(data_type_)) + ")"; 15483be168c0dSopenharmony_ci+ } else { 15484be168c0dSopenharmony_ci+ buffers_str_.emplace_back("NULL"); 15485be168c0dSopenharmony_ci+ } 15486be168c0dSopenharmony_ci+ if (lstm_param_->project_size_ == 0) { 15487be168c0dSopenharmony_ci+ buffers_str_.emplace_back("NULL"); 15488be168c0dSopenharmony_ci+ } else { 15489be168c0dSopenharmony_ci+ buffers_str_.emplace_back(batch + " <= 3 ? NULL : " + "(" + buffers_start_ + " + " + offset + ")"); 15490be168c0dSopenharmony_ci+ } 15491be168c0dSopenharmony_ci+} 15492be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LSTM, 15493be168c0dSopenharmony_ci+ CPUOpCoderCreator<LstmMindirDynamicFP16Coder>) 15494be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15495be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h 15496be168c0dSopenharmony_cinew file mode 100644 15497be168c0dSopenharmony_ciindex 00000000..1084fa82 15498be168c0dSopenharmony_ci--- /dev/null 15499be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/lstm_mindir_dynamic_fp16_coder.h 15500be168c0dSopenharmony_ci@@ -0,0 +1,66 @@ 15501be168c0dSopenharmony_ci+/** 15502be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15503be168c0dSopenharmony_ci+ * 15504be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15505be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15506be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15507be168c0dSopenharmony_ci+ * 15508be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 15509be168c0dSopenharmony_ci+ * 15510be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 15511be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 15512be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15513be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 15514be168c0dSopenharmony_ci+ * limitations under the License. 15515be168c0dSopenharmony_ci+ */ 15516be168c0dSopenharmony_ci+ 15517be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_LSTM_DYNAMIC_FP16_CODER_H 15518be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_LSTM_DYNAMIC_FP16_CODER_H 15519be168c0dSopenharmony_ci+ 15520be168c0dSopenharmony_ci+#include <vector> 15521be168c0dSopenharmony_ci+#include <string> 15522be168c0dSopenharmony_ci+#include "nnacl/lstm_parameter.h" 15523be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" 15524be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 15525be168c0dSopenharmony_ci+ 15526be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 15527be168c0dSopenharmony_ci+ 15528be168c0dSopenharmony_ci+class LstmMindirDynamicFP16Coder : public OperatorCoder { 15529be168c0dSopenharmony_ci+ public: 15530be168c0dSopenharmony_ci+ LstmMindirDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 15531be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 15532be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 15533be168c0dSopenharmony_ci+ 15534be168c0dSopenharmony_ci+ ~LstmMindirDynamicFP16Coder() override = default; 15535be168c0dSopenharmony_ci+ 15536be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 15537be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 15538be168c0dSopenharmony_ci+ 15539be168c0dSopenharmony_ci+ private: 15540be168c0dSopenharmony_ci+ int InitParam(); 15541be168c0dSopenharmony_ci+ int ComputeWorkSpace(); 15542be168c0dSopenharmony_ci+ void CreateBufferAddrStr(); 15543be168c0dSopenharmony_ci+ int InitInputWeightBias(CoderContext *const context); 15544be168c0dSopenharmony_ci+ int InitStateWeightBias(CoderContext *const context); 15545be168c0dSopenharmony_ci+ int InitProjectWeight(CoderContext *const context); 15546be168c0dSopenharmony_ci+ bool gpu_state_{false}; 15547be168c0dSopenharmony_ci+ TypeId data_type_{kNumberTypeFloat16}; 15548be168c0dSopenharmony_ci+ int weight_segment_num_{0}; 15549be168c0dSopenharmony_ci+ size_t hi_size_{0}; 15550be168c0dSopenharmony_ci+ size_t hh_size_{0}; 15551be168c0dSopenharmony_ci+ size_t hp_size_{0}; 15552be168c0dSopenharmony_ci+ size_t bias_size_{0}; 15553be168c0dSopenharmony_ci+ void *weight_i_ptr_{nullptr}; 15554be168c0dSopenharmony_ci+ void *weight_h_ptr_{nullptr}; 15555be168c0dSopenharmony_ci+ void *weight_project_ptr_{nullptr}; 15556be168c0dSopenharmony_ci+ void *input_bias_{nullptr}; 15557be168c0dSopenharmony_ci+ void *hh_bias_{nullptr}; 15558be168c0dSopenharmony_ci+ void *project_bias_{nullptr}; 15559be168c0dSopenharmony_ci+ LstmParameter *lstm_param_{nullptr}; 15560be168c0dSopenharmony_ci+ DynamicLstmParameter dynamic_lstm_param_; 15561be168c0dSopenharmony_ci+ std::string buffers_start_; 15562be168c0dSopenharmony_ci+ std::vector<std::string> buffers_str_; 15563be168c0dSopenharmony_ci+}; 15564be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15565be168c0dSopenharmony_ci+ 15566be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_LSTM_DYNAMIC_FP16_CODER_H 15567be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc 15568be168c0dSopenharmony_cinew file mode 100644 15569be168c0dSopenharmony_ciindex 00000000..f6c56f86 15570be168c0dSopenharmony_ci--- /dev/null 15571be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.cc 15572be168c0dSopenharmony_ci@@ -0,0 +1,228 @@ 15573be168c0dSopenharmony_ci+/** 15574be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15575be168c0dSopenharmony_ci+ * 15576be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15577be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15578be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15579be168c0dSopenharmony_ci+ * 15580be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 15581be168c0dSopenharmony_ci+ * 15582be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 15583be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 15584be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15585be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 15586be168c0dSopenharmony_ci+ * limitations under the License. 15587be168c0dSopenharmony_ci+ */ 15588be168c0dSopenharmony_ci+ 15589be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h" 15590be168c0dSopenharmony_ci+#include <string> 15591be168c0dSopenharmony_ci+#include <vector> 15592be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/log.h" 15593be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/file_collector.h" 15594be168c0dSopenharmony_ci+#include "base/float16.h" 15595be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 15596be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 15597be168c0dSopenharmony_ci+ 15598be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_MatMulFusion; 15599be168c0dSopenharmony_ci+ 15600be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 15601be168c0dSopenharmony_ci+int MatMulDynamicFP16BaseCoder::Prepare(CoderContext *const context) { 15602be168c0dSopenharmony_ci+ row_tile_ = C1NUM; 15603be168c0dSopenharmony_ci+ col_tile_ = C4NUM; 15604be168c0dSopenharmony_ci+ auto ret = InitAShape(); 15605be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "init A-metrics' info failed"); 15606be168c0dSopenharmony_ci+ ret = InitBShape(); 15607be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "init B-metrics' info failed"); 15608be168c0dSopenharmony_ci+ params_->col_align_ = UP_ROUND(params_->col_, col_tile_); 15609be168c0dSopenharmony_ci+ return RET_OK; 15610be168c0dSopenharmony_ci+} 15611be168c0dSopenharmony_ci+ 15612be168c0dSopenharmony_ci+int MatMulDynamicFP16BaseCoder::DoCode(CoderContext *const context) { 15613be168c0dSopenharmony_ci+ CollectFilesForTarget(context); 15614be168c0dSopenharmony_ci+ auto ret = InitMatrixB(context); 15615be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "InitMatrixB failed."); 15616be168c0dSopenharmony_ci+ ret = InitBiasData(context); 15617be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "InitBiasData failed."); 15618be168c0dSopenharmony_ci+ 15619be168c0dSopenharmony_ci+ ret = ComputeWorkSpace(); 15620be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Matmul alloc workspace failed."); 15621be168c0dSopenharmony_ci+ auto input_a_str = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 15622be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input_a_str.empty(), RET_ERROR, "Matmul cannot get matrixA"); 15623be168c0dSopenharmony_ci+ auto output_str = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 15624be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!output_str.empty(), RET_ERROR, "Matmul cannot get output"); 15625be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 15626be168c0dSopenharmony_ci+ if (params_->a_transpose_) { 15627be168c0dSopenharmony_ci+ code << " if (" << dynamic_params_.row_ << " == 1) {\n"; 15628be168c0dSopenharmony_ci+ code << " if (" << dynamic_params_.batch_ << " <= 3) {\n"; 15629be168c0dSopenharmony_ci+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + input_a_str + ")", input_b_pack_str_, 15630be168c0dSopenharmony_ci+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15631be168c0dSopenharmony_ci+ dynamic_params_.batch_, params_->col_, params_->col_, OutType_Nhwc); 15632be168c0dSopenharmony_ci+ code << " } else {\n"; 15633be168c0dSopenharmony_ci+ code.CodeFunction("RowMajor2ColLadder12MajorFp16", "(float16_t *)(" + input_a_str + ")", 15634be168c0dSopenharmony_ci+ "(float16_t *)(" + buffer_start_ + ")", dynamic_params_.batch_, params_->deep_); 15635be168c0dSopenharmony_ci+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + buffer_start_ + ")", input_b_pack_str_, 15636be168c0dSopenharmony_ci+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15637be168c0dSopenharmony_ci+ dynamic_params_.batch_, params_->col_, params_->col_, OutType_Nhwc); 15638be168c0dSopenharmony_ci+ code << " } else {\n"; 15639be168c0dSopenharmony_ci+ code << " int in_stride = " << dynamic_params_.row_ << " * " << params_->deep_ << ";\n"; 15640be168c0dSopenharmony_ci+ code << " int out_stride = " << dynamic_params_.row_ << " * " << params_->col_ << ";\n"; 15641be168c0dSopenharmony_ci+ code << " for (int i = 0; i < " << dynamic_params_.batch_ << "; ++i) {\n"; 15642be168c0dSopenharmony_ci+ code.CodeFunction("RowMajor2RowLadder12MajorFp16", "(float16_t *)(" + input_a_str + ")" + " + in_stride * i", 15643be168c0dSopenharmony_ci+ "(float16_t *)(" + buffer_start_ + ")", params_->deep_, dynamic_params_.row_); 15644be168c0dSopenharmony_ci+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + buffer_start_ + ")", input_b_pack_str_, 15645be168c0dSopenharmony_ci+ "(float16_t *)(" + output_str + ")" + " + out_stride * i", bias_str_, params_->act_type_, 15646be168c0dSopenharmony_ci+ params_->deep_, dynamic_params_.row_, params_->col_, OutType_Nhwc); 15647be168c0dSopenharmony_ci+ code << " }\n"; 15648be168c0dSopenharmony_ci+ code << " }\n"; 15649be168c0dSopenharmony_ci+ } else { 15650be168c0dSopenharmony_ci+ code << " if (" << dynamic_params_.batch_ << " * " << dynamic_params_.row_ << " <= 3) {\n"; 15651be168c0dSopenharmony_ci+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + input_a_str + ")", input_b_pack_str_, 15652be168c0dSopenharmony_ci+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15653be168c0dSopenharmony_ci+ dynamic_params_.batch_ + " * " + dynamic_params_.row_, params_->col_, params_->col_, 15654be168c0dSopenharmony_ci+ OutType_Nhwc); 15655be168c0dSopenharmony_ci+ code << " } else {\n"; 15656be168c0dSopenharmony_ci+ code.CodeFunction("RowMajor2ColLadder12MajorFp16", "(float16_t *)(" + input_a_str + ")", 15657be168c0dSopenharmony_ci+ "(float16_t *)(" + buffer_start_ + ")", dynamic_params_.batch_ + " * " + dynamic_params_.row_, 15658be168c0dSopenharmony_ci+ params_->deep_); 15659be168c0dSopenharmony_ci+ code.CodeFunction("MatmulFp16OptV2", "(float16_t *)(" + buffer_start_ + ")", input_b_pack_str_, 15660be168c0dSopenharmony_ci+ "(float16_t *)(" + output_str + ")", bias_str_, params_->act_type_, params_->deep_, 15661be168c0dSopenharmony_ci+ dynamic_params_.batch_ + " * " + dynamic_params_.row_, params_->col_, params_->col_, 15662be168c0dSopenharmony_ci+ OutType_Nhwc); 15663be168c0dSopenharmony_ci+ } 15664be168c0dSopenharmony_ci+ code << " }\n"; 15665be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 15666be168c0dSopenharmony_ci+ return RET_OK; 15667be168c0dSopenharmony_ci+} 15668be168c0dSopenharmony_ci+ 15669be168c0dSopenharmony_ci+int MatMulDynamicFP16BaseCoder::InitMatrixB(CoderContext *const context) { 15670be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 15671be168c0dSopenharmony_ci+ if (b_pack_ptr_ != nullptr) { 15672be168c0dSopenharmony_ci+ return RET_OK; 15673be168c0dSopenharmony_ci+ } 15674be168c0dSopenharmony_ci+ auto b_pack_ptr_size = static_cast<size_t>(params_->col_align_ * params_->deep_ * DataTypeSize(data_type_)); 15675be168c0dSopenharmony_ci+ b_pack_ptr_ = allocator_->GetSharedWeightAddr(filter_tensor_); 15676be168c0dSopenharmony_ci+ if (b_pack_ptr_ == nullptr) { 15677be168c0dSopenharmony_ci+ b_pack_ptr_ = allocator_->Malloc(data_type_, b_pack_ptr_size, kOnlinePackWeight, 15678be168c0dSopenharmony_ci+ filter_tensor_->tensor_name() + "_online_pack"); 15679be168c0dSopenharmony_ci+ allocator_->MarkSharedWeight(filter_tensor_, b_pack_ptr_); 15680be168c0dSopenharmony_ci+ } 15681be168c0dSopenharmony_ci+ MS_CHECK_PTR(b_pack_ptr_); 15682be168c0dSopenharmony_ci+ std::string input_b_str = allocator_->GetRuntimeAddr(filter_tensor_); 15683be168c0dSopenharmony_ci+ input_b_pack_str_ = allocator_->GetRuntimeAddr(static_cast<float16 *>(b_pack_ptr_)); 15684be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(b_pack_ptr_, context->weight_name(), context->weight_offset_name(), 15685be168c0dSopenharmony_ci+ context->weight_size_name(), b_pack_ptr_size); 15686be168c0dSopenharmony_ci+ if (b_batch_ == C1NUM) { 15687be168c0dSopenharmony_ci+ if (params_->b_transpose_) { 15688be168c0dSopenharmony_ci+ init_code.CodeFunction("RowMajor2ColNMajorFp16", input_b_str, input_b_pack_str_, params_->col_, params_->deep_, 15689be168c0dSopenharmony_ci+ "false"); 15690be168c0dSopenharmony_ci+ } else { 15691be168c0dSopenharmony_ci+ init_code.CodeFunction("RowMajor2RowNMajorFp16", input_b_str, input_b_pack_str_, params_->deep_, params_->col_, 15692be168c0dSopenharmony_ci+ "false"); 15693be168c0dSopenharmony_ci+ } 15694be168c0dSopenharmony_ci+ } else { 15695be168c0dSopenharmony_ci+ init_code << " for (int i = 0; i < " << b_batch_ << "; i++) {\n" 15696be168c0dSopenharmony_ci+ << " float16_t *src = " << input_b_str << " + i * " << params_->deep_ * params_->col_ << ";\n" 15697be168c0dSopenharmony_ci+ << " float16_t *dst = " << input_b_pack_str_ << " + i * " << params_->deep_ * params_->col_align_ 15698be168c0dSopenharmony_ci+ << ";\n"; 15699be168c0dSopenharmony_ci+ if (params_->b_transpose_) { 15700be168c0dSopenharmony_ci+ init_code << " RowMajor2ColNMajorFp16(src, dst, " << params_->col_ << ", " << params_->deep_ << ", false);\n"; 15701be168c0dSopenharmony_ci+ } else { 15702be168c0dSopenharmony_ci+ init_code << " RowMajor2RowNMajorFp16(src, dst, " << params_->deep_ << ", " << params_->col_ << ", false);\n"; 15703be168c0dSopenharmony_ci+ } 15704be168c0dSopenharmony_ci+ init_code << " }\n"; 15705be168c0dSopenharmony_ci+ } 15706be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(b_pack_ptr_size); 15707be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 15708be168c0dSopenharmony_ci+ return RET_OK; 15709be168c0dSopenharmony_ci+} 15710be168c0dSopenharmony_ci+ 15711be168c0dSopenharmony_ci+int MatMulDynamicFP16BaseCoder::InitBiasData(CoderContext *const context) { 15712be168c0dSopenharmony_ci+ NNaclFp32Serializer init_code; 15713be168c0dSopenharmony_ci+ if (bias_ptr_ != nullptr) { 15714be168c0dSopenharmony_ci+ return RET_OK; 15715be168c0dSopenharmony_ci+ } 15716be168c0dSopenharmony_ci+ auto bias_pack_ptr_size = static_cast<size_t>(params_->col_align_ * DataTypeSize(data_type_)); 15717be168c0dSopenharmony_ci+ if (input_tensors_.size() == C3NUM) { 15718be168c0dSopenharmony_ci+ bias_ptr_ = 15719be168c0dSopenharmony_ci+ allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, bias_tensor_->tensor_name() + "_online_pack"); 15720be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_ptr_); 15721be168c0dSopenharmony_ci+ } else { 15722be168c0dSopenharmony_ci+ bias_ptr_ = allocator_->Malloc(data_type_, kOnlineSize, kOnlinePackWeight, node_->name_ + "_bias_online_pack"); 15723be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_ptr_); 15724be168c0dSopenharmony_ci+ } 15725be168c0dSopenharmony_ci+ init_code.CodeBufferOffsetExpression(bias_ptr_, context->weight_name(), context->weight_offset_name(), 15726be168c0dSopenharmony_ci+ context->weight_size_name(), bias_pack_ptr_size); 15727be168c0dSopenharmony_ci+ bias_str_ = allocator_->GetRuntimeAddr(bias_ptr_); 15728be168c0dSopenharmony_ci+ if (input_tensors_.size() == DIMENSION_3D) { 15729be168c0dSopenharmony_ci+ auto origin_bias_str = allocator_->GetRuntimeAddr(bias_tensor_); 15730be168c0dSopenharmony_ci+ init_code.CodeFunction("memcpy", bias_str_, origin_bias_str, bias_tensor_->Size()); 15731be168c0dSopenharmony_ci+ } else { 15732be168c0dSopenharmony_ci+ init_code.CodeFunction("memset", bias_str_, 0, bias_pack_ptr_size); 15733be168c0dSopenharmony_ci+ } 15734be168c0dSopenharmony_ci+ context->AppendInitWeightSizeCode(bias_pack_ptr_size); 15735be168c0dSopenharmony_ci+ context->AppendInitCode(init_code.str()); 15736be168c0dSopenharmony_ci+ return RET_OK; 15737be168c0dSopenharmony_ci+} 15738be168c0dSopenharmony_ci+ 15739be168c0dSopenharmony_ci+int MatMulDynamicFP16BaseCoder::ComputeWorkSpace() { 15740be168c0dSopenharmony_ci+ auto a_shape = shape_info_container_->GetTemplateShape(input_tensor_); 15741be168c0dSopenharmony_ci+ std::map<std::string, std::vector<int>> real_nums; 15742be168c0dSopenharmony_ci+ size_t scene_num = 0; 15743be168c0dSopenharmony_ci+ for (auto &dim_template : a_shape) { 15744be168c0dSopenharmony_ci+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 15745be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 15746be168c0dSopenharmony_ci+ real_nums[dim_template] = dim_nums; 15747be168c0dSopenharmony_ci+ scene_num = std::max(scene_num, dim_nums.size()); 15748be168c0dSopenharmony_ci+ } 15749be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 15750be168c0dSopenharmony_ci+ std::vector<int> real_shape(a_shape.size()); 15751be168c0dSopenharmony_ci+ for (size_t j = 0; j < a_shape.size(); ++j) { 15752be168c0dSopenharmony_ci+ if (IsNumber(a_shape[j])) { 15753be168c0dSopenharmony_ci+ real_shape[j] = std::stoi(a_shape[j]); 15754be168c0dSopenharmony_ci+ } else { 15755be168c0dSopenharmony_ci+ real_shape[j] = real_nums[a_shape[j]][i % real_nums[a_shape[j]].size()]; 15756be168c0dSopenharmony_ci+ } 15757be168c0dSopenharmony_ci+ } 15758be168c0dSopenharmony_ci+ int a_batch = 1; 15759be168c0dSopenharmony_ci+ for (size_t j = 0; j < a_shape.size() - C2NUM; ++j) { 15760be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(a_batch, real_shape[j], RET_ERROR); 15761be168c0dSopenharmony_ci+ a_batch *= real_shape[j]; 15762be168c0dSopenharmony_ci+ } 15763be168c0dSopenharmony_ci+ int row = params_->a_transpose_ ? real_shape.back() : real_shape[real_shape.size() - C2NUM]; 15764be168c0dSopenharmony_ci+ int deep = params_->a_transpose_ ? real_shape[real_shape.size() - C2NUM] : real_shape.back(); 15765be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(deep == params_->deep_, RET_INPUT_TENSOR_ERROR, 15766be168c0dSopenharmony_ci+ "Matmul's matrixA doesn't match matrixB, becase their deeps are not same."); 15767be168c0dSopenharmony_ci+ int workspace = 0; 15768be168c0dSopenharmony_ci+ if (params_->a_transpose_) { 15769be168c0dSopenharmony_ci+ workspace = (row == 1 ? (a_batch <= C3NUM ? 0 : UP_ROUND(a_batch, row_tile_)) : UP_ROUND(row, row_tile_)) * deep; 15770be168c0dSopenharmony_ci+ } else { 15771be168c0dSopenharmony_ci+ workspace = (a_batch * row <= C3NUM ? 0 : UP_ROUND(a_batch * row, row_tile_)) * deep; 15772be168c0dSopenharmony_ci+ } 15773be168c0dSopenharmony_ci+ buffer_start_ = dynamic_mem_manager_->AllocWorkSpace(workspace, i); 15774be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!buffer_start_.empty(), RET_ERROR, "Matmul cannot alloc workspace."); 15775be168c0dSopenharmony_ci+ } 15776be168c0dSopenharmony_ci+ return RET_OK; 15777be168c0dSopenharmony_ci+} 15778be168c0dSopenharmony_ci+ 15779be168c0dSopenharmony_ci+int MatMulDynamicFP16BaseCoder::CollectFilesForTarget(CoderContext *const context) { 15780be168c0dSopenharmony_ci+ Collect(context, 15781be168c0dSopenharmony_ci+ { 15782be168c0dSopenharmony_ci+ "nnacl/fp16/pack_fp16.h", 15783be168c0dSopenharmony_ci+ "nnacl/fp16/matmul_fp16.h", 15784be168c0dSopenharmony_ci+ }, 15785be168c0dSopenharmony_ci+ { 15786be168c0dSopenharmony_ci+ "pack_fp16.c", 15787be168c0dSopenharmony_ci+ "matmul_fp16.c", 15788be168c0dSopenharmony_ci+ }); 15789be168c0dSopenharmony_ci+ if (target_ == kARM32) { 15790be168c0dSopenharmony_ci+ Collect(context, {}, {}, 15791be168c0dSopenharmony_ci+ { 15792be168c0dSopenharmony_ci+ "Matmul12x8Fp16.S", 15793be168c0dSopenharmony_ci+ "MatVecMulFp16.S", 15794be168c0dSopenharmony_ci+ }); 15795be168c0dSopenharmony_ci+ } else if (target_ == kARM64) { 15796be168c0dSopenharmony_ci+ Collect(context, {}, {}, {"MatmulFp16OptV2.S"}); 15797be168c0dSopenharmony_ci+ } 15798be168c0dSopenharmony_ci+ return RET_OK; 15799be168c0dSopenharmony_ci+} 15800be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15801be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h 15802be168c0dSopenharmony_cinew file mode 100644 15803be168c0dSopenharmony_ciindex 00000000..f73cfff7 15804be168c0dSopenharmony_ci--- /dev/null 15805be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h 15806be168c0dSopenharmony_ci@@ -0,0 +1,73 @@ 15807be168c0dSopenharmony_ci+/** 15808be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15809be168c0dSopenharmony_ci+ * 15810be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15811be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15812be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15813be168c0dSopenharmony_ci+ * 15814be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 15815be168c0dSopenharmony_ci+ * 15816be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 15817be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 15818be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15819be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 15820be168c0dSopenharmony_ci+ * limitations under the License. 15821be168c0dSopenharmony_ci+ */ 15822be168c0dSopenharmony_ci+ 15823be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_FP16_BASE_CODER_H_ 15824be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_FP16_BASE_CODER_H_ 15825be168c0dSopenharmony_ci+ 15826be168c0dSopenharmony_ci+#include <vector> 15827be168c0dSopenharmony_ci+#include <string> 15828be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/op_coder.h" 15829be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 15830be168c0dSopenharmony_ci+#include "nnacl/matmul_parameter.h" 15831be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 15832be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 15833be168c0dSopenharmony_ci+#include "base/float16.h" 15834be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/matmul_dynamic_parameter.h" 15835be168c0dSopenharmony_ci+ 15836be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 15837be168c0dSopenharmony_ci+class MatMulDynamicFP16BaseCoder : public OperatorCoder { 15838be168c0dSopenharmony_ci+ public: 15839be168c0dSopenharmony_ci+ MatMulDynamicFP16BaseCoder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 15840be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 15841be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 15842be168c0dSopenharmony_ci+ 15843be168c0dSopenharmony_ci+ ~MatMulDynamicFP16BaseCoder() override = default; 15844be168c0dSopenharmony_ci+ 15845be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 15846be168c0dSopenharmony_ci+ 15847be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 15848be168c0dSopenharmony_ci+ 15849be168c0dSopenharmony_ci+ private: 15850be168c0dSopenharmony_ci+ int InitBiasData(CoderContext *const context); 15851be168c0dSopenharmony_ci+ int InitMatrixB(CoderContext *const context); 15852be168c0dSopenharmony_ci+ int CollectFilesForTarget(CoderContext *const context); 15853be168c0dSopenharmony_ci+ int ComputeWorkSpace(); 15854be168c0dSopenharmony_ci+ 15855be168c0dSopenharmony_ci+ protected: 15856be168c0dSopenharmony_ci+ virtual int InitAShape() = 0; 15857be168c0dSopenharmony_ci+ virtual int InitBShape() = 0; 15858be168c0dSopenharmony_ci+ 15859be168c0dSopenharmony_ci+ protected: 15860be168c0dSopenharmony_ci+ Tensor *filter_tensor_{nullptr}; 15861be168c0dSopenharmony_ci+ Tensor *bias_tensor_{nullptr}; 15862be168c0dSopenharmony_ci+ MatMulParameter *params_{nullptr}; 15863be168c0dSopenharmony_ci+ MatmulDynamicParameter dynamic_params_; 15864be168c0dSopenharmony_ci+ void *a_pack_ptr_ = nullptr; 15865be168c0dSopenharmony_ci+ void *b_pack_ptr_ = nullptr; 15866be168c0dSopenharmony_ci+ void *bias_ptr_{nullptr}; 15867be168c0dSopenharmony_ci+ int col_tile_{0}; 15868be168c0dSopenharmony_ci+ int row_tile_{0}; 15869be168c0dSopenharmony_ci+ size_t a_pack_ptr_size_{0}; 15870be168c0dSopenharmony_ci+ TypeId data_type_{kNumberTypeFloat16}; 15871be168c0dSopenharmony_ci+ int a_batch_; 15872be168c0dSopenharmony_ci+ int b_batch_; 15873be168c0dSopenharmony_ci+ std::string buffer_start_; 15874be168c0dSopenharmony_ci+ std::string bias_str_; 15875be168c0dSopenharmony_ci+ std::string input_a_pack_str_; 15876be168c0dSopenharmony_ci+ std::string input_b_pack_str_; 15877be168c0dSopenharmony_ci+}; 15878be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15879be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_FP16_BASE_CODER_H_ 15880be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc 15881be168c0dSopenharmony_cinew file mode 100644 15882be168c0dSopenharmony_ciindex 00000000..24cf7120 15883be168c0dSopenharmony_ci--- /dev/null 15884be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.cc 15885be168c0dSopenharmony_ci@@ -0,0 +1,100 @@ 15886be168c0dSopenharmony_ci+/** 15887be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15888be168c0dSopenharmony_ci+ * 15889be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15890be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15891be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15892be168c0dSopenharmony_ci+ * 15893be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 15894be168c0dSopenharmony_ci+ * 15895be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 15896be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 15897be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15898be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 15899be168c0dSopenharmony_ci+ * limitations under the License. 15900be168c0dSopenharmony_ci+ */ 15901be168c0dSopenharmony_ci+ 15902be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h" 15903be168c0dSopenharmony_ci+#include <vector> 15904be168c0dSopenharmony_ci+#include "coder/log.h" 15905be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 15906be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 15907be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 15908be168c0dSopenharmony_ci+ 15909be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_MatMulFusion; 15910be168c0dSopenharmony_ci+ 15911be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 15912be168c0dSopenharmony_ci+int MatMulDynamicFP16Coder::InitAShape() { 15913be168c0dSopenharmony_ci+ auto a_shape = shape_info_container_->GetTemplateShape(input_tensor_); 15914be168c0dSopenharmony_ci+ auto a_shape_size = a_shape.size(); 15915be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(a_shape_size >= DIMENSION_2D, RET_NOT_SUPPORT, "Matmul's a_shape_size must be not less than two."); 15916be168c0dSopenharmony_ci+ int64_t const_part = 1; 15917be168c0dSopenharmony_ci+ std::string non_const_part; 15918be168c0dSopenharmony_ci+ for (size_t i = 0; i < a_shape_size - C2NUM; ++i) { 15919be168c0dSopenharmony_ci+ if (IsNumber(a_shape[i])) { 15920be168c0dSopenharmony_ci+ const_part *= std::atoi(a_shape[i].c_str()); 15921be168c0dSopenharmony_ci+ } else { 15922be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 15923be168c0dSopenharmony_ci+ non_const_part += " * "; 15924be168c0dSopenharmony_ci+ } 15925be168c0dSopenharmony_ci+ non_const_part += a_shape[i]; 15926be168c0dSopenharmony_ci+ } 15927be168c0dSopenharmony_ci+ } 15928be168c0dSopenharmony_ci+ dynamic_params_.batch_ = non_const_part + " * " + std::to_string(const_part); 15929be168c0dSopenharmony_ci+ dynamic_params_.row_ = params_->a_transpose_ ? a_shape[a_shape.size() - C1NUM] : a_shape[a_shape.size() - C2NUM]; 15930be168c0dSopenharmony_ci+ return RET_OK; 15931be168c0dSopenharmony_ci+} 15932be168c0dSopenharmony_ci+ 15933be168c0dSopenharmony_ci+int MatMulDynamicFP16Coder::InitBShape() { 15934be168c0dSopenharmony_ci+ std::vector<int> b_shape = filter_tensor_->shape(); 15935be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(b_shape.size() >= DIMENSION_2D, RET_NOT_SUPPORT, 15936be168c0dSopenharmony_ci+ "Matmul's b_shape_size must be not less than two."); 15937be168c0dSopenharmony_ci+ int batch = 1; 15938be168c0dSopenharmony_ci+ for (size_t i = 0; i < b_shape.size() - DIMENSION_2D; ++i) { 15939be168c0dSopenharmony_ci+ batch *= b_shape[i]; 15940be168c0dSopenharmony_ci+ } 15941be168c0dSopenharmony_ci+ if (batch != 1) { 15942be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Currently, Matmul only support matrixB's batch is 1."; 15943be168c0dSopenharmony_ci+ } 15944be168c0dSopenharmony_ci+ b_batch_ = batch; 15945be168c0dSopenharmony_ci+ params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - C2NUM] : b_shape[b_shape.size() - C1NUM]; 15946be168c0dSopenharmony_ci+ params_->col_8_ = UP_ROUND(params_->col_, C8NUM); 15947be168c0dSopenharmony_ci+ params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - C1NUM] : b_shape[b_shape.size() - C2NUM]; 15948be168c0dSopenharmony_ci+ return RET_OK; 15949be168c0dSopenharmony_ci+} 15950be168c0dSopenharmony_ci+ 15951be168c0dSopenharmony_ci+int MatMulDynamicFP16Coder::Prepare(CoderContext *const context) { 15952be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 15953be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 15954be168c0dSopenharmony_ci+ "Input tensor data type is invalid."); 15955be168c0dSopenharmony_ci+ } 15956be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 15957be168c0dSopenharmony_ci+ "Input tensor data type is invalid."); 15958be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_.size() == C2NUM || input_tensors_.size() == C3NUM, RET_INPUT_PARAM_INVALID, 15959be168c0dSopenharmony_ci+ "MatMul's input-num must be 2 or 3."); 15960be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 15961be168c0dSopenharmony_ci+ "Currently, only support the first input of matmul is non-const when shape is dynamical."); 15962be168c0dSopenharmony_ci+ if (input_tensors_.size() == C3NUM) { 15963be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[THIRD_INPUT]->IsConst(), RET_NOT_SUPPORT, 15964be168c0dSopenharmony_ci+ "Currently, only support the first input of matmul is non-const when shape is dynamical."); 15965be168c0dSopenharmony_ci+ } 15966be168c0dSopenharmony_ci+ params_ = reinterpret_cast<MatMulParameter *>(parameter_); 15967be168c0dSopenharmony_ci+ filter_tensor_ = input_tensors_.at(kWeightIndex); 15968be168c0dSopenharmony_ci+ MS_CHECK_PTR(filter_tensor_); 15969be168c0dSopenharmony_ci+ if (input_tensors_.size() == kInputSize2) { 15970be168c0dSopenharmony_ci+ bias_tensor_ = input_tensors_.at(kBiasIndex); 15971be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_tensor_); 15972be168c0dSopenharmony_ci+ MS_CHECK_PTR(bias_tensor_->data()); 15973be168c0dSopenharmony_ci+ } 15974be168c0dSopenharmony_ci+ params_->a_const_ = (input_tensor_->data() != nullptr); 15975be168c0dSopenharmony_ci+ params_->b_const_ = (filter_tensor_->data() != nullptr); 15976be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(MatMulDynamicFP16BaseCoder::Prepare(context), "MatMulDynamicFP16Coder prepare failed"); 15977be168c0dSopenharmony_ci+ return RET_OK; 15978be168c0dSopenharmony_ci+} 15979be168c0dSopenharmony_ci+ 15980be168c0dSopenharmony_ci+int MatMulDynamicFP16Coder::DoCode(CoderContext *const context) { return MatMulDynamicFP16BaseCoder::DoCode(context); } 15981be168c0dSopenharmony_ci+ 15982be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_MatMulFusion, 15983be168c0dSopenharmony_ci+ CPUOpCoderCreator<MatMulDynamicFP16Coder>) 15984be168c0dSopenharmony_ci+// REG_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_MatMulFusion, CPUOpCoderCreator<MatMulDynamicFP16Coder>) 15985be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 15986be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h 15987be168c0dSopenharmony_cinew file mode 100644 15988be168c0dSopenharmony_ciindex 00000000..1a16798c 15989be168c0dSopenharmony_ci--- /dev/null 15990be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_coder.h 15991be168c0dSopenharmony_ci@@ -0,0 +1,44 @@ 15992be168c0dSopenharmony_ci+/** 15993be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 15994be168c0dSopenharmony_ci+ * 15995be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 15996be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 15997be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 15998be168c0dSopenharmony_ci+ * 15999be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16000be168c0dSopenharmony_ci+ * 16001be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16002be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16003be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16004be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16005be168c0dSopenharmony_ci+ * limitations under the License. 16006be168c0dSopenharmony_ci+ */ 16007be168c0dSopenharmony_ci+ 16008be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_DYNAMIC_FP16_CODER_H_ 16009be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_DYNAMIC_FP16_CODER_H_ 16010be168c0dSopenharmony_ci+ 16011be168c0dSopenharmony_ci+#include <vector> 16012be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_dynamic_fp16_base_coder.h" 16013be168c0dSopenharmony_ci+#include "nnacl/matmul_parameter.h" 16014be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 16015be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 16016be168c0dSopenharmony_ci+ 16017be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16018be168c0dSopenharmony_ci+class MatMulDynamicFP16Coder final : public MatMulDynamicFP16BaseCoder { 16019be168c0dSopenharmony_ci+ public: 16020be168c0dSopenharmony_ci+ MatMulDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16021be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 16022be168c0dSopenharmony_ci+ : MatMulDynamicFP16BaseCoder(in_tensors, out_tensors, node, node_index, target) {} 16023be168c0dSopenharmony_ci+ 16024be168c0dSopenharmony_ci+ ~MatMulDynamicFP16Coder() override = default; 16025be168c0dSopenharmony_ci+ 16026be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 16027be168c0dSopenharmony_ci+ 16028be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 16029be168c0dSopenharmony_ci+ 16030be168c0dSopenharmony_ci+ private: 16031be168c0dSopenharmony_ci+ int InitAShape() override; 16032be168c0dSopenharmony_ci+ int InitBShape() override; 16033be168c0dSopenharmony_ci+}; 16034be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16035be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_MATMUL_DYNAMIC_FP16_CODER_H_ 16036be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 16037be168c0dSopenharmony_ciindex 67f633fe..415e912d 100644 16038be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 16039be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/matmul_fp16_base_coder.cc 16040be168c0dSopenharmony_ci@@ -102,14 +102,15 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN 16041be168c0dSopenharmony_ci if (a_batch_ == 1) { 16042be168c0dSopenharmony_ci if (params_.a_transpose_) { 16043be168c0dSopenharmony_ci if (target_ == kARM64) { 16044be168c0dSopenharmony_ci- pack_code.CodeFunction("RowMajor2RowNMajorFp16", input_a_str, input_a_pack_str, params_.deep_, params_.row_); 16045be168c0dSopenharmony_ci+ pack_code.CodeFunction("RowMajor2RowNMajorFp16", input_a_str, input_a_pack_str, params_.deep_, params_.row_, 16046be168c0dSopenharmony_ci+ "false"); 16047be168c0dSopenharmony_ci } else { 16048be168c0dSopenharmony_ci pack_code.CodeFunction("RowMajor2Row12MajorFp16", input_a_str, input_a_pack_str, params_.deep_, params_.row_, 16049be168c0dSopenharmony_ci false); 16050be168c0dSopenharmony_ci } 16051be168c0dSopenharmony_ci } else { 16052be168c0dSopenharmony_ci if (target_ == kARM64) { 16053be168c0dSopenharmony_ci- pack_code.CodeFunction("RowMajor2ColNMajorFp16", input_a_str, input_a_pack_str, params_.row_, params_.deep_); 16054be168c0dSopenharmony_ci+ pack_code.CodeFunction("RowMajor2ColNMajorFp16", input_a_str, input_a_pack_str, params_.row_, params_.deep_, false); 16055be168c0dSopenharmony_ci } else { 16056be168c0dSopenharmony_ci pack_code.CodeFunction("RowMajor2Col12MajorFp16", input_a_str, input_a_pack_str, params_.row_, params_.deep_, 16057be168c0dSopenharmony_ci false); 16058be168c0dSopenharmony_ci@@ -122,13 +123,13 @@ std::string MatMulFP16BaseCoder::InitMatrixA(NNaclFp32Serializer *const code, NN 16059be168c0dSopenharmony_ci << ";\n"; 16060be168c0dSopenharmony_ci if (params_.a_transpose_) { 16061be168c0dSopenharmony_ci if (target_ == kARM64) { 16062be168c0dSopenharmony_ci- pack_code << " RowMajor2RowNMajorFp16(src, dst, " << params_.deep_ << ", " << params_.row_ << ");\n"; 16063be168c0dSopenharmony_ci+ pack_code << " RowMajor2RowNMajorFp16(src, dst, " << params_.deep_ << ", " << params_.row_ << ", false);\n"; 16064be168c0dSopenharmony_ci } else { 16065be168c0dSopenharmony_ci pack_code << " RowMajor2Row12MajorFp16(src, dst, " << params_.deep_ << ", " << params_.row_ << ", false);\n"; 16066be168c0dSopenharmony_ci } 16067be168c0dSopenharmony_ci } else { 16068be168c0dSopenharmony_ci if (target_ == kARM64) { 16069be168c0dSopenharmony_ci- pack_code << " RowMajor2ColNMajorFp16(src, dst, " << params_.row_ << ", " << params_.deep_ << ");\n"; 16070be168c0dSopenharmony_ci+ pack_code << " RowMajor2ColNMajorFp16(src, dst, " << params_.row_ << ", " << params_.deep_ << ", false);\n"; 16071be168c0dSopenharmony_ci } else { 16072be168c0dSopenharmony_ci pack_code << " RowMajor2Col12MajorFp16(src, dst, " << params_.row_ << ", " << params_.deep_ << ", false);\n"; 16073be168c0dSopenharmony_ci } 16074be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc 16075be168c0dSopenharmony_cinew file mode 100644 16076be168c0dSopenharmony_ciindex 00000000..c565f5b2 16077be168c0dSopenharmony_ci--- /dev/null 16078be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.cc 16079be168c0dSopenharmony_ci@@ -0,0 +1,89 @@ 16080be168c0dSopenharmony_ci+/** 16081be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16082be168c0dSopenharmony_ci+ * 16083be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16084be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16085be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16086be168c0dSopenharmony_ci+ * 16087be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16088be168c0dSopenharmony_ci+ * 16089be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16090be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16091be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16092be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16093be168c0dSopenharmony_ci+ * limitations under the License. 16094be168c0dSopenharmony_ci+ */ 16095be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h" 16096be168c0dSopenharmony_ci+#include <cfloat> 16097be168c0dSopenharmony_ci+#include <string> 16098be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16099be168c0dSopenharmony_ci+#include "coder/log.h" 16100be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 16101be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 16102be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 16103be168c0dSopenharmony_ci+ 16104be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_AvgPoolFusion; 16105be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_MaxPoolFusion; 16106be168c0dSopenharmony_ci+ 16107be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16108be168c0dSopenharmony_ci+int PoolingDynamicFP16Coder::Prepare(CoderContext *const context) { 16109be168c0dSopenharmony_ci+ if (input_tensor_->data_type() != kNumberTypeFloat16 || output_tensor_->data_type() != kNumberTypeFloat16) { 16110be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Tensor data type is invalid"; 16111be168c0dSopenharmony_ci+ return lite::RET_INPUT_PARAM_INVALID; 16112be168c0dSopenharmony_ci+ } 16113be168c0dSopenharmony_ci+ param_ = reinterpret_cast<PoolingParameter *>(parameter_); 16114be168c0dSopenharmony_ci+ MS_CHECK_PTR(param_); 16115be168c0dSopenharmony_ci+ dynamic_param_.input_batch_ = shape_info_container_->GetTemplateShape(input_tensor_)[0]; 16116be168c0dSopenharmony_ci+ compute_.input_channel_ = input_tensor_->Channel(); 16117be168c0dSopenharmony_ci+ compute_.input_h_ = input_tensor_->Height(); 16118be168c0dSopenharmony_ci+ compute_.input_w_ = input_tensor_->Width(); 16119be168c0dSopenharmony_ci+ dynamic_param_.output_batch_ = shape_info_container_->GetTemplateShape(output_tensor_)[0]; 16120be168c0dSopenharmony_ci+ compute_.output_channel_ = output_tensor_->Channel(); 16121be168c0dSopenharmony_ci+ compute_.output_h_ = output_tensor_->Height(); 16122be168c0dSopenharmony_ci+ compute_.output_w_ = output_tensor_->Width(); 16123be168c0dSopenharmony_ci+ if (param_->global_) { 16124be168c0dSopenharmony_ci+ param_->window_h_ = compute_.input_h_; 16125be168c0dSopenharmony_ci+ param_->window_w_ = compute_.input_w_; 16126be168c0dSopenharmony_ci+ } 16127be168c0dSopenharmony_ci+ return RET_OK; 16128be168c0dSopenharmony_ci+} 16129be168c0dSopenharmony_ci+ 16130be168c0dSopenharmony_ci+int PoolingDynamicFP16Coder::DoCode(CoderContext *const context) { 16131be168c0dSopenharmony_ci+ Collect(context, 16132be168c0dSopenharmony_ci+ { 16133be168c0dSopenharmony_ci+ "nnacl/fp16/pooling_fp16.h", 16134be168c0dSopenharmony_ci+ }, 16135be168c0dSopenharmony_ci+ { 16136be168c0dSopenharmony_ci+ "pooling_fp16.c", 16137be168c0dSopenharmony_ci+ }); 16138be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 16139be168c0dSopenharmony_ci+ code.CodeStruct("pooling_parameter", *param_); 16140be168c0dSopenharmony_ci+ code.CodeStruct("pooling_compute", compute_, dynamic_param_); 16141be168c0dSopenharmony_ci+ 16142be168c0dSopenharmony_ci+ auto input_data = 16143be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16144be168c0dSopenharmony_ci+ auto output_data = 16145be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16146be168c0dSopenharmony_ci+ if (param_->pool_mode_ == PoolMode_MaxPool) { 16147be168c0dSopenharmony_ci+ code.CodeFunction("MaxPoolingFp16", input_data, output_data, "&pooling_parameter", "&pooling_compute", 16148be168c0dSopenharmony_ci+ kDefaultTaskId, param_->op_parameter_.thread_num_); 16149be168c0dSopenharmony_ci+ } else if (param_->pool_mode_ == PoolMode_AvgPool) { 16150be168c0dSopenharmony_ci+ code.CodeFunction("AvgPoolingFp16", input_data, output_data, "&pooling_parameter", "&pooling_compute", 16151be168c0dSopenharmony_ci+ kDefaultTaskId, param_->op_parameter_.thread_num_); 16152be168c0dSopenharmony_ci+ } else { 16153be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported pooling mode."; 16154be168c0dSopenharmony_ci+ return lite::RET_ERROR; 16155be168c0dSopenharmony_ci+ } 16156be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 16157be168c0dSopenharmony_ci+ return lite::RET_OK; 16158be168c0dSopenharmony_ci+} 16159be168c0dSopenharmony_ci+ 16160be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_AvgPoolFusion, 16161be168c0dSopenharmony_ci+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16162be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_AvgPoolFusion, 16163be168c0dSopenharmony_ci+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16164be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_MaxPoolFusion, 16165be168c0dSopenharmony_ci+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16166be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_MaxPoolFusion, 16167be168c0dSopenharmony_ci+ CPUOpCoderCreator<PoolingDynamicFP16Coder>) 16168be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16169be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h 16170be168c0dSopenharmony_cinew file mode 100644 16171be168c0dSopenharmony_ciindex 00000000..7b138b61 16172be168c0dSopenharmony_ci--- /dev/null 16173be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/pooling_dynamic_fp16_coder.h 16174be168c0dSopenharmony_ci@@ -0,0 +1,44 @@ 16175be168c0dSopenharmony_ci+/** 16176be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16177be168c0dSopenharmony_ci+ * 16178be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16179be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16180be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16181be168c0dSopenharmony_ci+ * 16182be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16183be168c0dSopenharmony_ci+ * 16184be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16185be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16186be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16187be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16188be168c0dSopenharmony_ci+ * limitations under the License. 16189be168c0dSopenharmony_ci+ */ 16190be168c0dSopenharmony_ci+ 16191be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_POOLING_DYNAMIC_FP16_CODER_H_ 16192be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_POOLING_DYNAMIC_FP16_CODER_H_ 16193be168c0dSopenharmony_ci+ 16194be168c0dSopenharmony_ci+#include <vector> 16195be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 16196be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h" 16197be168c0dSopenharmony_ci+#include "nnacl/pooling_parameter.h" 16198be168c0dSopenharmony_ci+#include "nnacl/kernel/pooling.h" 16199be168c0dSopenharmony_ci+ 16200be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16201be168c0dSopenharmony_ci+class PoolingDynamicFP16Coder final : public OperatorCoder { 16202be168c0dSopenharmony_ci+ public: 16203be168c0dSopenharmony_ci+ PoolingDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16204be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 16205be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16206be168c0dSopenharmony_ci+ ~PoolingDynamicFP16Coder() override = default; 16207be168c0dSopenharmony_ci+ 16208be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 16209be168c0dSopenharmony_ci+ 16210be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 16211be168c0dSopenharmony_ci+ 16212be168c0dSopenharmony_ci+ private: 16213be168c0dSopenharmony_ci+ PoolingParameter *param_{nullptr}; 16214be168c0dSopenharmony_ci+ PoolingComputeParam compute_; 16215be168c0dSopenharmony_ci+ PoolingDynamicParameter dynamic_param_; 16216be168c0dSopenharmony_ci+}; 16217be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16218be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_POOLING_DYNAMIC_FP16_CODER_H_ 16219be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc 16220be168c0dSopenharmony_cinew file mode 100644 16221be168c0dSopenharmony_ciindex 00000000..733cf49d 16222be168c0dSopenharmony_ci--- /dev/null 16223be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.cc 16224be168c0dSopenharmony_ci@@ -0,0 +1,128 @@ 16225be168c0dSopenharmony_ci+/** 16226be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16227be168c0dSopenharmony_ci+ * 16228be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16229be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16230be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16231be168c0dSopenharmony_ci+ * 16232be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16233be168c0dSopenharmony_ci+ * 16234be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16235be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16236be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16237be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16238be168c0dSopenharmony_ci+ * limitations under the License. 16239be168c0dSopenharmony_ci+ */ 16240be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h" 16241be168c0dSopenharmony_ci+#include <string> 16242be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16243be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 16244be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 16245be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 16246be168c0dSopenharmony_ci+ 16247be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_ScaleFusion; 16248be168c0dSopenharmony_ci+ 16249be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16250be168c0dSopenharmony_ci+int ScaleDynamicFP16Coder::Prepare(CoderContext *const context) { 16251be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 16252be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16253be168c0dSopenharmony_ci+ "Input tensor data type should be fp16, now is " << input_tensors_[i]->data_type()); 16254be168c0dSopenharmony_ci+ } 16255be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16256be168c0dSopenharmony_ci+ "Output tensor data type should be fp16, now is " << output_tensor_->data_type()); 16257be168c0dSopenharmony_ci+ 16258be168c0dSopenharmony_ci+ scale_param_ = reinterpret_cast<ScaleParameter *>(parameter_); 16259be168c0dSopenharmony_ci+ MS_CHECK_PTR(scale_param_); 16260be168c0dSopenharmony_ci+ scale_struct_.base_.param_ = parameter_; 16261be168c0dSopenharmony_ci+ if (input_tensors_.size() < DIMENSION_2D || input_tensors_.size() > DIMENSION_3D) { 16262be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "inputs to Scale operator should be 2 or 3, but " << input_tensors_.size() << " is given."; 16263be168c0dSopenharmony_ci+ return RET_ERROR; 16264be168c0dSopenharmony_ci+ } 16265be168c0dSopenharmony_ci+ scale_tensor_ = input_tensors_.at(kWeightIndex); 16266be168c0dSopenharmony_ci+ MS_CHECK_PTR(scale_tensor_); 16267be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(CalculateParameter(), "Scale fp16 CalculateParameter failed."); 16268be168c0dSopenharmony_ci+ return RET_OK; 16269be168c0dSopenharmony_ci+} 16270be168c0dSopenharmony_ci+ 16271be168c0dSopenharmony_ci+int ScaleDynamicFP16Coder::DoCode(CoderContext *const context) { 16272be168c0dSopenharmony_ci+ // init struct ScaleParameters 16273be168c0dSopenharmony_ci+ Collect(context, 16274be168c0dSopenharmony_ci+ { 16275be168c0dSopenharmony_ci+ "nnacl/kernel/scale.h", 16276be168c0dSopenharmony_ci+ "nnacl/fp16/scale_fp16.h", 16277be168c0dSopenharmony_ci+ }, 16278be168c0dSopenharmony_ci+ { 16279be168c0dSopenharmony_ci+ "scale_fp16.c", 16280be168c0dSopenharmony_ci+ }); 16281be168c0dSopenharmony_ci+ 16282be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 16283be168c0dSopenharmony_ci+ code.CodeStruct("scale_struct", scale_struct_, dynamic_param_); 16284be168c0dSopenharmony_ci+ 16285be168c0dSopenharmony_ci+ auto scale = GetTensorAddr(scale_tensor_, scale_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 16286be168c0dSopenharmony_ci+ std::string offset{"NULL"}; 16287be168c0dSopenharmony_ci+ if (input_tensors_.size() == DIMENSION_3D) { 16288be168c0dSopenharmony_ci+ auto offset_tensor = input_tensors_.at(kBiasIndex); 16289be168c0dSopenharmony_ci+ offset = GetTensorAddr(offset_tensor, offset_tensor->IsConst(), dynamic_mem_manager_, allocator_); 16290be168c0dSopenharmony_ci+ } 16291be168c0dSopenharmony_ci+ std::string input_str = 16292be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16293be168c0dSopenharmony_ci+ std::string output_str = 16294be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16295be168c0dSopenharmony_ci+ switch (scale_param_->activation_type_) { 16296be168c0dSopenharmony_ci+ case schema::ActivationType_RELU6: 16297be168c0dSopenharmony_ci+ code.CodeFunction("DoScaleRelu6Fp16", input_str, output_str, scale, offset, kDefaultTaskId, "&scale_struct"); 16298be168c0dSopenharmony_ci+ break; 16299be168c0dSopenharmony_ci+ case schema::ActivationType_RELU: 16300be168c0dSopenharmony_ci+ code.CodeFunction("Fp16DoScaleRelu", input_str, output_str, scale, offset, kDefaultTaskId, "&scale_struct"); 16301be168c0dSopenharmony_ci+ break; 16302be168c0dSopenharmony_ci+ case schema::ActivationType_NO_ACTIVATION: 16303be168c0dSopenharmony_ci+ code.CodeFunction("DoScaleFp16", input_str, output_str, scale, offset, kDefaultTaskId, "&scale_struct"); 16304be168c0dSopenharmony_ci+ break; 16305be168c0dSopenharmony_ci+ default: 16306be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; 16307be168c0dSopenharmony_ci+ return RET_ERROR; 16308be168c0dSopenharmony_ci+ } 16309be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 16310be168c0dSopenharmony_ci+ return RET_OK; 16311be168c0dSopenharmony_ci+} 16312be168c0dSopenharmony_ci+ 16313be168c0dSopenharmony_ci+int ScaleDynamicFP16Coder::CalculateParameter() { 16314be168c0dSopenharmony_ci+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 16315be168c0dSopenharmony_ci+ std::vector<std::string> scale_shape; 16316be168c0dSopenharmony_ci+ if (scale_tensor_->IsConst()) { 16317be168c0dSopenharmony_ci+ for (auto dim : scale_tensor_->shape()) { 16318be168c0dSopenharmony_ci+ scale_shape.emplace_back(std::to_string(dim)); 16319be168c0dSopenharmony_ci+ } 16320be168c0dSopenharmony_ci+ } else { 16321be168c0dSopenharmony_ci+ scale_shape = shape_info_container_->GetTemplateShape(scale_tensor_); 16322be168c0dSopenharmony_ci+ } 16323be168c0dSopenharmony_ci+ if (scale_param_->axis_ < 0) { 16324be168c0dSopenharmony_ci+ scale_struct_.axis_ = scale_param_->axis_ + in_shape.size(); 16325be168c0dSopenharmony_ci+ } 16326be168c0dSopenharmony_ci+ if (scale_shape.size() + scale_struct_.axis_ > in_shape.size()) { 16327be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Scale tensor shape is incorrect."; 16328be168c0dSopenharmony_ci+ return RET_ERROR; 16329be168c0dSopenharmony_ci+ } 16330be168c0dSopenharmony_ci+ dynamic_param_.outer_size_ = AccumulateShape(in_shape, 0, scale_struct_.axis_); 16331be168c0dSopenharmony_ci+ if (scale_tensor_->IsConst() && scale_tensor_->shape().size() == 1) { 16332be168c0dSopenharmony_ci+ dynamic_param_.axis_size_ = in_shape.at(scale_struct_.axis_); 16333be168c0dSopenharmony_ci+ } else { 16334be168c0dSopenharmony_ci+ dynamic_param_.axis_size_ = "{"; 16335be168c0dSopenharmony_ci+ for (size_t i = 0; i < scale_shape.size(); i++) { 16336be168c0dSopenharmony_ci+ if (in_shape.at(i + scale_struct_.axis_) != scale_shape.at(i)) { 16337be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Scale tensor shape is incorrect."; 16338be168c0dSopenharmony_ci+ return RET_ERROR; 16339be168c0dSopenharmony_ci+ } 16340be168c0dSopenharmony_ci+ dynamic_param_.axis_size_ += in_shape.at(i + scale_struct_.axis_) + ", "; 16341be168c0dSopenharmony_ci+ } 16342be168c0dSopenharmony_ci+ dynamic_param_.axis_size_ += "}"; 16343be168c0dSopenharmony_ci+ } 16344be168c0dSopenharmony_ci+ dynamic_param_.inner_size_ = AccumulateShape(in_shape, scale_struct_.axis_ + scale_shape.size(), in_shape.size()); 16345be168c0dSopenharmony_ci+ return RET_OK; 16346be168c0dSopenharmony_ci+} 16347be168c0dSopenharmony_ci+ 16348be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_ScaleFusion, 16349be168c0dSopenharmony_ci+ CPUOpCoderCreator<ScaleDynamicFP16Coder>) 16350be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_ScaleFusion, 16351be168c0dSopenharmony_ci+ CPUOpCoderCreator<ScaleDynamicFP16Coder>) 16352be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16353be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h 16354be168c0dSopenharmony_cinew file mode 100644 16355be168c0dSopenharmony_ciindex 00000000..02ec35ba 16356be168c0dSopenharmony_ci--- /dev/null 16357be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/scale_dynamic_fp16_coder.h 16358be168c0dSopenharmony_ci@@ -0,0 +1,46 @@ 16359be168c0dSopenharmony_ci+/** 16360be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16361be168c0dSopenharmony_ci+ * 16362be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16363be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16364be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16365be168c0dSopenharmony_ci+ * 16366be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16367be168c0dSopenharmony_ci+ * 16368be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16369be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16370be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16371be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16372be168c0dSopenharmony_ci+ * limitations under the License. 16373be168c0dSopenharmony_ci+ */ 16374be168c0dSopenharmony_ci+ 16375be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SCALE_DYNAMIC_FP16_CODER_H_ 16376be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SCALE_DYNAMIC_FP16_CODER_H_ 16377be168c0dSopenharmony_ci+ 16378be168c0dSopenharmony_ci+#include <vector> 16379be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 16380be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h" 16381be168c0dSopenharmony_ci+#include "nnacl/kernel/scale.h" 16382be168c0dSopenharmony_ci+#include "nnacl/scale_parameter.h" 16383be168c0dSopenharmony_ci+ 16384be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16385be168c0dSopenharmony_ci+class ScaleDynamicFP16Coder final : public OperatorCoder { 16386be168c0dSopenharmony_ci+ public: 16387be168c0dSopenharmony_ci+ ScaleDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16388be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 16389be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16390be168c0dSopenharmony_ci+ ~ScaleDynamicFP16Coder() override = default; 16391be168c0dSopenharmony_ci+ 16392be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 16393be168c0dSopenharmony_ci+ 16394be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 16395be168c0dSopenharmony_ci+ 16396be168c0dSopenharmony_ci+ private: 16397be168c0dSopenharmony_ci+ int CalculateParameter(); 16398be168c0dSopenharmony_ci+ ScaleParameter *scale_param_{nullptr}; 16399be168c0dSopenharmony_ci+ ScaleStruct scale_struct_; 16400be168c0dSopenharmony_ci+ ScaleDynamicParameter dynamic_param_; 16401be168c0dSopenharmony_ci+ Tensor *scale_tensor_{nullptr}; 16402be168c0dSopenharmony_ci+}; 16403be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16404be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SCALE_DYNAMIC_FP16_CODER_H_ 16405be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc 16406be168c0dSopenharmony_cinew file mode 100644 16407be168c0dSopenharmony_ciindex 00000000..1c6969b2 16408be168c0dSopenharmony_ci--- /dev/null 16409be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.cc 16410be168c0dSopenharmony_ci@@ -0,0 +1,160 @@ 16411be168c0dSopenharmony_ci+/** 16412be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16413be168c0dSopenharmony_ci+ * 16414be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16415be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16416be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16417be168c0dSopenharmony_ci+ * 16418be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16419be168c0dSopenharmony_ci+ * 16420be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16421be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16422be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16423be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16424be168c0dSopenharmony_ci+ * limitations under the License. 16425be168c0dSopenharmony_ci+ */ 16426be168c0dSopenharmony_ci+ 16427be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h" 16428be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 16429be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16430be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 16431be168c0dSopenharmony_ci+ 16432be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_SliceFusion; 16433be168c0dSopenharmony_ci+ 16434be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16435be168c0dSopenharmony_ci+int SliceDynamicFP16Coder::Prepare(CoderContext *const context) { 16436be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(input_tensors_.size(), C3NUM); 16437be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(output_tensors_.size(), 1); 16438be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input_tensors_[FIRST_INPUT]); 16439be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input_tensors_[SECOND_INPUT]); 16440be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input_tensors_[THIRD_INPUT]); 16441be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output_tensor_); 16442be168c0dSopenharmony_ci+ param_ = reinterpret_cast<SliceParameter *>(parameter_); 16443be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(param_); 16444be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst() && input_tensors_[THIRD_INPUT]->IsConst(), RET_NOT_SUPPORT, 16445be168c0dSopenharmony_ci+ "The second and third input of slice is non-const."); 16446be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32 && 16447be168c0dSopenharmony_ci+ input_tensors_[THIRD_INPUT]->data_type() == kNumberTypeInt32, 16448be168c0dSopenharmony_ci+ RET_INPUT_PARAM_INVALID, "second or third input tensor data type need to be int32."); 16449be168c0dSopenharmony_ci+ if (input_tensor_->data_type() != kNumberTypeFloat16 || output_tensor_->data_type() != kNumberTypeFloat16) { 16450be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Tensor data type is invalid"; 16451be168c0dSopenharmony_ci+ return lite::RET_INPUT_PARAM_INVALID; 16452be168c0dSopenharmony_ci+ } 16453be168c0dSopenharmony_ci+ return Init(); 16454be168c0dSopenharmony_ci+} 16455be168c0dSopenharmony_ci+ 16456be168c0dSopenharmony_ci+int SliceDynamicFP16Coder::DoCode(CoderContext *const context) { 16457be168c0dSopenharmony_ci+ Collect(context, 16458be168c0dSopenharmony_ci+ { 16459be168c0dSopenharmony_ci+ "nnacl/base/slice_base.h", 16460be168c0dSopenharmony_ci+ }, 16461be168c0dSopenharmony_ci+ { 16462be168c0dSopenharmony_ci+ "slice_base.c", 16463be168c0dSopenharmony_ci+ }); 16464be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 16465be168c0dSopenharmony_ci+ code.CodeStruct("slice_param", *param_, dynamic_param_); 16466be168c0dSopenharmony_ci+ std::string input_data = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 16467be168c0dSopenharmony_ci+ std::string output_data = GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 16468be168c0dSopenharmony_ci+ if (!support_parallel_) { 16469be168c0dSopenharmony_ci+ code.CodeFunction("DoSliceNoParallel", input_data, output_data, "&slice_param", 16470be168c0dSopenharmony_ci+ DataTypeSize(input_tensor_->data_type())); 16471be168c0dSopenharmony_ci+ } 16472be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 16473be168c0dSopenharmony_ci+ return NNACL_OK; 16474be168c0dSopenharmony_ci+} 16475be168c0dSopenharmony_ci+ 16476be168c0dSopenharmony_ci+int SliceDynamicFP16Coder::Init() { 16477be168c0dSopenharmony_ci+ auto begin_tensor = input_tensors_[SECOND_INPUT]; 16478be168c0dSopenharmony_ci+ auto size_tensor = input_tensors_[THIRD_INPUT]; 16479be168c0dSopenharmony_ci+ data_shape_ = shape_info_container_->GetTemplateShape(input_tensor_); 16480be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(data_shape_.size() == static_cast<size_t>(begin_tensor->ElementsNum()), RET_ERROR, 16481be168c0dSopenharmony_ci+ "The begin tensor is invalid."); 16482be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(data_shape_.size() == static_cast<size_t>(size_tensor->ElementsNum()), RET_ERROR, 16483be168c0dSopenharmony_ci+ "The size tensor is invalid."); 16484be168c0dSopenharmony_ci+ auto begin = reinterpret_cast<int32_t *>(begin_tensor->data()); 16485be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(begin); 16486be168c0dSopenharmony_ci+ auto size = reinterpret_cast<int32_t *>(size_tensor->data()); 16487be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(size); 16488be168c0dSopenharmony_ci+ param_->param_length_ = static_cast<int>(data_shape_.size()); 16489be168c0dSopenharmony_ci+ if (param_->param_length_ > DIMENSION_8D) { 16490be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_8D; 16491be168c0dSopenharmony_ci+ return RET_ERROR; 16492be168c0dSopenharmony_ci+ } 16493be168c0dSopenharmony_ci+ dynamic_param_.shape_ = "{"; 16494be168c0dSopenharmony_ci+ dynamic_param_.size_ = "{"; 16495be168c0dSopenharmony_ci+ dynamic_param_.end_ = "{"; 16496be168c0dSopenharmony_ci+ for (int i = 0; i < param_->param_length_; ++i) { 16497be168c0dSopenharmony_ci+ dynamic_param_.shape_ += data_shape_[i] + ", "; 16498be168c0dSopenharmony_ci+ param_->begin_[i] = begin[i]; 16499be168c0dSopenharmony_ci+ if (size[i] < 0) { 16500be168c0dSopenharmony_ci+ std::string cur_size = data_shape_[i] + " - " + std::to_string(begin[i]); 16501be168c0dSopenharmony_ci+ slice_size_.emplace_back(cur_size); 16502be168c0dSopenharmony_ci+ dynamic_param_.size_ += cur_size + ", "; 16503be168c0dSopenharmony_ci+ } else { 16504be168c0dSopenharmony_ci+ slice_size_.emplace_back(std::to_string(size[i])); 16505be168c0dSopenharmony_ci+ dynamic_param_.size_ += std::to_string(size[i]) + ", "; 16506be168c0dSopenharmony_ci+ } 16507be168c0dSopenharmony_ci+ std::string cur_end = std::to_string(param_->begin_[i]) + " + " + slice_size_[i]; 16508be168c0dSopenharmony_ci+ end_.emplace_back(cur_end); 16509be168c0dSopenharmony_ci+ dynamic_param_.end_ += cur_end + ", "; 16510be168c0dSopenharmony_ci+ } 16511be168c0dSopenharmony_ci+ dynamic_param_.shape_ += "}"; 16512be168c0dSopenharmony_ci+ dynamic_param_.size_ += "}"; 16513be168c0dSopenharmony_ci+ dynamic_param_.end_ += "}"; 16514be168c0dSopenharmony_ci+ if (param_->param_length_ < DIMENSION_8D) { 16515be168c0dSopenharmony_ci+ PadSliceParameterTo8D(); 16516be168c0dSopenharmony_ci+ } 16517be168c0dSopenharmony_ci+ return RET_OK; 16518be168c0dSopenharmony_ci+} 16519be168c0dSopenharmony_ci+ 16520be168c0dSopenharmony_ci+void SliceDynamicFP16Coder::PadSliceParameterTo8D() { 16521be168c0dSopenharmony_ci+ std::vector<int32_t> begin(DIMENSION_8D, 0); 16522be168c0dSopenharmony_ci+ std::vector<std::string> end(DIMENSION_8D, ""); 16523be168c0dSopenharmony_ci+ std::vector<std::string> slice_size(DIMENSION_8D, ""); 16524be168c0dSopenharmony_ci+ std::vector<std::string> data_shape(DIMENSION_8D, ""); 16525be168c0dSopenharmony_ci+ for (int32_t i = 0; i < param_->param_length_; ++i) { 16526be168c0dSopenharmony_ci+ begin[i] = param_->begin_[i]; 16527be168c0dSopenharmony_ci+ end[i] = end_[i]; 16528be168c0dSopenharmony_ci+ slice_size[i] = 16529be168c0dSopenharmony_ci+ slice_size_[i] + " < 0 ? " + data_shape[i] + " - " + std::to_string(begin[i]) + " : " + slice_size_[i]; 16530be168c0dSopenharmony_ci+ data_shape[i] = data_shape_[i]; 16531be168c0dSopenharmony_ci+ } 16532be168c0dSopenharmony_ci+ data_shape_.resize(DIMENSION_8D); 16533be168c0dSopenharmony_ci+ slice_size_.resize(DIMENSION_8D); 16534be168c0dSopenharmony_ci+ end_.resize(DIMENSION_8D); 16535be168c0dSopenharmony_ci+ int32_t real_index = param_->param_length_ - 1; 16536be168c0dSopenharmony_ci+ for (int32_t i = DIMENSION_8D - 1; i >= 0; --i) { 16537be168c0dSopenharmony_ci+ if (real_index >= 0) { 16538be168c0dSopenharmony_ci+ param_->begin_[i] = begin[real_index]; 16539be168c0dSopenharmony_ci+ end_[i] = end[real_index]; 16540be168c0dSopenharmony_ci+ slice_size_[i] = slice_size[real_index]; 16541be168c0dSopenharmony_ci+ data_shape_[i] = data_shape[real_index--]; 16542be168c0dSopenharmony_ci+ } else { 16543be168c0dSopenharmony_ci+ param_->begin_[i] = 0; 16544be168c0dSopenharmony_ci+ end_[i] = "1"; 16545be168c0dSopenharmony_ci+ slice_size_[i] = "1"; 16546be168c0dSopenharmony_ci+ data_shape_[i] = "1"; 16547be168c0dSopenharmony_ci+ } 16548be168c0dSopenharmony_ci+ } 16549be168c0dSopenharmony_ci+ param_->param_length_ = DIMENSION_8D; 16550be168c0dSopenharmony_ci+ dynamic_param_.shape_.clear(); 16551be168c0dSopenharmony_ci+ dynamic_param_.size_.clear(); 16552be168c0dSopenharmony_ci+ dynamic_param_.end_.clear(); 16553be168c0dSopenharmony_ci+ dynamic_param_.shape_ = "{"; 16554be168c0dSopenharmony_ci+ dynamic_param_.size_ = "{"; 16555be168c0dSopenharmony_ci+ dynamic_param_.end_ = "{"; 16556be168c0dSopenharmony_ci+ for (int i = 0; i < DIMENSION_8D; ++i) { 16557be168c0dSopenharmony_ci+ dynamic_param_.end_ += end_[i] + ", "; 16558be168c0dSopenharmony_ci+ dynamic_param_.size_ += slice_size_[i] + ", "; 16559be168c0dSopenharmony_ci+ dynamic_param_.shape_ += data_shape_[i] + ", "; 16560be168c0dSopenharmony_ci+ } 16561be168c0dSopenharmony_ci+ dynamic_param_.shape_ += "}"; 16562be168c0dSopenharmony_ci+ dynamic_param_.size_ += "}"; 16563be168c0dSopenharmony_ci+ dynamic_param_.end_ += "}"; 16564be168c0dSopenharmony_ci+} 16565be168c0dSopenharmony_ci+ 16566be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_SliceFusion, 16567be168c0dSopenharmony_ci+ CPUOpCoderCreator<SliceDynamicFP16Coder>) 16568be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_SliceFusion, 16569be168c0dSopenharmony_ci+ CPUOpCoderCreator<SliceDynamicFP16Coder>) 16570be168c0dSopenharmony_ci+}; // namespace mindspore::lite::micro::nnacl 16571be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h 16572be168c0dSopenharmony_cinew file mode 100644 16573be168c0dSopenharmony_ciindex 00000000..21b1b27b 16574be168c0dSopenharmony_ci--- /dev/null 16575be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/slice_dynamic_fp16_coder.h 16576be168c0dSopenharmony_ci@@ -0,0 +1,51 @@ 16577be168c0dSopenharmony_ci+/** 16578be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16579be168c0dSopenharmony_ci+ * 16580be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16581be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16582be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16583be168c0dSopenharmony_ci+ * 16584be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16585be168c0dSopenharmony_ci+ * 16586be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16587be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16588be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16589be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16590be168c0dSopenharmony_ci+ * limitations under the License. 16591be168c0dSopenharmony_ci+ */ 16592be168c0dSopenharmony_ci+ 16593be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SLICE_DYNAMIC_FP16_CODER_H_ 16594be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SLICE_DYNAMIC_FP16_CODER_H_ 16595be168c0dSopenharmony_ci+ 16596be168c0dSopenharmony_ci+#include <vector> 16597be168c0dSopenharmony_ci+#include "mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h" 16598be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" 16599be168c0dSopenharmony_ci+#include "nnacl/slice_parameter.h" 16600be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 16601be168c0dSopenharmony_ci+ 16602be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16603be168c0dSopenharmony_ci+class SliceDynamicFP16Coder final : public OperatorCoder { 16604be168c0dSopenharmony_ci+ public: 16605be168c0dSopenharmony_ci+ SliceDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16606be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 16607be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16608be168c0dSopenharmony_ci+ 16609be168c0dSopenharmony_ci+ ~SliceDynamicFP16Coder() override = default; 16610be168c0dSopenharmony_ci+ 16611be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 16612be168c0dSopenharmony_ci+ 16613be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 16614be168c0dSopenharmony_ci+ 16615be168c0dSopenharmony_ci+ protected: 16616be168c0dSopenharmony_ci+ int Init(); 16617be168c0dSopenharmony_ci+ void PadSliceParameterTo8D(); 16618be168c0dSopenharmony_ci+ SliceParameter *param_{nullptr}; 16619be168c0dSopenharmony_ci+ SliceDynamicParameter dynamic_param_; 16620be168c0dSopenharmony_ci+ std::vector<std::string> in_shapes_; 16621be168c0dSopenharmony_ci+ std::vector<std::string> out_shapes_; 16622be168c0dSopenharmony_ci+ std::vector<std::string> data_shape_; 16623be168c0dSopenharmony_ci+ std::vector<std::string> slice_size_; 16624be168c0dSopenharmony_ci+ std::vector<std::string> end_; 16625be168c0dSopenharmony_ci+}; 16626be168c0dSopenharmony_ci+}; // namespace mindspore::lite::micro::nnacl 16627be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SLICE_DYNAMIC_FP16_CODER_H_ 16628be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc 16629be168c0dSopenharmony_cinew file mode 100644 16630be168c0dSopenharmony_ciindex 00000000..1bd09fb5 16631be168c0dSopenharmony_ci--- /dev/null 16632be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.cc 16633be168c0dSopenharmony_ci@@ -0,0 +1,137 @@ 16634be168c0dSopenharmony_ci+/** 16635be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16636be168c0dSopenharmony_ci+ * 16637be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16638be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16639be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16640be168c0dSopenharmony_ci+ * 16641be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16642be168c0dSopenharmony_ci+ * 16643be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16644be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16645be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16646be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16647be168c0dSopenharmony_ci+ * limitations under the License. 16648be168c0dSopenharmony_ci+ */ 16649be168c0dSopenharmony_ci+ 16650be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h" 16651be168c0dSopenharmony_ci+#include <string> 16652be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16653be168c0dSopenharmony_ci+#include "schema/inner/ops_generated.h" 16654be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 16655be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 16656be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 16657be168c0dSopenharmony_ci+#include "base/float16.h" 16658be168c0dSopenharmony_ci+ 16659be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_LogSoftmax; 16660be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Softmax; 16661be168c0dSopenharmony_ci+ 16662be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16663be168c0dSopenharmony_ci+int SoftmaxDynamicFP16Coder::Prepare(CoderContext *const context) { 16664be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors_.size(); ++i) { 16665be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16666be168c0dSopenharmony_ci+ "Input tensor data type is invalid"); 16667be168c0dSopenharmony_ci+ } 16668be168c0dSopenharmony_ci+ for (size_t i = 0; i < output_tensors_.size(); ++i) { 16669be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(output_tensors_[i]->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16670be168c0dSopenharmony_ci+ "Output tensor data type is invalid"); 16671be168c0dSopenharmony_ci+ } 16672be168c0dSopenharmony_ci+ auto ret = Init(); 16673be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(ret, "Init failed!"); 16674be168c0dSopenharmony_ci+ return RET_OK; 16675be168c0dSopenharmony_ci+} 16676be168c0dSopenharmony_ci+ 16677be168c0dSopenharmony_ci+int SoftmaxDynamicFP16Coder::DoCode(CoderContext *const context) { 16678be168c0dSopenharmony_ci+ Collect(context, 16679be168c0dSopenharmony_ci+ { 16680be168c0dSopenharmony_ci+ "nnacl/fp16/softmax_fp16.h", 16681be168c0dSopenharmony_ci+ "nnacl/fp16/log_softmax_fp16.h", 16682be168c0dSopenharmony_ci+ }, 16683be168c0dSopenharmony_ci+ { 16684be168c0dSopenharmony_ci+ "softmax_fp16.c", 16685be168c0dSopenharmony_ci+ "log_softmax_fp16.c", 16686be168c0dSopenharmony_ci+ "exp_fp16.c", 16687be168c0dSopenharmony_ci+ }); 16688be168c0dSopenharmony_ci+ 16689be168c0dSopenharmony_ci+ auto ret = ComputeWorkSpace(); 16690be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(ret, "ComputeWorkSpace failed!"); 16691be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 16692be168c0dSopenharmony_ci+ sum_data_str_ = "(float16_t *)(" + buffer_start_ + ")"; 16693be168c0dSopenharmony_ci+ auto primitive_type = param_->op_parameter_.type_; 16694be168c0dSopenharmony_ci+ std::string input_data = 16695be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16696be168c0dSopenharmony_ci+ std::string output_data = 16697be168c0dSopenharmony_ci+ "(float16_t *)(" + GetTensorAddr(output_tensor_, output_tensor_->IsConst(), dynamic_mem_manager_, allocator_) + ")"; 16698be168c0dSopenharmony_ci+ code << " int input_shape[" << input_shape_.size() << "] = " << dynamic_param_.input_shape_ << ";\n"; 16699be168c0dSopenharmony_ci+ if (primitive_type == schema::PrimitiveType_Softmax) { 16700be168c0dSopenharmony_ci+ code.CodeFunction("SoftmaxFp16", input_data, output_data, sum_data_str_, softmax_struct_.axis_, 16701be168c0dSopenharmony_ci+ softmax_struct_.n_dim_, "&input_shape"); 16702be168c0dSopenharmony_ci+ } else { 16703be168c0dSopenharmony_ci+ code.CodeFunction("LogSoftmaxFp16", input_data, output_data, sum_data_str_, "&input_shape", softmax_struct_.n_dim_, 16704be168c0dSopenharmony_ci+ softmax_struct_.axis_); 16705be168c0dSopenharmony_ci+ } 16706be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 16707be168c0dSopenharmony_ci+ return RET_OK; 16708be168c0dSopenharmony_ci+} 16709be168c0dSopenharmony_ci+ 16710be168c0dSopenharmony_ci+int SoftmaxDynamicFP16Coder::Init() { 16711be168c0dSopenharmony_ci+ param_ = reinterpret_cast<SoftmaxParameter *>(parameter_); 16712be168c0dSopenharmony_ci+ MS_CHECK_PTR(param_); 16713be168c0dSopenharmony_ci+ softmax_struct_.base_.param_ = parameter_; 16714be168c0dSopenharmony_ci+ input_shape_ = shape_info_container_->GetTemplateShape(input_tensor_); 16715be168c0dSopenharmony_ci+ size_t in_dims = input_shape_.size(); 16716be168c0dSopenharmony_ci+ softmax_struct_.n_dim_ = in_dims; 16717be168c0dSopenharmony_ci+ softmax_struct_.axis_ = param_->axis_ < 0 ? param_->axis_ + softmax_struct_.n_dim_ : param_->axis_; 16718be168c0dSopenharmony_ci+ dynamic_param_.element_size_ = AccumulateShape(input_shape_, 0, input_shape_.size()); 16719be168c0dSopenharmony_ci+ dynamic_param_.input_shape_ = "{"; 16720be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_shape_.size(); ++i) { 16721be168c0dSopenharmony_ci+ dynamic_param_.input_shape_ += input_shape_[i] + ", "; 16722be168c0dSopenharmony_ci+ } 16723be168c0dSopenharmony_ci+ dynamic_param_.input_shape_ += "}"; 16724be168c0dSopenharmony_ci+ return RET_OK; 16725be168c0dSopenharmony_ci+} 16726be168c0dSopenharmony_ci+ 16727be168c0dSopenharmony_ci+int SoftmaxDynamicFP16Coder::ComputeWorkSpace() { 16728be168c0dSopenharmony_ci+ std::map<std::string, std::vector<int>> real_nums; 16729be168c0dSopenharmony_ci+ size_t scene_num = 0; 16730be168c0dSopenharmony_ci+ for (auto &dim_template : input_shape_) { 16731be168c0dSopenharmony_ci+ auto dim_nums = shape_info_container_->GetRealNums(dim_template); 16732be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!dim_nums.empty(), RET_ERROR, "Dynamic shape's num must be greater than 0."); 16733be168c0dSopenharmony_ci+ real_nums[dim_template] = dim_nums; 16734be168c0dSopenharmony_ci+ scene_num = std::max(scene_num, dim_nums.size()); 16735be168c0dSopenharmony_ci+ } 16736be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 16737be168c0dSopenharmony_ci+ std::vector<int> real_shape(input_shape_.size()); 16738be168c0dSopenharmony_ci+ for (size_t j = 0; j < input_shape_.size(); ++j) { 16739be168c0dSopenharmony_ci+ if (IsNumber(input_shape_[j])) { 16740be168c0dSopenharmony_ci+ real_shape[j] = std::stoi(input_shape_[j]); 16741be168c0dSopenharmony_ci+ } else { 16742be168c0dSopenharmony_ci+ real_shape[j] = real_nums[input_shape_[j]][i % real_nums[input_shape_[j]].size()]; 16743be168c0dSopenharmony_ci+ } 16744be168c0dSopenharmony_ci+ } 16745be168c0dSopenharmony_ci+ int out_plane_size = 1; 16746be168c0dSopenharmony_ci+ for (int j = 0; j < softmax_struct_.axis_; ++j) { 16747be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(out_plane_size, real_shape[j], RET_ERROR); 16748be168c0dSopenharmony_ci+ out_plane_size *= real_shape[j]; 16749be168c0dSopenharmony_ci+ } 16750be168c0dSopenharmony_ci+ int in_plane_size = 1; 16751be168c0dSopenharmony_ci+ for (int j = softmax_struct_.axis_ + 1; j < softmax_struct_.n_dim_; ++j) { 16752be168c0dSopenharmony_ci+ MS_CHECK_INT_MUL_NOT_OVERFLOW(in_plane_size, real_shape[j], RET_ERROR); 16753be168c0dSopenharmony_ci+ in_plane_size *= real_shape[j]; 16754be168c0dSopenharmony_ci+ } 16755be168c0dSopenharmony_ci+ int workspace = out_plane_size * in_plane_size * sizeof(float16); 16756be168c0dSopenharmony_ci+ buffer_start_ = dynamic_mem_manager_->AllocWorkSpace(workspace, i); 16757be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!buffer_start_.empty(), RET_ERROR, "Softmax cannot alloc workspace."); 16758be168c0dSopenharmony_ci+ } 16759be168c0dSopenharmony_ci+ return RET_OK; 16760be168c0dSopenharmony_ci+} 16761be168c0dSopenharmony_ci+ 16762be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Softmax, 16763be168c0dSopenharmony_ci+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16764be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Softmax, 16765be168c0dSopenharmony_ci+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16766be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_LogSoftmax, 16767be168c0dSopenharmony_ci+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16768be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_LogSoftmax, 16769be168c0dSopenharmony_ci+ CPUOpCoderCreator<SoftmaxDynamicFP16Coder>) 16770be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16771be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h 16772be168c0dSopenharmony_cinew file mode 100644 16773be168c0dSopenharmony_ciindex 00000000..913f5ad4 16774be168c0dSopenharmony_ci--- /dev/null 16775be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/softmax_dynamic_fp16_coder.h 16776be168c0dSopenharmony_ci@@ -0,0 +1,50 @@ 16777be168c0dSopenharmony_ci+/** 16778be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16779be168c0dSopenharmony_ci+ * 16780be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16781be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16782be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16783be168c0dSopenharmony_ci+ * 16784be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16785be168c0dSopenharmony_ci+ * 16786be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16787be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16788be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16789be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16790be168c0dSopenharmony_ci+ * limitations under the License. 16791be168c0dSopenharmony_ci+ */ 16792be168c0dSopenharmony_ci+ 16793be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SOFTMAX_DYNAMIC_FP16_CODER_H_ 16794be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SOFTMAX_DYNAMIC_FP16_CODER_H_ 16795be168c0dSopenharmony_ci+ 16796be168c0dSopenharmony_ci+#include <vector> 16797be168c0dSopenharmony_ci+#include <string> 16798be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 16799be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/softmax_dynamic_parameter.h" 16800be168c0dSopenharmony_ci+#include "nnacl/softmax_parameter.h" 16801be168c0dSopenharmony_ci+#include "nnacl/kernel/softmax.h" 16802be168c0dSopenharmony_ci+ 16803be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16804be168c0dSopenharmony_ci+class SoftmaxDynamicFP16Coder final : public OperatorCoder { 16805be168c0dSopenharmony_ci+ public: 16806be168c0dSopenharmony_ci+ SoftmaxDynamicFP16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16807be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 16808be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 16809be168c0dSopenharmony_ci+ ~SoftmaxDynamicFP16Coder() override = default; 16810be168c0dSopenharmony_ci+ 16811be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 16812be168c0dSopenharmony_ci+ 16813be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 16814be168c0dSopenharmony_ci+ 16815be168c0dSopenharmony_ci+ private: 16816be168c0dSopenharmony_ci+ int Init(); 16817be168c0dSopenharmony_ci+ int ComputeWorkSpace(); 16818be168c0dSopenharmony_ci+ SoftmaxParameter *param_{nullptr}; 16819be168c0dSopenharmony_ci+ SoftmaxStruct softmax_struct_; 16820be168c0dSopenharmony_ci+ SoftmaxDynamicParameter dynamic_param_; 16821be168c0dSopenharmony_ci+ std::vector<std::string> input_shape_; 16822be168c0dSopenharmony_ci+ std::string buffer_start_; 16823be168c0dSopenharmony_ci+ std::string sum_data_str_; 16824be168c0dSopenharmony_ci+}; 16825be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16826be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_SOFTMAX_DYNAMIC_FP16_CODER_H_ 16827be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc 16828be168c0dSopenharmony_cinew file mode 100644 16829be168c0dSopenharmony_ciindex 00000000..59c8d8b8 16830be168c0dSopenharmony_ci--- /dev/null 16831be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.cc 16832be168c0dSopenharmony_ci@@ -0,0 +1,76 @@ 16833be168c0dSopenharmony_ci+/** 16834be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16835be168c0dSopenharmony_ci+ * 16836be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16837be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16838be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16839be168c0dSopenharmony_ci+ * 16840be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16841be168c0dSopenharmony_ci+ * 16842be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16843be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16844be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16845be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16846be168c0dSopenharmony_ci+ * limitations under the License. 16847be168c0dSopenharmony_ci+ */ 16848be168c0dSopenharmony_ci+ 16849be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h" 16850be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16851be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 16852be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 16853be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 16854be168c0dSopenharmony_ci+ 16855be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Transpose; 16856be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16857be168c0dSopenharmony_ci+int TransposeDynamicFp16Coder::Prepare(CoderContext *const context) { 16858be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensor_->data_type() == kNumberTypeFloat16, RET_INPUT_PARAM_INVALID, 16859be168c0dSopenharmony_ci+ "Input tensor data type is invalid."); 16860be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32, RET_INPUT_PARAM_INVALID, 16861be168c0dSopenharmony_ci+ "Perm tensor data type is invalid."); 16862be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG( 16863be168c0dSopenharmony_ci+ output_tensor_->data_type() == kNumberTypeInt32 || output_tensor_->data_type() == kNumberTypeFloat16, 16864be168c0dSopenharmony_ci+ RET_INPUT_PARAM_INVALID, "Output tensor data type is invalid."); 16865be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 16866be168c0dSopenharmony_ci+ "The second input of transpose is non-const."); 16867be168c0dSopenharmony_ci+ thread_num_ = 1; 16868be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(Init(), "init failed"); 16869be168c0dSopenharmony_ci+ return RET_OK; 16870be168c0dSopenharmony_ci+} 16871be168c0dSopenharmony_ci+ 16872be168c0dSopenharmony_ci+int TransposeDynamicFp16Coder::DoCode(CoderContext *const context) { 16873be168c0dSopenharmony_ci+ Collect(context, 16874be168c0dSopenharmony_ci+ { 16875be168c0dSopenharmony_ci+ "nnacl/transpose_parameter.h", 16876be168c0dSopenharmony_ci+ "nnacl/errorcode.h", 16877be168c0dSopenharmony_ci+ "nnacl/fp16/transpose_fp16.h", 16878be168c0dSopenharmony_ci+ }, 16879be168c0dSopenharmony_ci+ { 16880be168c0dSopenharmony_ci+ "transpose_fp16.c", 16881be168c0dSopenharmony_ci+ }); 16882be168c0dSopenharmony_ci+ 16883be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 16884be168c0dSopenharmony_ci+ dims_ = static_cast<int>(out_shapes_.size()); 16885be168c0dSopenharmony_ci+ code << "const int32_t output_shape[" << dims_ << "] = {"; 16886be168c0dSopenharmony_ci+ for (size_t i = 0; i < out_shapes_.size(); ++i) { 16887be168c0dSopenharmony_ci+ code << out_shapes_[i] << ", "; 16888be168c0dSopenharmony_ci+ } 16889be168c0dSopenharmony_ci+ code << "};\n"; 16890be168c0dSopenharmony_ci+ code.CodeStruct("trans_param", *param_, dynamic_param_); 16891be168c0dSopenharmony_ci+ auto input_str = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 16892be168c0dSopenharmony_ci+ auto output_str = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 16893be168c0dSopenharmony_ci+ if (param_->num_axes_ > DIMENSION_6D) { 16894be168c0dSopenharmony_ci+ code.CodeFunction("TransposeDimsFp16", input_str, output_str, "output_shape", "trans_param.perm_", 16895be168c0dSopenharmony_ci+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.num_axes_", kDefaultTaskId, 16896be168c0dSopenharmony_ci+ kDefaultThreadNum); 16897be168c0dSopenharmony_ci+ } else { 16898be168c0dSopenharmony_ci+ code.CodeFunction("DoTransposeFp16", input_str, output_str, "output_shape", "trans_param.perm_", 16899be168c0dSopenharmony_ci+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.data_num_", 16900be168c0dSopenharmony_ci+ "trans_param.num_axes_"); 16901be168c0dSopenharmony_ci+ } 16902be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 16903be168c0dSopenharmony_ci+ return RET_OK; 16904be168c0dSopenharmony_ci+} 16905be168c0dSopenharmony_ci+ 16906be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Transpose, 16907be168c0dSopenharmony_ci+ CPUOpCoderCreator<TransposeDynamicFp16Coder>) 16908be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16909be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h 16910be168c0dSopenharmony_cinew file mode 100644 16911be168c0dSopenharmony_ciindex 00000000..e008a794 16912be168c0dSopenharmony_ci--- /dev/null 16913be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp16/transpose_dynamic_fp16_coder.h 16914be168c0dSopenharmony_ci@@ -0,0 +1,37 @@ 16915be168c0dSopenharmony_ci+/** 16916be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16917be168c0dSopenharmony_ci+ * 16918be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16919be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16920be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16921be168c0dSopenharmony_ci+ * 16922be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16923be168c0dSopenharmony_ci+ * 16924be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16925be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16926be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16927be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16928be168c0dSopenharmony_ci+ * limitations under the License. 16929be168c0dSopenharmony_ci+ */ 16930be168c0dSopenharmony_ci+ 16931be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_TRANSPOSE_DYNAMIC_FP16_CODER_H_ 16932be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_TRANSPOSE_DYNAMIC_FP16_CODER_H_ 16933be168c0dSopenharmony_ci+#include <vector> 16934be168c0dSopenharmony_ci+#include <string> 16935be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h" 16936be168c0dSopenharmony_ci+ 16937be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16938be168c0dSopenharmony_ci+class TransposeDynamicFp16Coder : public TransposeDynamicFp32Coder { 16939be168c0dSopenharmony_ci+ public: 16940be168c0dSopenharmony_ci+ TransposeDynamicFp16Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 16941be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 16942be168c0dSopenharmony_ci+ : TransposeDynamicFp32Coder(in_tensors, out_tensors, node, node_index, target) {} 16943be168c0dSopenharmony_ci+ 16944be168c0dSopenharmony_ci+ ~TransposeDynamicFp16Coder() override = default; 16945be168c0dSopenharmony_ci+ 16946be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 16947be168c0dSopenharmony_ci+ 16948be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 16949be168c0dSopenharmony_ci+}; 16950be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 16951be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP16_TRANSPOSE_DYNAMIC_FP16_CODER_H_ 16952be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc 16953be168c0dSopenharmony_cinew file mode 100644 16954be168c0dSopenharmony_ciindex 00000000..1dd33bbd 16955be168c0dSopenharmony_ci--- /dev/null 16956be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.cc 16957be168c0dSopenharmony_ci@@ -0,0 +1,112 @@ 16958be168c0dSopenharmony_ci+/** 16959be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 16960be168c0dSopenharmony_ci+ * 16961be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 16962be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 16963be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 16964be168c0dSopenharmony_ci+ * 16965be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 16966be168c0dSopenharmony_ci+ * 16967be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 16968be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 16969be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16970be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 16971be168c0dSopenharmony_ci+ * limitations under the License. 16972be168c0dSopenharmony_ci+ */ 16973be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h" 16974be168c0dSopenharmony_ci+#include <string> 16975be168c0dSopenharmony_ci+#include "nnacl/fp32/activation_fp32.h" 16976be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 16977be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 16978be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 16979be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 16980be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 16981be168c0dSopenharmony_ci+ 16982be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Activation; 16983be168c0dSopenharmony_ci+ 16984be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 16985be168c0dSopenharmony_ci+int ActivationDynamicFP32Coder::Preprocess() { 16986be168c0dSopenharmony_ci+ // attribute 16987be168c0dSopenharmony_ci+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 16988be168c0dSopenharmony_ci+ int64_t const_part = 1; 16989be168c0dSopenharmony_ci+ std::string non_const_part; 16990be168c0dSopenharmony_ci+ for (const auto &item : in_shape) { 16991be168c0dSopenharmony_ci+ if (IsNumber(item)) { 16992be168c0dSopenharmony_ci+ const_part *= std::atoi(item.c_str()); 16993be168c0dSopenharmony_ci+ } else { 16994be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 16995be168c0dSopenharmony_ci+ non_const_part += " * "; 16996be168c0dSopenharmony_ci+ } 16997be168c0dSopenharmony_ci+ non_const_part += item; 16998be168c0dSopenharmony_ci+ } 16999be168c0dSopenharmony_ci+ } 17000be168c0dSopenharmony_ci+ count_ = std::to_string(const_part) + " * " + non_const_part; 17001be168c0dSopenharmony_ci+ input_data_ = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 17002be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17003be168c0dSopenharmony_ci+ output_data_ = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 17004be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!output_data_.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17005be168c0dSopenharmony_ci+ return RET_OK; 17006be168c0dSopenharmony_ci+} 17007be168c0dSopenharmony_ci+ 17008be168c0dSopenharmony_ci+int ActivationDynamicFP32Coder::DoCode(CoderContext *const context) { 17009be168c0dSopenharmony_ci+ Collect(context, 17010be168c0dSopenharmony_ci+ { 17011be168c0dSopenharmony_ci+ "wrapper/fp32/activation_fp32_wrapper.h", 17012be168c0dSopenharmony_ci+ "nnacl/fp32/activation_fp32.h", 17013be168c0dSopenharmony_ci+ }, 17014be168c0dSopenharmony_ci+ { 17015be168c0dSopenharmony_ci+ "activation_fp32_wrapper.c", 17016be168c0dSopenharmony_ci+ "activation_fp32.c", 17017be168c0dSopenharmony_ci+ }); 17018be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 17019be168c0dSopenharmony_ci+ auto *activation_parameter = reinterpret_cast<ActivationParameter *>(parameter_); 17020be168c0dSopenharmony_ci+ int ret = Preprocess(); 17021be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, RET_ERROR, "Preprocess failed"); 17022be168c0dSopenharmony_ci+ 17023be168c0dSopenharmony_ci+ switch (activation_parameter->type_) { 17024be168c0dSopenharmony_ci+ case schema::ActivationType_RELU: 17025be168c0dSopenharmony_ci+ code.CodeFunction("Fp32Relu", input_data_, count_, output_data_); 17026be168c0dSopenharmony_ci+ break; 17027be168c0dSopenharmony_ci+ case schema::ActivationType_RELU6: 17028be168c0dSopenharmony_ci+ code.CodeFunction("Fp32Relu6", input_data_, count_, output_data_); 17029be168c0dSopenharmony_ci+ break; 17030be168c0dSopenharmony_ci+ case schema::ActivationType_LEAKY_RELU: 17031be168c0dSopenharmony_ci+ code.CodeFunction("LRelu", input_data_, count_, output_data_, activation_parameter->alpha_); 17032be168c0dSopenharmony_ci+ break; 17033be168c0dSopenharmony_ci+ case schema::ActivationType_SIGMOID: 17034be168c0dSopenharmony_ci+ if (!support_parallel_) { 17035be168c0dSopenharmony_ci+ code.CodeFunction("Sigmoid", input_data_, count_, output_data_); 17036be168c0dSopenharmony_ci+ } else { 17037be168c0dSopenharmony_ci+ code.CodeStruct("activation_param", *activation_parameter); 17038be168c0dSopenharmony_ci+ code.CodeBaseStruct("ActivationFp32Args", kRunArgs, input_data_, count_, output_data_, 0.0f, 17039be168c0dSopenharmony_ci+ "&activation_param"); 17040be168c0dSopenharmony_ci+ code.CodeFunction(kParallelLaunch, "DoSigmoid", kRunArgsAddr, "activation_param.op_parameter_.thread_num_"); 17041be168c0dSopenharmony_ci+ } 17042be168c0dSopenharmony_ci+ break; 17043be168c0dSopenharmony_ci+ case schema::ActivationType_TANH: 17044be168c0dSopenharmony_ci+ code.CodeFunction("Tanh", input_data_, count_, output_data_); 17045be168c0dSopenharmony_ci+ break; 17046be168c0dSopenharmony_ci+ case schema::ActivationType_HSWISH: 17047be168c0dSopenharmony_ci+ code.CodeFunction("HSwish", input_data_, count_, output_data_); 17048be168c0dSopenharmony_ci+ break; 17049be168c0dSopenharmony_ci+ case schema::ActivationType_SWISH: 17050be168c0dSopenharmony_ci+ code.CodeFunction("Swish", input_data_, count_, output_data_); 17051be168c0dSopenharmony_ci+ break; 17052be168c0dSopenharmony_ci+ case schema::ActivationType_HSIGMOID: 17053be168c0dSopenharmony_ci+ code.CodeFunction("HSigmoid", input_data_, count_, output_data_); 17054be168c0dSopenharmony_ci+ break; 17055be168c0dSopenharmony_ci+ case schema::ActivationType_ELU: 17056be168c0dSopenharmony_ci+ code.CodeFunction("Elu", input_data_, count_, output_data_, activation_parameter->alpha_); 17057be168c0dSopenharmony_ci+ break; 17058be168c0dSopenharmony_ci+ default: 17059be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Activation type error"; 17060be168c0dSopenharmony_ci+ return RET_ERROR; 17061be168c0dSopenharmony_ci+ } 17062be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "ActivationFP32Code has been called"; 17063be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 17064be168c0dSopenharmony_ci+ return lite::RET_OK; 17065be168c0dSopenharmony_ci+} 17066be168c0dSopenharmony_ci+ 17067be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Activation, 17068be168c0dSopenharmony_ci+ CPUOpCoderCreator<ActivationDynamicFP32Coder>) 17069be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17070be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h 17071be168c0dSopenharmony_cinew file mode 100644 17072be168c0dSopenharmony_ciindex 00000000..1560afbb 17073be168c0dSopenharmony_ci--- /dev/null 17074be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/activation_dynamic_fp32_coder.h 17075be168c0dSopenharmony_ci@@ -0,0 +1,46 @@ 17076be168c0dSopenharmony_ci+/** 17077be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 17078be168c0dSopenharmony_ci+ * 17079be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17080be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17081be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17082be168c0dSopenharmony_ci+ * 17083be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17084be168c0dSopenharmony_ci+ * 17085be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17086be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17087be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17088be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17089be168c0dSopenharmony_ci+ * limitations under the License. 17090be168c0dSopenharmony_ci+ */ 17091be168c0dSopenharmony_ci+ 17092be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_ACTIVATION_DYNAMIC_FP32_CODER_H_ 17093be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_ACTIVATION_DYNAMIC_FP32_CODER_H_ 17094be168c0dSopenharmony_ci+ 17095be168c0dSopenharmony_ci+#include <string> 17096be168c0dSopenharmony_ci+#include <vector> 17097be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/opcoders/op_coder.h" 17098be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 17099be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 17100be168c0dSopenharmony_ci+ 17101be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17102be168c0dSopenharmony_ci+class ActivationDynamicFP32Coder : public OperatorCoder { 17103be168c0dSopenharmony_ci+ public: 17104be168c0dSopenharmony_ci+ ActivationDynamicFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17105be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 17106be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17107be168c0dSopenharmony_ci+ 17108be168c0dSopenharmony_ci+ ~ActivationDynamicFP32Coder() override = default; 17109be168c0dSopenharmony_ci+ 17110be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override { return RET_OK; } 17111be168c0dSopenharmony_ci+ 17112be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 17113be168c0dSopenharmony_ci+ 17114be168c0dSopenharmony_ci+ protected: 17115be168c0dSopenharmony_ci+ int Preprocess(); 17116be168c0dSopenharmony_ci+ std::string count_; 17117be168c0dSopenharmony_ci+ std::string input_data_; 17118be168c0dSopenharmony_ci+ std::string output_data_; 17119be168c0dSopenharmony_ci+}; 17120be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17121be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_ACTIVATION_DYNAMIC_FP32_CODER_H_ 17122be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 17123be168c0dSopenharmony_ciindex c15d3101..1b827283 100644 17124be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 17125be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/convolution_winograd_fp32_coder.cc 17126be168c0dSopenharmony_ci@@ -266,7 +266,6 @@ void ConvolutionWinogradFP32Coder::CollectFilesForFunc(CoderContext *const conte 17127be168c0dSopenharmony_ci } else if (target_ == kARM64) { 17128be168c0dSopenharmony_ci Collect(context, {}, {}, 17129be168c0dSopenharmony_ci { 17130be168c0dSopenharmony_ci- "BigMatmulFp32Opt.S", 17131be168c0dSopenharmony_ci "MatmulFp32.S", 17132be168c0dSopenharmony_ci "MatmulFp32Opt.S", 17133be168c0dSopenharmony_ci "PreSum4x16Int8Peroc.S", 17134be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc 17135be168c0dSopenharmony_cinew file mode 100644 17136be168c0dSopenharmony_ciindex 00000000..57d7a5dd 17137be168c0dSopenharmony_ci--- /dev/null 17138be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.cc 17139be168c0dSopenharmony_ci@@ -0,0 +1,106 @@ 17140be168c0dSopenharmony_ci+/** 17141be168c0dSopenharmony_ci+ * Copyright 2021-2022 Huawei Technologies Co., Ltd 17142be168c0dSopenharmony_ci+ * 17143be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17144be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17145be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17146be168c0dSopenharmony_ci+ * 17147be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17148be168c0dSopenharmony_ci+ * 17149be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17150be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17151be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17152be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17153be168c0dSopenharmony_ci+ * limitations under the License. 17154be168c0dSopenharmony_ci+ */ 17155be168c0dSopenharmony_ci+ 17156be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h" 17157be168c0dSopenharmony_ci+#include <string> 17158be168c0dSopenharmony_ci+#include "nnacl/gather_parameter.h" 17159be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 17160be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 17161be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 17162be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 17163be168c0dSopenharmony_ci+ 17164be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Gather; 17165be168c0dSopenharmony_ci+ 17166be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17167be168c0dSopenharmony_ci+int GatherDynamicFP32Coder::Prepare(CoderContext *const context) { 17168be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_.size() == C3NUM, RET_ERROR, "Gather's input-num must be 3."); 17169be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[FIRST_INPUT]->IsConst() && input_tensors_[THIRD_INPUT]->IsConst(), RET_NOT_SUPPORT, 17170be168c0dSopenharmony_ci+ "Currently, only support the second input of gather is non-const when shape is dynamical."); 17171be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[THIRD_INPUT]->data_type() == kNumberTypeInt32 || 17172be168c0dSopenharmony_ci+ input_tensors_[THIRD_INPUT]->data_type() == kNumberTypeInt, 17173be168c0dSopenharmony_ci+ RET_ERROR, "The data-type of Gather's third input must be int."); 17174be168c0dSopenharmony_ci+ auto axis = input_tensors_[THIRD_INPUT]->data(); 17175be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(axis != nullptr, RET_NULL_PTR, "Gather has no axis."); 17176be168c0dSopenharmony_ci+ axis_ = *(static_cast<int *>(axis)); 17177be168c0dSopenharmony_ci+ auto in_shape0 = input_tensors_[FIRST_INPUT]->shape(); 17178be168c0dSopenharmony_ci+ axis_ = axis_ >= 0 ? axis_ : axis_ + static_cast<int>(in_shape0.size()); 17179be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(axis_ >= 0 && axis_ < static_cast<int>(in_shape0.size()), RET_INPUT_TENSOR_ERROR, 17180be168c0dSopenharmony_ci+ "Gather's axis is out of range."); 17181be168c0dSopenharmony_ci+ return RET_OK; 17182be168c0dSopenharmony_ci+} 17183be168c0dSopenharmony_ci+ 17184be168c0dSopenharmony_ci+int GatherDynamicFP32Coder::DoCode(CoderContext *const context) { 17185be168c0dSopenharmony_ci+ Collect(context, 17186be168c0dSopenharmony_ci+ { 17187be168c0dSopenharmony_ci+ "nnacl/base/gather_base.h", 17188be168c0dSopenharmony_ci+ }, 17189be168c0dSopenharmony_ci+ { 17190be168c0dSopenharmony_ci+ "gather_base.c", 17191be168c0dSopenharmony_ci+ }); 17192be168c0dSopenharmony_ci+ auto in_shape0 = input_tensors_[FIRST_INPUT]->shape(); 17193be168c0dSopenharmony_ci+ auto data_item_size = static_cast<int>(lite::DataTypeSize(input_tensors_[FIRST_INPUT]->data_type())); 17194be168c0dSopenharmony_ci+ int64_t out_size = 1; 17195be168c0dSopenharmony_ci+ for (size_t i = 0; i < static_cast<size_t>(axis_); ++i) { 17196be168c0dSopenharmony_ci+ out_size *= in_shape0[i]; 17197be168c0dSopenharmony_ci+ } 17198be168c0dSopenharmony_ci+ int64_t byte_inner_size = data_item_size; 17199be168c0dSopenharmony_ci+ for (size_t i = axis_ + 1; i < in_shape0.size(); ++i) { 17200be168c0dSopenharmony_ci+ byte_inner_size *= in_shape0[i]; 17201be168c0dSopenharmony_ci+ } 17202be168c0dSopenharmony_ci+ int64_t limit = in_shape0[axis_]; 17203be168c0dSopenharmony_ci+ auto in_shape1 = shape_info_container_->GetTemplateShape(input_tensors_[SECOND_INPUT]); 17204be168c0dSopenharmony_ci+ int64_t const_part = 1; 17205be168c0dSopenharmony_ci+ std::string non_const_part; 17206be168c0dSopenharmony_ci+ for (const auto &item : in_shape1) { 17207be168c0dSopenharmony_ci+ if (IsNumber(item)) { 17208be168c0dSopenharmony_ci+ const_part *= std::stoi(item); 17209be168c0dSopenharmony_ci+ } else { 17210be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 17211be168c0dSopenharmony_ci+ non_const_part += " * "; 17212be168c0dSopenharmony_ci+ } 17213be168c0dSopenharmony_ci+ non_const_part += item; 17214be168c0dSopenharmony_ci+ } 17215be168c0dSopenharmony_ci+ } 17216be168c0dSopenharmony_ci+ std::string byte_out_stride_str = std::to_string(const_part * byte_inner_size); 17217be168c0dSopenharmony_ci+ std::string index_num_str = std::to_string(const_part); 17218be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 17219be168c0dSopenharmony_ci+ byte_out_stride_str += " * " + non_const_part; 17220be168c0dSopenharmony_ci+ index_num_str += " * " + non_const_part; 17221be168c0dSopenharmony_ci+ } 17222be168c0dSopenharmony_ci+ std::string input0_data = MemoryAllocator::GetInstance()->GetRuntimeAddr(input_tensors_[FIRST_INPUT], true); 17223be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input0_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17224be168c0dSopenharmony_ci+ std::string input1_data = dynamic_mem_manager_->GetVarTensorAddr(input_tensors_[SECOND_INPUT]); 17225be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!input1_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17226be168c0dSopenharmony_ci+ std::string output_data = dynamic_mem_manager_->GetVarTensorAddr(output_tensors_[FIRST_INPUT]); 17227be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!output_data.empty(), RET_ERROR, "pointer is not allocated by the allocator"); 17228be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 17229be168c0dSopenharmony_ci+ code << "\t\tconst int8_t *int8_in = (const int8_t *)(" << input0_data << ");\n"; 17230be168c0dSopenharmony_ci+ code << "\t\tconst int *index_data = (const int *)(" << input1_data << ");\n"; 17231be168c0dSopenharmony_ci+ code << "\t\tint8_t *int8_out = (int8_t *)(" << output_data << ");\n"; 17232be168c0dSopenharmony_ci+ // call the op function 17233be168c0dSopenharmony_ci+ code.CodeFunction("Gather", "int8_in", out_size, byte_inner_size, limit, "index_data", index_num_str, "int8_out", 17234be168c0dSopenharmony_ci+ byte_out_stride_str); 17235be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 17236be168c0dSopenharmony_ci+ return RET_OK; 17237be168c0dSopenharmony_ci+} 17238be168c0dSopenharmony_ci+ 17239be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Gather, 17240be168c0dSopenharmony_ci+ CPUOpCoderCreator<GatherDynamicFP32Coder>) 17241be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Gather, 17242be168c0dSopenharmony_ci+ CPUOpCoderCreator<GatherDynamicFP32Coder>) 17243be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat16, PrimitiveType_Gather, CPUOpCoderCreator<GatherDynamicFP32Coder>) 17244be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM32, kNumberTypeFloat16, PrimitiveType_Gather, CPUOpCoderCreator<GatherDynamicFP32Coder>) 17245be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17246be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h 17247be168c0dSopenharmony_cinew file mode 100644 17248be168c0dSopenharmony_ciindex 00000000..9e58e1fa 17249be168c0dSopenharmony_ci--- /dev/null 17250be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/gather_dynamic_fp32_coder.h 17251be168c0dSopenharmony_ci@@ -0,0 +1,42 @@ 17252be168c0dSopenharmony_ci+/** 17253be168c0dSopenharmony_ci+ * Copyright 2021 Huawei Technologies Co., Ltd 17254be168c0dSopenharmony_ci+ * 17255be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17256be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17257be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17258be168c0dSopenharmony_ci+ * 17259be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17260be168c0dSopenharmony_ci+ * 17261be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17262be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17263be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17264be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17265be168c0dSopenharmony_ci+ * limitations under the License. 17266be168c0dSopenharmony_ci+ */ 17267be168c0dSopenharmony_ci+ 17268be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_DYNAMIC_FP32_CODER_H_ 17269be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_DYNAMIC_FP32_CODER_H_ 17270be168c0dSopenharmony_ci+ 17271be168c0dSopenharmony_ci+#include <string> 17272be168c0dSopenharmony_ci+#include <vector> 17273be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 17274be168c0dSopenharmony_ci+#include "nnacl/base/tile_base.h" 17275be168c0dSopenharmony_ci+ 17276be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17277be168c0dSopenharmony_ci+class GatherDynamicFP32Coder final : public OperatorCoder { 17278be168c0dSopenharmony_ci+ public: 17279be168c0dSopenharmony_ci+ GatherDynamicFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17280be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 17281be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17282be168c0dSopenharmony_ci+ 17283be168c0dSopenharmony_ci+ ~GatherDynamicFP32Coder() override = default; 17284be168c0dSopenharmony_ci+ 17285be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 17286be168c0dSopenharmony_ci+ 17287be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 17288be168c0dSopenharmony_ci+ 17289be168c0dSopenharmony_ci+ private: 17290be168c0dSopenharmony_ci+ int axis_{0}; 17291be168c0dSopenharmony_ci+}; 17292be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17293be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_GATHER_DYNAMIC_FP32_CODER_H_ 17294be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc 17295be168c0dSopenharmony_cinew file mode 100644 17296be168c0dSopenharmony_ciindex 00000000..4ec7f317 17297be168c0dSopenharmony_ci--- /dev/null 17298be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.cc 17299be168c0dSopenharmony_ci@@ -0,0 +1,94 @@ 17300be168c0dSopenharmony_ci+/** 17301be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd 17302be168c0dSopenharmony_ci+ * 17303be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17304be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17305be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17306be168c0dSopenharmony_ci+ * 17307be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17308be168c0dSopenharmony_ci+ * 17309be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17310be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17311be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17312be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17313be168c0dSopenharmony_ci+ * limitations under the License. 17314be168c0dSopenharmony_ci+ */ 17315be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h" 17316be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 17317be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 17318be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 17319be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 17320be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 17321be168c0dSopenharmony_ci+ 17322be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Split; 17323be168c0dSopenharmony_ci+ 17324be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17325be168c0dSopenharmony_ci+int SplitDynamicFP32Coder::Prepare(CoderContext *const context) { 17326be168c0dSopenharmony_ci+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 17327be168c0dSopenharmony_ci+ int in_shape_size = static_cast<int>(input_shape.size()); 17328be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(in_shape_size, 1); 17329be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(SPLIT_STRIDES_SIZE - 1, in_shape_size); 17330be168c0dSopenharmony_ci+ param_ = reinterpret_cast<SplitParameter *>(parameter_); 17331be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(param_); 17332be168c0dSopenharmony_ci+ 17333be168c0dSopenharmony_ci+ auto split_dim = param_->split_dim_; 17334be168c0dSopenharmony_ci+ param_->split_dim_ = split_dim >= 0 ? split_dim : in_shape_size + split_dim; 17335be168c0dSopenharmony_ci+ std::vector<std::string> strides(in_shape_size); 17336be168c0dSopenharmony_ci+ strides[in_shape_size - 1] = "1"; 17337be168c0dSopenharmony_ci+ for (int i = static_cast<int>(in_shape_size) - C2NUM; i >= 0; i--) { 17338be168c0dSopenharmony_ci+ strides[i] = strides[i + 1] + " * " + input_shape[i + 1]; 17339be168c0dSopenharmony_ci+ } 17340be168c0dSopenharmony_ci+ dynamic_param_.strides_ = "{"; 17341be168c0dSopenharmony_ci+ for (int i = 0; i < in_shape_size; ++i) { 17342be168c0dSopenharmony_ci+ dynamic_param_.strides_ += strides[i] + ", "; 17343be168c0dSopenharmony_ci+ } 17344be168c0dSopenharmony_ci+ dynamic_param_.strides_ += "}"; 17345be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(in_shape_size, param_->split_dim_ + 1); 17346be168c0dSopenharmony_ci+ if (input_shape.at(param_->split_dim_) == "0") { 17347be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "input_shape[" << param_->split_dim_ << "] must not be zero!"; 17348be168c0dSopenharmony_ci+ return RET_ERROR; 17349be168c0dSopenharmony_ci+ } 17350be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(SPLIT_STRIDES_SIZE, param_->split_dim_ + 1); 17351be168c0dSopenharmony_ci+ if (strides[param_->split_dim_] == "0") { 17352be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "strides[" << param_->split_dim_ << "] must not be zero!"; 17353be168c0dSopenharmony_ci+ return RET_ERROR; 17354be168c0dSopenharmony_ci+ } 17355be168c0dSopenharmony_ci+ dynamic_param_.split_count_ = strides[0] + " * " + input_shape[0] + " / (" + input_shape.at(param_->split_dim_) + 17356be168c0dSopenharmony_ci+ " * " + strides[param_->split_dim_] + ")"; 17357be168c0dSopenharmony_ci+ param_->n_dims_ = static_cast<int>(input_shape.size()); 17358be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(param_->num_split_, 1); 17359be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(param_->split_sizes_[0] != 0 && param_->split_sizes_[param_->num_split_ - 1] != -1, 17360be168c0dSopenharmony_ci+ lite::RET_PARAM_INVALID, "Currently, split not support split_size 0 or -1"); 17361be168c0dSopenharmony_ci+ return RET_OK; 17362be168c0dSopenharmony_ci+} 17363be168c0dSopenharmony_ci+ 17364be168c0dSopenharmony_ci+int SplitDynamicFP32Coder::DoCode(CoderContext *const context) { 17365be168c0dSopenharmony_ci+ Collect(context, {"nnacl/base/split_base.h"}, {"split_base.c"}); 17366be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 17367be168c0dSopenharmony_ci+ code << " void *output_ptrs[" << output_tensors_.size() << "] = {"; 17368be168c0dSopenharmony_ci+ for (int i = 0; i < param_->num_split_; i++) { 17369be168c0dSopenharmony_ci+ code << GetTensorAddr(output_tensors_.at(i), output_tensors_.at(i)->IsConst(), dynamic_mem_manager_, allocator_) 17370be168c0dSopenharmony_ci+ << ", "; 17371be168c0dSopenharmony_ci+ } 17372be168c0dSopenharmony_ci+ code << "};\n"; 17373be168c0dSopenharmony_ci+ auto input_shape = shape_info_container_->GetTemplateShape(input_tensor_); 17374be168c0dSopenharmony_ci+ code << " int input_dim[" << input_shape.size() << "] = {"; 17375be168c0dSopenharmony_ci+ for (auto &dim : input_shape) { 17376be168c0dSopenharmony_ci+ code << dim << ", "; 17377be168c0dSopenharmony_ci+ } 17378be168c0dSopenharmony_ci+ code << "};\n"; 17379be168c0dSopenharmony_ci+ std::string input_data = GetTensorAddr(input_tensor_, input_tensor_->IsConst(), dynamic_mem_manager_, allocator_); 17380be168c0dSopenharmony_ci+ std::string num_unit = dynamic_param_.split_count_ + " * " + std::to_string(param_->num_split_); 17381be168c0dSopenharmony_ci+ code.CodeStruct("split_param", *param_, dynamic_param_); 17382be168c0dSopenharmony_ci+ code.CodeFunction("DoSplit", input_data, "output_ptrs", "input_dim", "0", num_unit, "&split_param", 17383be168c0dSopenharmony_ci+ lite::DataTypeSize(input_tensor_->data_type())); 17384be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 17385be168c0dSopenharmony_ci+ return RET_OK; 17386be168c0dSopenharmony_ci+} 17387be168c0dSopenharmony_ci+ 17388be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat32, PrimitiveType_Split, 17389be168c0dSopenharmony_ci+ CPUOpCoderCreator<SplitDynamicFP32Coder>) 17390be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeInt32, PrimitiveType_Split, CPUOpCoderCreator<SplitDynamicFP32Coder>) 17391be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kAllTargets, kNumberTypeFloat16, PrimitiveType_Split, 17392be168c0dSopenharmony_ci+ CPUOpCoderCreator<SplitDynamicFP32Coder>) 17393be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17394be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h 17395be168c0dSopenharmony_cinew file mode 100644 17396be168c0dSopenharmony_ciindex 00000000..e3e64cb3 17397be168c0dSopenharmony_ci--- /dev/null 17398be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/split_dynamic_fp32_coder.h 17399be168c0dSopenharmony_ci@@ -0,0 +1,42 @@ 17400be168c0dSopenharmony_ci+/** 17401be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 17402be168c0dSopenharmony_ci+ * 17403be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17404be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17405be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17406be168c0dSopenharmony_ci+ * 17407be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17408be168c0dSopenharmony_ci+ * 17409be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17410be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17411be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17412be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17413be168c0dSopenharmony_ci+ * limitations under the License. 17414be168c0dSopenharmony_ci+ */ 17415be168c0dSopenharmony_ci+ 17416be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_SPLIT_DYNAMIC_FP32_CODER_H_ 17417be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_SPLIT_DYNAMIC_FP32_CODER_H_ 17418be168c0dSopenharmony_ci+ 17419be168c0dSopenharmony_ci+#include <vector> 17420be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 17421be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h" 17422be168c0dSopenharmony_ci+#include "nnacl/split_parameter.h" 17423be168c0dSopenharmony_ci+ 17424be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17425be168c0dSopenharmony_ci+class SplitDynamicFP32Coder : public OperatorCoder { 17426be168c0dSopenharmony_ci+ public: 17427be168c0dSopenharmony_ci+ SplitDynamicFP32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17428be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 17429be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17430be168c0dSopenharmony_ci+ ~SplitDynamicFP32Coder() override = default; 17431be168c0dSopenharmony_ci+ 17432be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 17433be168c0dSopenharmony_ci+ 17434be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 17435be168c0dSopenharmony_ci+ 17436be168c0dSopenharmony_ci+ protected: 17437be168c0dSopenharmony_ci+ SplitParameter *param_{nullptr}; 17438be168c0dSopenharmony_ci+ SplitDynamicParameter dynamic_param_; 17439be168c0dSopenharmony_ci+}; 17440be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17441be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_SPLIT_DYNAMIC_FP32_CODER_H_ 17442be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc 17443be168c0dSopenharmony_cinew file mode 100644 17444be168c0dSopenharmony_ciindex 00000000..7fb160d5 17445be168c0dSopenharmony_ci--- /dev/null 17446be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.cc 17447be168c0dSopenharmony_ci@@ -0,0 +1,171 @@ 17448be168c0dSopenharmony_ci+/** 17449be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 17450be168c0dSopenharmony_ci+ * 17451be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17452be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17453be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17454be168c0dSopenharmony_ci+ * 17455be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17456be168c0dSopenharmony_ci+ * 17457be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17458be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17459be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17460be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17461be168c0dSopenharmony_ci+ * limitations under the License. 17462be168c0dSopenharmony_ci+ */ 17463be168c0dSopenharmony_ci+ 17464be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h" 17465be168c0dSopenharmony_ci+#include <vector> 17466be168c0dSopenharmony_ci+#include <unordered_set> 17467be168c0dSopenharmony_ci+#include "coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h" 17468be168c0dSopenharmony_ci+#include "coder/opcoders/file_collector.h" 17469be168c0dSopenharmony_ci+#include "coder/opcoders/parallel.h" 17470be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 17471be168c0dSopenharmony_ci+ 17472be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Transpose; 17473be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17474be168c0dSopenharmony_ci+int TransposeDynamicFp32Coder::Prepare(CoderContext *const context) { 17475be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensor_->data_type() == kNumberTypeInt32 || input_tensor_->data_type() == kNumberTypeFloat32, 17476be168c0dSopenharmony_ci+ RET_INPUT_PARAM_INVALID, "Input tensor data type is invalid."); 17477be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->data_type() == kNumberTypeInt32, RET_INPUT_PARAM_INVALID, 17478be168c0dSopenharmony_ci+ "Perm tensor data type is invalid."); 17479be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG( 17480be168c0dSopenharmony_ci+ output_tensor_->data_type() == kNumberTypeInt32 || output_tensor_->data_type() == kNumberTypeFloat32, 17481be168c0dSopenharmony_ci+ RET_INPUT_PARAM_INVALID, "Output tensor data type is invalid."); 17482be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(input_tensors_[SECOND_INPUT]->IsConst(), RET_NOT_SUPPORT, 17483be168c0dSopenharmony_ci+ "The second input of transpose is non-const."); 17484be168c0dSopenharmony_ci+ thread_num_ = 1; 17485be168c0dSopenharmony_ci+ MS_CHECK_RET_CODE(Init(), "init failed"); 17486be168c0dSopenharmony_ci+ return RET_OK; 17487be168c0dSopenharmony_ci+} 17488be168c0dSopenharmony_ci+ 17489be168c0dSopenharmony_ci+int TransposeDynamicFp32Coder::DoCode(CoderContext *const context) { 17490be168c0dSopenharmony_ci+ Collect(context, 17491be168c0dSopenharmony_ci+ { 17492be168c0dSopenharmony_ci+ "nnacl/transpose_parameter.h", 17493be168c0dSopenharmony_ci+ "nnacl/errorcode.h", 17494be168c0dSopenharmony_ci+ "nnacl/fp32/transpose_fp32.h", 17495be168c0dSopenharmony_ci+ }, 17496be168c0dSopenharmony_ci+ { 17497be168c0dSopenharmony_ci+ "transpose_fp32.c", 17498be168c0dSopenharmony_ci+ }); 17499be168c0dSopenharmony_ci+ 17500be168c0dSopenharmony_ci+ NNaclFp32Serializer code; 17501be168c0dSopenharmony_ci+ dims_ = static_cast<int>(out_shapes_.size()); 17502be168c0dSopenharmony_ci+ code << "const int32_t output_shape[" << dims_ << "] = {"; 17503be168c0dSopenharmony_ci+ for (size_t i = 0; i < out_shapes_.size(); ++i) { 17504be168c0dSopenharmony_ci+ code << out_shapes_[i] << ", "; 17505be168c0dSopenharmony_ci+ } 17506be168c0dSopenharmony_ci+ code << "};\n"; 17507be168c0dSopenharmony_ci+ code.CodeStruct("trans_param", *param_, dynamic_param_); 17508be168c0dSopenharmony_ci+ auto input_str = dynamic_mem_manager_->GetVarTensorAddr(input_tensor_); 17509be168c0dSopenharmony_ci+ auto output_str = dynamic_mem_manager_->GetVarTensorAddr(output_tensor_); 17510be168c0dSopenharmony_ci+ if (param_->num_axes_ > DIMENSION_6D) { 17511be168c0dSopenharmony_ci+ code.CodeFunction("TransposeDimsFp32", input_str, output_str, "output_shape", "trans_param.perm_", 17512be168c0dSopenharmony_ci+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.num_axes_", kDefaultTaskId, 17513be168c0dSopenharmony_ci+ kDefaultThreadNum); 17514be168c0dSopenharmony_ci+ } else { 17515be168c0dSopenharmony_ci+ code.CodeFunction("DoTransposeFp32", input_str, output_str, "output_shape", "trans_param.perm_", 17516be168c0dSopenharmony_ci+ "trans_param.strides_", "trans_param.out_strides_", "trans_param.data_num_", 17517be168c0dSopenharmony_ci+ "trans_param.num_axes_"); 17518be168c0dSopenharmony_ci+ } 17519be168c0dSopenharmony_ci+ context->AppendCode(code.str()); 17520be168c0dSopenharmony_ci+ return RET_OK; 17521be168c0dSopenharmony_ci+} 17522be168c0dSopenharmony_ci+ 17523be168c0dSopenharmony_ci+int TransposeDynamicFp32Coder::Init() { 17524be168c0dSopenharmony_ci+ param_ = reinterpret_cast<TransposeParameter *>(parameter_); 17525be168c0dSopenharmony_ci+ MS_CHECK_PTR(param_); 17526be168c0dSopenharmony_ci+ param_->num_axes_ = 0; 17527be168c0dSopenharmony_ci+ if (input_tensors_.size() == C2NUM) { 17528be168c0dSopenharmony_ci+ param_->num_axes_ = input_tensors_[SECOND_INPUT]->ElementsNum(); 17529be168c0dSopenharmony_ci+ } 17530be168c0dSopenharmony_ci+ if (input_tensor_->shape().size() != static_cast<size_t>(param_->num_axes_)) { 17531be168c0dSopenharmony_ci+ return RET_OK; 17532be168c0dSopenharmony_ci+ } 17533be168c0dSopenharmony_ci+ // get perm data 17534be168c0dSopenharmony_ci+ auto ret = ResetStatus(); 17535be168c0dSopenharmony_ci+ if (ret != RET_OK) { 17536be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Do transpose reset failed."; 17537be168c0dSopenharmony_ci+ return ret; 17538be168c0dSopenharmony_ci+ } 17539be168c0dSopenharmony_ci+ 17540be168c0dSopenharmony_ci+ ret = ComputeOfflineInfo(); 17541be168c0dSopenharmony_ci+ if (ret != RET_OK) { 17542be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Do compute transpose offline info failed."; 17543be168c0dSopenharmony_ci+ return ret; 17544be168c0dSopenharmony_ci+ } 17545be168c0dSopenharmony_ci+ return RET_OK; 17546be168c0dSopenharmony_ci+} 17547be168c0dSopenharmony_ci+ 17548be168c0dSopenharmony_ci+int TransposeDynamicFp32Coder::ResetStatus() { 17549be168c0dSopenharmony_ci+ auto in_shape = shape_info_container_->GetTemplateShape(input_tensor_); 17550be168c0dSopenharmony_ci+ if (in_shape.size() > MAX_TRANSPOSE_DIM_SIZE) { 17551be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "input shape out of range."; 17552be168c0dSopenharmony_ci+ return RET_ERROR; 17553be168c0dSopenharmony_ci+ } 17554be168c0dSopenharmony_ci+ int trans_nd[MAX_TRANSPOSE_DIM_SIZE] = {0, 2, 1}; 17555be168c0dSopenharmony_ci+ int *perm_data{nullptr}; 17556be168c0dSopenharmony_ci+ if (in_shape.size() != static_cast<size_t>(param_->num_axes_)) { 17557be168c0dSopenharmony_ci+ perm_data = trans_nd; 17558be168c0dSopenharmony_ci+ if (in_shape.size() == C3NUM && param_->num_axes_ == C4NUM) { 17559be168c0dSopenharmony_ci+ param_->num_axes_ = C3NUM; 17560be168c0dSopenharmony_ci+ } 17561be168c0dSopenharmony_ci+ if (param_->num_axes_ == 0) { 17562be168c0dSopenharmony_ci+ for (int i = 0; i < static_cast<int>(in_shape.size()); ++i) { 17563be168c0dSopenharmony_ci+ trans_nd[i] = static_cast<int>(in_shape.size()) - 1 - i; 17564be168c0dSopenharmony_ci+ } 17565be168c0dSopenharmony_ci+ param_->num_axes_ = static_cast<int>(in_shape.size()); 17566be168c0dSopenharmony_ci+ } 17567be168c0dSopenharmony_ci+ } else { 17568be168c0dSopenharmony_ci+ if (input_tensors_.size() != C2NUM) { 17569be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "input tensors size is not equal to 2."; 17570be168c0dSopenharmony_ci+ return RET_ERROR; 17571be168c0dSopenharmony_ci+ } 17572be168c0dSopenharmony_ci+ auto perm_tensor = input_tensors_.at(SECOND_INPUT); 17573be168c0dSopenharmony_ci+ perm_data = reinterpret_cast<int *>(perm_tensor->data()); 17574be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(perm_data); 17575be168c0dSopenharmony_ci+ std::vector<int> perm(perm_data, perm_data + input_tensors_[SECOND_INPUT]->ElementsNum()); 17576be168c0dSopenharmony_ci+ if (perm.size() != std::unordered_set<int>(perm.cbegin(), perm.cend()).size()) { 17577be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid perm, the same element exits in perm."; 17578be168c0dSopenharmony_ci+ return RET_ERROR; 17579be168c0dSopenharmony_ci+ } 17580be168c0dSopenharmony_ci+ } 17581be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(param_->num_axes_ <= MAX_TRANSPOSE_DIM_SIZE, RET_ERROR, "transpose perm is invalid."); 17582be168c0dSopenharmony_ci+ for (int i = 0; i < param_->num_axes_; ++i) { 17583be168c0dSopenharmony_ci+ param_->perm_[i] = perm_data[i]; 17584be168c0dSopenharmony_ci+ } 17585be168c0dSopenharmony_ci+ return RET_OK; 17586be168c0dSopenharmony_ci+} 17587be168c0dSopenharmony_ci+ 17588be168c0dSopenharmony_ci+int TransposeDynamicFp32Coder::ComputeOfflineInfo() { 17589be168c0dSopenharmony_ci+ in_shapes_ = shape_info_container_->GetTemplateShape(input_tensor_); 17590be168c0dSopenharmony_ci+ out_shapes_ = shape_info_container_->GetTemplateShape(output_tensor_); 17591be168c0dSopenharmony_ci+ const int ori_stride = 1; 17592be168c0dSopenharmony_ci+ dynamic_param_.strides_ = std::to_string(ori_stride) + ", "; 17593be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ = std::to_string(ori_stride) + ", "; 17594be168c0dSopenharmony_ci+ dynamic_param_.data_num_ = AccumulateShape(in_shapes_, 0, in_shapes_.size()); 17595be168c0dSopenharmony_ci+ std::vector<std::string> strides(param_->num_axes_); 17596be168c0dSopenharmony_ci+ std::vector<std::string> out_strides(param_->num_axes_); 17597be168c0dSopenharmony_ci+ strides[param_->num_axes_ - 1] = "1"; 17598be168c0dSopenharmony_ci+ out_strides[param_->num_axes_ - 1] = "1"; 17599be168c0dSopenharmony_ci+ for (int i = param_->num_axes_ - C2NUM; i >= 0; --i) { 17600be168c0dSopenharmony_ci+ strides[i] = in_shapes_[i + 1] + " * " + strides[i + 1]; 17601be168c0dSopenharmony_ci+ out_strides[i] = out_shapes_[i + 1] + " * " + out_strides[i + 1]; 17602be168c0dSopenharmony_ci+ } 17603be168c0dSopenharmony_ci+ dynamic_param_.strides_ = "{"; 17604be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ = "{"; 17605be168c0dSopenharmony_ci+ for (int i = 0; i < param_->num_axes_; ++i) { 17606be168c0dSopenharmony_ci+ dynamic_param_.strides_ += strides[i] + ", "; 17607be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ += out_strides[i] + ", "; 17608be168c0dSopenharmony_ci+ } 17609be168c0dSopenharmony_ci+ dynamic_param_.strides_ += "}"; 17610be168c0dSopenharmony_ci+ dynamic_param_.out_strides_ += "}"; 17611be168c0dSopenharmony_ci+ return RET_OK; 17612be168c0dSopenharmony_ci+} 17613be168c0dSopenharmony_ci+ 17614be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeFloat32, PrimitiveType_Transpose, 17615be168c0dSopenharmony_ci+ CPUOpCoderCreator<TransposeDynamicFp32Coder>) 17616be168c0dSopenharmony_ci+REG_DYNAMIC_OPERATOR_CODER(kARM64, kNumberTypeInt32, PrimitiveType_Transpose, 17617be168c0dSopenharmony_ci+ CPUOpCoderCreator<TransposeDynamicFp32Coder>) 17618be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17619be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h 17620be168c0dSopenharmony_cinew file mode 100644 17621be168c0dSopenharmony_ciindex 00000000..9230b8e3 17622be168c0dSopenharmony_ci--- /dev/null 17623be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/nnacl/fp32/transpose_dynamic_fp32_coder.h 17624be168c0dSopenharmony_ci@@ -0,0 +1,49 @@ 17625be168c0dSopenharmony_ci+/** 17626be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 17627be168c0dSopenharmony_ci+ * 17628be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 17629be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 17630be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 17631be168c0dSopenharmony_ci+ * 17632be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 17633be168c0dSopenharmony_ci+ * 17634be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 17635be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 17636be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17637be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 17638be168c0dSopenharmony_ci+ * limitations under the License. 17639be168c0dSopenharmony_ci+ */ 17640be168c0dSopenharmony_ci+ 17641be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_DYNAMIC_FP32_CODER_H_ 17642be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_DYNAMIC_FP32_CODER_H_ 17643be168c0dSopenharmony_ci+#include <vector> 17644be168c0dSopenharmony_ci+#include <string> 17645be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 17646be168c0dSopenharmony_ci+#include "nnacl/transpose_parameter.h" 17647be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" 17648be168c0dSopenharmony_ci+ 17649be168c0dSopenharmony_ci+namespace mindspore::lite::micro::nnacl { 17650be168c0dSopenharmony_ci+class TransposeDynamicFp32Coder : public OperatorCoder { 17651be168c0dSopenharmony_ci+ public: 17652be168c0dSopenharmony_ci+ TransposeDynamicFp32Coder(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, 17653be168c0dSopenharmony_ci+ const LiteGraph::Node *node, size_t node_index, Target target) 17654be168c0dSopenharmony_ci+ : OperatorCoder(in_tensors, out_tensors, node, node_index, target) {} 17655be168c0dSopenharmony_ci+ 17656be168c0dSopenharmony_ci+ ~TransposeDynamicFp32Coder() override = default; 17657be168c0dSopenharmony_ci+ 17658be168c0dSopenharmony_ci+ int Prepare(CoderContext *const context) override; 17659be168c0dSopenharmony_ci+ 17660be168c0dSopenharmony_ci+ int DoCode(CoderContext *const context) override; 17661be168c0dSopenharmony_ci+ 17662be168c0dSopenharmony_ci+ protected: 17663be168c0dSopenharmony_ci+ int Init(); 17664be168c0dSopenharmony_ci+ int ResetStatus(); 17665be168c0dSopenharmony_ci+ int ComputeOfflineInfo(); 17666be168c0dSopenharmony_ci+ TransposeParameter *param_{nullptr}; 17667be168c0dSopenharmony_ci+ TransposeDynamicParameter dynamic_param_; 17668be168c0dSopenharmony_ci+ int dims_{0}; 17669be168c0dSopenharmony_ci+ std::vector<std::string> in_shapes_; 17670be168c0dSopenharmony_ci+ std::vector<std::string> out_shapes_; 17671be168c0dSopenharmony_ci+}; 17672be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro::nnacl 17673be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_NNACL_FP32_TRANSPOSE_DYNAMIC_FP32_CODER_H_ 17674be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h 17675be168c0dSopenharmony_ciindex dffaf14b..fa59e483 100644 17676be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h 17677be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder.h 17678be168c0dSopenharmony_ci@@ -28,6 +28,8 @@ 17679be168c0dSopenharmony_ci #include "securec/include/securec.h" 17680be168c0dSopenharmony_ci #include "tools/converter/micro/coder/opcoders/op_coder_register.h" 17681be168c0dSopenharmony_ci #include "tools/converter/micro/coder/log.h" 17682be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/shape_info_container.h" 17683be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/dynamic_mem_manager.h" 17684be168c0dSopenharmony_ci 17685be168c0dSopenharmony_ci namespace mindspore::lite::micro { 17686be168c0dSopenharmony_ci constexpr int kPrecision = 19; 17687be168c0dSopenharmony_ci@@ -71,6 +73,8 @@ class OperatorCoder { 17688be168c0dSopenharmony_ci 17689be168c0dSopenharmony_ci void set_parameter(OpParameter *parameter); 17690be168c0dSopenharmony_ci 17691be168c0dSopenharmony_ci+ OpParameter *get_parameter() const { return parameter_; } 17692be168c0dSopenharmony_ci+ 17693be168c0dSopenharmony_ci const LiteGraph::Node *node() const { return this->node_; } 17694be168c0dSopenharmony_ci 17695be168c0dSopenharmony_ci void AddInitialParameters(Tensor *parameter) { initial_parameters_.push_back(parameter); } 17696be168c0dSopenharmony_ci@@ -88,6 +92,12 @@ class OperatorCoder { 17697be168c0dSopenharmony_ci 17698be168c0dSopenharmony_ci void set_thread_num(int thread_num); 17699be168c0dSopenharmony_ci 17700be168c0dSopenharmony_ci+ void set_shape_info_container(ShapeInfoContainer *shape_info_container) { 17701be168c0dSopenharmony_ci+ shape_info_container_ = shape_info_container; 17702be168c0dSopenharmony_ci+ } 17703be168c0dSopenharmony_ci+ 17704be168c0dSopenharmony_ci+ void set_dynamic_mem_manager(DynamicMemManager *dynamic_mem_manager) { dynamic_mem_manager_ = dynamic_mem_manager; } 17705be168c0dSopenharmony_ci+ 17706be168c0dSopenharmony_ci protected: 17707be168c0dSopenharmony_ci std::vector<Tensor *> input_tensors_; 17708be168c0dSopenharmony_ci std::vector<Tensor *> output_tensors_; 17709be168c0dSopenharmony_ci@@ -103,6 +113,8 @@ class OperatorCoder { 17710be168c0dSopenharmony_ci bool support_parallel_{false}; 17711be168c0dSopenharmony_ci int thread_num_{1}; 17712be168c0dSopenharmony_ci int schema_version_ = lite::SCHEMA_VERSION::SCHEMA_CUR; 17713be168c0dSopenharmony_ci+ ShapeInfoContainer *shape_info_container_{nullptr}; 17714be168c0dSopenharmony_ci+ DynamicMemManager *dynamic_mem_manager_{nullptr}; 17715be168c0dSopenharmony_ci 17716be168c0dSopenharmony_ci private: 17717be168c0dSopenharmony_ci size_t node_index_{0}; 17718be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 17719be168c0dSopenharmony_ciindex 45b2e37f..e2d70c12 100644 17720be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 17721be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.cc 17722be168c0dSopenharmony_ci@@ -35,7 +35,7 @@ std::unique_ptr<OperatorCoder> OpCoderBuilder::build(int schema_version) { 17723be168c0dSopenharmony_ci } 17724be168c0dSopenharmony_ci coder_key = CoderKey(target_, data_type_, schema::PrimitiveType_Custom, custom_type->str()); 17725be168c0dSopenharmony_ci } 17726be168c0dSopenharmony_ci- CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key); 17727be168c0dSopenharmony_ci+ CoderCreatorFunc creator_func = OpCoderFactory::GetInstance()->FindOpCoder(coder_key, dynamic_); 17728be168c0dSopenharmony_ci if (creator_func == nullptr) { 17729be168c0dSopenharmony_ci MS_LOG(ERROR) << "caught unsupported layer: " << node_->name_; 17730be168c0dSopenharmony_ci return nullptr; 17731be168c0dSopenharmony_ci@@ -125,5 +125,10 @@ OpCoderBuilder &OpCoderBuilder::is_builtin_custom(bool builtin_custom) { 17732be168c0dSopenharmony_ci return *this; 17733be168c0dSopenharmony_ci } 17734be168c0dSopenharmony_ci 17735be168c0dSopenharmony_ci+OpCoderBuilder &OpCoderBuilder::is_dynamic(bool dynamic) { 17736be168c0dSopenharmony_ci+ dynamic_ = dynamic; 17737be168c0dSopenharmony_ci+ return *this; 17738be168c0dSopenharmony_ci+} 17739be168c0dSopenharmony_ci+ 17740be168c0dSopenharmony_ci void OpCoderBuilder::Reset() {} 17741be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 17742be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 17743be168c0dSopenharmony_ciindex d85f1c32..bdd815ef 100644 17744be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 17745be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_builder.h 17746be168c0dSopenharmony_ci@@ -48,6 +48,8 @@ class OpCoderBuilder { 17747be168c0dSopenharmony_ci 17748be168c0dSopenharmony_ci OpCoderBuilder &is_builtin_custom(bool builtin_custom); 17749be168c0dSopenharmony_ci 17750be168c0dSopenharmony_ci+ OpCoderBuilder &is_dynamic(bool dynamic); 17751be168c0dSopenharmony_ci+ 17752be168c0dSopenharmony_ci void Reset(); 17753be168c0dSopenharmony_ci 17754be168c0dSopenharmony_ci private: 17755be168c0dSopenharmony_ci@@ -74,6 +76,8 @@ class OpCoderBuilder { 17756be168c0dSopenharmony_ci bool support_parallel_{false}; 17757be168c0dSopenharmony_ci 17758be168c0dSopenharmony_ci bool builtin_custom_{false}; 17759be168c0dSopenharmony_ci+ 17760be168c0dSopenharmony_ci+ bool dynamic_{false}; 17761be168c0dSopenharmony_ci }; 17762be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 17763be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_BUILDER_H_ 17764be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 17765be168c0dSopenharmony_ciindex cf26d51d..1dac9c73 100644 17766be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 17767be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.cc 17768be168c0dSopenharmony_ci@@ -37,33 +37,38 @@ OpCoderFactory *OpCoderFactory::GetInstance() { 17769be168c0dSopenharmony_ci } 17770be168c0dSopenharmony_ci 17771be168c0dSopenharmony_ci int OpCoderFactory::RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17772be168c0dSopenharmony_ci- const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func) { 17773be168c0dSopenharmony_ci+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func, 17774be168c0dSopenharmony_ci+ bool dynamic) { 17775be168c0dSopenharmony_ci+ auto &op_sets = dynamic ? dynamic_opcoder_sets_ : static_opcoder_sets_; 17776be168c0dSopenharmony_ci // check key 17777be168c0dSopenharmony_ci CoderKey key(target, data_type, operator_type, builtin_custom_type); 17778be168c0dSopenharmony_ci // insert pair to registry 17779be168c0dSopenharmony_ci- if (this->opcoder_sets_.find(key) != this->opcoder_sets_.end()) { 17780be168c0dSopenharmony_ci+ if (op_sets.find(key) != op_sets.end()) { 17781be168c0dSopenharmony_ci MS_LOG(ERROR) << "coder already exist: " << key.ToString(); 17782be168c0dSopenharmony_ci return RET_ERROR; 17783be168c0dSopenharmony_ci } 17784be168c0dSopenharmony_ci- this->opcoder_sets_.insert(std::pair<CoderKey, CoderCreatorFunc>(key, creator_func)); 17785be168c0dSopenharmony_ci+ op_sets.insert(std::pair<CoderKey, CoderCreatorFunc>(key, creator_func)); 17786be168c0dSopenharmony_ci return RET_OK; 17787be168c0dSopenharmony_ci } 17788be168c0dSopenharmony_ci 17789be168c0dSopenharmony_ci-CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key) { 17790be168c0dSopenharmony_ci- auto iterator = this->opcoder_sets_.find(key); 17791be168c0dSopenharmony_ci- if (iterator != this->opcoder_sets_.end()) { 17792be168c0dSopenharmony_ci+CoderCreatorFunc OpCoderFactory::FindOpCoder(const CoderKey &key, bool dynamic) { 17793be168c0dSopenharmony_ci+ const auto &op_sets = dynamic ? dynamic_opcoder_sets_ : static_opcoder_sets_; 17794be168c0dSopenharmony_ci+ auto iterator = op_sets.find(key); 17795be168c0dSopenharmony_ci+ if (iterator != op_sets.end()) { 17796be168c0dSopenharmony_ci return iterator->second; 17797be168c0dSopenharmony_ci } 17798be168c0dSopenharmony_ci // matching kAllTargets 17799be168c0dSopenharmony_ci- iterator = this->opcoder_sets_.find(key.AllKey()); 17800be168c0dSopenharmony_ci- if (iterator != this->opcoder_sets_.end()) { 17801be168c0dSopenharmony_ci+ iterator = op_sets.find(key.AllKey()); 17802be168c0dSopenharmony_ci+ if (iterator != op_sets.end()) { 17803be168c0dSopenharmony_ci return iterator->second; 17804be168c0dSopenharmony_ci } 17805be168c0dSopenharmony_ci return nullptr; 17806be168c0dSopenharmony_ci } 17807be168c0dSopenharmony_ci 17808be168c0dSopenharmony_ci OpCoderRegister::OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17809be168c0dSopenharmony_ci- const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc) { 17810be168c0dSopenharmony_ci- OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc); 17811be168c0dSopenharmony_ci+ const std::string &builtin_custom_type, const CoderCreatorFunc &creatorFunc, 17812be168c0dSopenharmony_ci+ bool dynamic) { 17813be168c0dSopenharmony_ci+ OpCoderFactory::GetInstance()->RegistOpCoder(target, data_type, operator_type, builtin_custom_type, creatorFunc, 17814be168c0dSopenharmony_ci+ dynamic); 17815be168c0dSopenharmony_ci } 17816be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 17817be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 17818be168c0dSopenharmony_ciindex 30c8a64d..b616e287 100644 17819be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 17820be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/op_coder_register.h 17821be168c0dSopenharmony_ci@@ -65,15 +65,19 @@ class OpCoderFactory { 17822be168c0dSopenharmony_ci static OpCoderFactory *GetInstance(); 17823be168c0dSopenharmony_ci 17824be168c0dSopenharmony_ci int RegistOpCoder(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17825be168c0dSopenharmony_ci- const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); 17826be168c0dSopenharmony_ci+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func, bool dynamic); 17827be168c0dSopenharmony_ci 17828be168c0dSopenharmony_ci- CoderCreatorFunc FindOpCoder(const CoderKey &key); 17829be168c0dSopenharmony_ci+ CoderCreatorFunc FindOpCoder(const CoderKey &key, bool dynamic = false); 17830be168c0dSopenharmony_ci 17831be168c0dSopenharmony_ci- ~OpCoderFactory() { opcoder_sets_.clear(); } 17832be168c0dSopenharmony_ci+ ~OpCoderFactory() { 17833be168c0dSopenharmony_ci+ static_opcoder_sets_.clear(); 17834be168c0dSopenharmony_ci+ dynamic_opcoder_sets_.clear(); 17835be168c0dSopenharmony_ci+ } 17836be168c0dSopenharmony_ci 17837be168c0dSopenharmony_ci private: 17838be168c0dSopenharmony_ci // target || data type || primitive type 17839be168c0dSopenharmony_ci- std::map<CoderKey, CoderCreatorFunc> opcoder_sets_; 17840be168c0dSopenharmony_ci+ std::map<CoderKey, CoderCreatorFunc> static_opcoder_sets_; 17841be168c0dSopenharmony_ci+ std::map<CoderKey, CoderCreatorFunc> dynamic_opcoder_sets_; 17842be168c0dSopenharmony_ci }; 17843be168c0dSopenharmony_ci 17844be168c0dSopenharmony_ci class OpCoderRegister { 17845be168c0dSopenharmony_ci@@ -81,16 +85,20 @@ class OpCoderRegister { 17846be168c0dSopenharmony_ci OpCoderRegister() = delete; 17847be168c0dSopenharmony_ci 17848be168c0dSopenharmony_ci OpCoderRegister(Target target, TypeId data_type, schema::PrimitiveType operator_type, 17849be168c0dSopenharmony_ci- const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func); 17850be168c0dSopenharmony_ci+ const std::string &builtin_custom_type, const CoderCreatorFunc &creator_func, bool dynamic = false); 17851be168c0dSopenharmony_ci 17852be168c0dSopenharmony_ci ~OpCoderRegister() = default; 17853be168c0dSopenharmony_ci }; 17854be168c0dSopenharmony_ci-#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 17855be168c0dSopenharmony_ci- static OpCoderRegister g_##target##data_type##operator_type##Creator(target, data_type, operator_type, "", \ 17856be168c0dSopenharmony_ci- creator_func); 17857be168c0dSopenharmony_ci+#define REG_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 17858be168c0dSopenharmony_ci+ static OpCoderRegister g_##target##data_type##operator_type##StaticCreator(target, data_type, operator_type, "", \ 17859be168c0dSopenharmony_ci+ creator_func); 17860be168c0dSopenharmony_ci 17861be168c0dSopenharmony_ci #define REG_BUILIN_CUSTOM_CODER(target, data_type, custom_type, creator_func) \ 17862be168c0dSopenharmony_ci static OpCoderRegister g_##target##data_type##operator_type##Creator( \ 17863be168c0dSopenharmony_ci target, data_type, schema::PrimitiveType_Custom, custom_type, creator_func); 17864be168c0dSopenharmony_ci+ 17865be168c0dSopenharmony_ci+#define REG_DYNAMIC_OPERATOR_CODER(target, data_type, operator_type, creator_func) \ 17866be168c0dSopenharmony_ci+ static OpCoderRegister g_##target##data_type##operator_type##DynamicCreator(target, data_type, operator_type, "", \ 17867be168c0dSopenharmony_ci+ creator_func, true); 17868be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 17869be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_OPCODERS_OP_CODER_REGISTER_H_ 17870be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 17871be168c0dSopenharmony_ciindex a3743b48..920f2723 100644 17872be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 17873be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.cc 17874be168c0dSopenharmony_ci@@ -38,6 +38,15 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const PoolingCompu 17875be168c0dSopenharmony_ci pooling_compute.maxf); 17876be168c0dSopenharmony_ci } 17877be168c0dSopenharmony_ci 17878be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const PoolingComputeParam &pooling_compute, 17879be168c0dSopenharmony_ci+ const PoolingDynamicParameter &dynamic_pooling_param) { 17880be168c0dSopenharmony_ci+ CodeBaseStruct<false>("PoolingComputeParam", name, pooling_compute.input_w_, pooling_compute.input_h_, 17881be168c0dSopenharmony_ci+ dynamic_pooling_param.input_batch_, pooling_compute.input_channel_, pooling_compute.output_w_, 17882be168c0dSopenharmony_ci+ pooling_compute.output_h_, dynamic_pooling_param.output_batch_, pooling_compute.output_channel_, 17883be168c0dSopenharmony_ci+ pooling_compute.window_w_, pooling_compute.window_h_, pooling_compute.minf, 17884be168c0dSopenharmony_ci+ pooling_compute.maxf); 17885be168c0dSopenharmony_ci+} 17886be168c0dSopenharmony_ci+ 17887be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const BatchNormParameter &batch_norm_parameter) { 17888be168c0dSopenharmony_ci CodeBaseStruct("BatchNormParameter", name, batch_norm_parameter.op_parameter_, batch_norm_parameter.epsilon_, 17889be168c0dSopenharmony_ci batch_norm_parameter.momentum_, batch_norm_parameter.unit_, batch_norm_parameter.units_, 17890be168c0dSopenharmony_ci@@ -85,6 +94,29 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParamete 17891be168c0dSopenharmony_ci conv_parameter.output_padding_w_, conv_parameter.output_padding_h_); 17892be168c0dSopenharmony_ci } 17893be168c0dSopenharmony_ci 17894be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParameter &conv_parameter, 17895be168c0dSopenharmony_ci+ const ConvDynamicParameter &dynamic_conv_param) { 17896be168c0dSopenharmony_ci+ CodeBaseStruct<false>( 17897be168c0dSopenharmony_ci+ "ConvParameter", name, conv_parameter.op_parameter_, "{0}", conv_parameter.kernel_h_, conv_parameter.kernel_w_, 17898be168c0dSopenharmony_ci+ conv_parameter.stride_h_, conv_parameter.stride_w_, conv_parameter.dilation_h_, conv_parameter.dilation_w_, 17899be168c0dSopenharmony_ci+ conv_parameter.pad_u_, conv_parameter.pad_d_, conv_parameter.pad_l_, conv_parameter.pad_r_, conv_parameter.group_, 17900be168c0dSopenharmony_ci+ conv_parameter.tile_num_, dynamic_conv_param.input_batch_, conv_parameter.input_h_, conv_parameter.input_w_, 17901be168c0dSopenharmony_ci+ conv_parameter.input_channel_, dynamic_conv_param.output_batch_, conv_parameter.output_h_, conv_parameter.output_w_, 17902be168c0dSopenharmony_ci+ conv_parameter.output_channel_, conv_parameter.thread_num_, conv_parameter.input_unit_, conv_parameter.output_unit_, 17903be168c0dSopenharmony_ci+ conv_parameter.pad_mode_, conv_parameter.act_type_, conv_parameter.channel_multiplie_, 17904be168c0dSopenharmony_ci+ conv_parameter.output_padding_w_, conv_parameter.output_padding_h_); 17905be168c0dSopenharmony_ci+} 17906be168c0dSopenharmony_ci+ 17907be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter) { 17908be168c0dSopenharmony_ci+ CodeBaseStruct<false>( 17909be168c0dSopenharmony_ci+ "MatMulParameter", name, mat_mul_parameter.op_parameter_, mat_mul_parameter.has_bias_, mat_mul_parameter.use_axis_, 17910be168c0dSopenharmony_ci+ mat_mul_parameter.a_transpose_, mat_mul_parameter.b_transpose_, mat_mul_parameter.act_type_, mat_mul_parameter.row_, 17911be168c0dSopenharmony_ci+ mat_mul_parameter.col_, mat_mul_parameter.row_4_, mat_mul_parameter.row_16_, mat_mul_parameter.row_align_, 17912be168c0dSopenharmony_ci+ mat_mul_parameter.col_8_, mat_mul_parameter.col_align_, mat_mul_parameter.deep_, mat_mul_parameter.deep_4_, 17913be168c0dSopenharmony_ci+ mat_mul_parameter.deep_16_, mat_mul_parameter.deep_align_, mat_mul_parameter.batch, mat_mul_parameter.a_const_, 17914be168c0dSopenharmony_ci+ mat_mul_parameter.b_const_, mat_mul_parameter.axis_, mat_mul_parameter.matmul_type_); 17915be168c0dSopenharmony_ci+} 17916be168c0dSopenharmony_ci+ 17917be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const MicroMatmulParameter µ_matmul_parameter) { 17918be168c0dSopenharmony_ci CodeBaseStruct<false>("MicroMatmulParameter", name, micro_matmul_parameter.act_type_, 17919be168c0dSopenharmony_ci micro_matmul_parameter.thread_num_, micro_matmul_parameter.row_, micro_matmul_parameter.col_, 17920be168c0dSopenharmony_ci@@ -102,18 +134,41 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleStruct 17921be168c0dSopenharmony_ci scale_struct.outer_size_, scale_struct.inner_size_); 17922be168c0dSopenharmony_ci } 17923be168c0dSopenharmony_ci 17924be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const ScaleStruct &scale_struct, 17925be168c0dSopenharmony_ci+ const ScaleDynamicParameter &dynamic_scale_param) { 17926be168c0dSopenharmony_ci+ CodeBaseStruct<false>("ScaleStruct", name, "{}", scale_struct.axis_, scale_struct.data_type_, 17927be168c0dSopenharmony_ci+ dynamic_scale_param.axis_size_, dynamic_scale_param.outer_size_, 17928be168c0dSopenharmony_ci+ dynamic_scale_param.inner_size_); 17929be168c0dSopenharmony_ci+} 17930be168c0dSopenharmony_ci+ 17931be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const SliceParameter &slice_parameter) { 17932be168c0dSopenharmony_ci CodeBaseStruct("SliceParameter", name, slice_parameter.op_parameter_, ToString(slice_parameter.shape_), 17933be168c0dSopenharmony_ci ToString(slice_parameter.begin_), ToString(slice_parameter.end_), ToString(slice_parameter.size_), 17934be168c0dSopenharmony_ci "{0}", slice_parameter.param_length_); 17935be168c0dSopenharmony_ci } 17936be168c0dSopenharmony_ci 17937be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const SliceParameter &slice_parameter, 17938be168c0dSopenharmony_ci+ const SliceDynamicParameter &dynamic_slice_param) { 17939be168c0dSopenharmony_ci+ CodeBaseStruct<false>("SliceParameter", name, slice_parameter.op_parameter_, dynamic_slice_param.shape_, 17940be168c0dSopenharmony_ci+ ToString(slice_parameter.begin_), dynamic_slice_param.end_, dynamic_slice_param.size_, "{0}", 17941be168c0dSopenharmony_ci+ slice_parameter.param_length_); 17942be168c0dSopenharmony_ci+} 17943be168c0dSopenharmony_ci+ 17944be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const SplitParameter &split_parameter) { 17945be168c0dSopenharmony_ci CodeBaseStruct("SplitParameter", name, split_parameter.op_parameter_, split_parameter.num_split_, "split_sizes", 17946be168c0dSopenharmony_ci split_parameter.split_dim_, ToString(split_parameter.strides_), "{0}", split_parameter.n_dims_, 17947be168c0dSopenharmony_ci split_parameter.split_count_); 17948be168c0dSopenharmony_ci } 17949be168c0dSopenharmony_ci 17950be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const SplitParameter &split_parameter, 17951be168c0dSopenharmony_ci+ const SplitDynamicParameter &dynamic_split_param) { 17952be168c0dSopenharmony_ci+ CodeArray("split_sizes", split_parameter.split_sizes_, split_parameter.num_split_, false); 17953be168c0dSopenharmony_ci+ CodeBaseStruct<false>("SplitParameter", name, split_parameter.op_parameter_, split_parameter.num_split_, nullptr, 17954be168c0dSopenharmony_ci+ split_parameter.split_dim_, dynamic_split_param.strides_, "{0}", split_parameter.n_dims_, 17955be168c0dSopenharmony_ci+ dynamic_split_param.split_count_); 17956be168c0dSopenharmony_ci+ code << " " << name << ".split_sizes_ = split_sizes;\n"; 17957be168c0dSopenharmony_ci+} 17958be168c0dSopenharmony_ci+ 17959be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const TileParameter &tile_parameter) { 17960be168c0dSopenharmony_ci CodeBaseStruct("TileParameter", name, tile_parameter.op_parameter_, ToString(tile_parameter.multiples_), 17961be168c0dSopenharmony_ci ToString(tile_parameter.in_shape_), ToString(tile_parameter.out_shape_), 17962be168c0dSopenharmony_ci@@ -127,12 +182,32 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransposePar 17963be168c0dSopenharmony_ci ToString(transpose_parameter.out_strides_), transpose_parameter.num_axes_, transpose_parameter.data_num_); 17964be168c0dSopenharmony_ci } 17965be168c0dSopenharmony_ci 17966be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const TransposeParameter &transpose_param, 17967be168c0dSopenharmony_ci+ const TransposeDynamicParameter &dynamic_transpose_param) { 17968be168c0dSopenharmony_ci+ CodeBaseStruct<false>("TransposeParameter", name, transpose_param.op_parameter_, ToString(transpose_param.perm_), 17969be168c0dSopenharmony_ci+ transpose_param.perm_size_, transpose_param.conjugate_, dynamic_transpose_param.strides_, 17970be168c0dSopenharmony_ci+ dynamic_transpose_param.out_strides_, transpose_param.num_axes_, 17971be168c0dSopenharmony_ci+ dynamic_transpose_param.data_num_); 17972be168c0dSopenharmony_ci+} 17973be168c0dSopenharmony_ci+ 17974be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const LstmParameter &lstm_parameter) { 17975be168c0dSopenharmony_ci CodeBaseStruct("LstmParameter", name, lstm_parameter.op_parameter_, lstm_parameter.input_size_, 17976be168c0dSopenharmony_ci- lstm_parameter.hidden_size_, lstm_parameter.project_size_, lstm_parameter.seq_len_, 17977be168c0dSopenharmony_ci- lstm_parameter.batch_, lstm_parameter.output_step_, lstm_parameter.bidirectional_, 17978be168c0dSopenharmony_ci- lstm_parameter.zoneout_cell_, lstm_parameter.zoneout_hidden_, lstm_parameter.input_row_align_, 17979be168c0dSopenharmony_ci- lstm_parameter.input_col_align_, lstm_parameter.state_row_align_, lstm_parameter.state_col_align_); 17980be168c0dSopenharmony_ci+ lstm_parameter.hidden_size_, lstm_parameter.project_size_, lstm_parameter.output_size_, 17981be168c0dSopenharmony_ci+ lstm_parameter.seq_len_, lstm_parameter.batch_, lstm_parameter.output_step_, 17982be168c0dSopenharmony_ci+ lstm_parameter.bidirectional_, lstm_parameter.zoneout_cell_, lstm_parameter.zoneout_hidden_, 17983be168c0dSopenharmony_ci+ lstm_parameter.input_row_align_, lstm_parameter.input_col_align_, lstm_parameter.state_row_align_, 17984be168c0dSopenharmony_ci+ lstm_parameter.state_col_align_, lstm_parameter.proj_col_align_, lstm_parameter.has_bias_); 17985be168c0dSopenharmony_ci+} 17986be168c0dSopenharmony_ci+ 17987be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const LstmParameter &lstm_parameter, 17988be168c0dSopenharmony_ci+ const DynamicLstmParameter &dynamic_lstm_param) { 17989be168c0dSopenharmony_ci+ CodeBaseStruct("LstmParameter", name, lstm_parameter.op_parameter_, lstm_parameter.input_size_, 17990be168c0dSopenharmony_ci+ lstm_parameter.hidden_size_, lstm_parameter.project_size_, lstm_parameter.output_size_, 17991be168c0dSopenharmony_ci+ dynamic_lstm_param.seq_len_, dynamic_lstm_param.batch_, dynamic_lstm_param.output_step_, 17992be168c0dSopenharmony_ci+ lstm_parameter.bidirectional_, lstm_parameter.zoneout_cell_, lstm_parameter.zoneout_hidden_, 17993be168c0dSopenharmony_ci+ dynamic_lstm_param.input_row_align_, lstm_parameter.input_col_align_, 17994be168c0dSopenharmony_ci+ dynamic_lstm_param.state_row_align_, lstm_parameter.state_col_align_, lstm_parameter.proj_col_align_, 17995be168c0dSopenharmony_ci+ lstm_parameter.has_bias_); 17996be168c0dSopenharmony_ci } 17997be168c0dSopenharmony_ci 17998be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const DeQuantArg &de_quant_arg) { 17999be168c0dSopenharmony_ci@@ -165,6 +240,17 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const StridedSlice 18000be168c0dSopenharmony_ci strided_slice_parameter.newAxisMask_, strided_slice_parameter.shrinkAxisMask_); 18001be168c0dSopenharmony_ci } 18002be168c0dSopenharmony_ci 18003be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter, 18004be168c0dSopenharmony_ci+ const StridedSliceDynamicParameter &dynamic_strided_slice_param) { 18005be168c0dSopenharmony_ci+ CodeBaseStruct<false>("StridedSliceParameter", name, strided_slice_parameter.op_parameter_, 18006be168c0dSopenharmony_ci+ ToString(strided_slice_parameter.begins_), ToString(strided_slice_parameter.ends_), 18007be168c0dSopenharmony_ci+ ToString(strided_slice_parameter.strides_), strided_slice_parameter.isScale, 18008be168c0dSopenharmony_ci+ strided_slice_parameter.in_shape_length_, dynamic_strided_slice_param.in_shape_, 18009be168c0dSopenharmony_ci+ strided_slice_parameter.num_axes_, strided_slice_parameter.data_type, 18010be168c0dSopenharmony_ci+ strided_slice_parameter.begins_mask_, strided_slice_parameter.ellipsisMask_, 18011be168c0dSopenharmony_ci+ strided_slice_parameter.newAxisMask_, strided_slice_parameter.shrinkAxisMask_); 18012be168c0dSopenharmony_ci+} 18013be168c0dSopenharmony_ci+ 18014be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const ArithmeticWrapperInfo &arithmetic_wrapper_info) { 18015be168c0dSopenharmony_ci CodeBaseStruct("ArithmeticWrapperInfo", name, arithmetic_wrapper_info.offset0_, arithmetic_wrapper_info.stride0_, 18016be168c0dSopenharmony_ci arithmetic_wrapper_info.offset1_, arithmetic_wrapper_info.stride1_, 18017be168c0dSopenharmony_ci@@ -207,6 +293,12 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastSha 18018be168c0dSopenharmony_ci ToString(param.output_shape_), param.output_shape_size_); 18019be168c0dSopenharmony_ci } 18020be168c0dSopenharmony_ci 18021be168c0dSopenharmony_ci+void NNaclFp32Serializer::CodeStruct(const std::string &name, const BroadcastShapeInfo &op_param, 18022be168c0dSopenharmony_ci+ const BroadcastDynamicShapeInfo &dynamic_param) { 18023be168c0dSopenharmony_ci+ CodeBaseStruct<false>("BroadcastShapeInfo", name, dynamic_param.input_shape_, op_param.input_shape_size_, 18024be168c0dSopenharmony_ci+ dynamic_param.output_shape_, op_param.output_shape_size_); 18025be168c0dSopenharmony_ci+} 18026be168c0dSopenharmony_ci+ 18027be168c0dSopenharmony_ci void NNaclFp32Serializer::CodeStruct(const std::string &name, const CustomGruParameter &op_param) { 18028be168c0dSopenharmony_ci CodeBaseStruct<false>("CustomGruParameter", name, op_param.op_parameter_, op_param.num_step, op_param.batch_size, 18029be168c0dSopenharmony_ci op_param.input_size, op_param.hidden_size); 18030be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 18031be168c0dSopenharmony_ciindex d1435dea..2b1536c6 100644 18032be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 18033be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/opcoders/serializers/nnacl_serializer/nnacl_fp32_serializer.h 18034be168c0dSopenharmony_ci@@ -53,6 +53,15 @@ 18035be168c0dSopenharmony_ci #include "nnacl/kernel/pooling.h" 18036be168c0dSopenharmony_ci #include "nnacl/kernel/layer_norm.h" 18037be168c0dSopenharmony_ci #include "nnacl/kernel/fill.h" 18038be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/dynamic_lstm_parameter.h" 18039be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/transpose_dynamic_parameter.h" 18040be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/slice_dynamic_parameter.h" 18041be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/split_dynamic_parameter.h" 18042be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/strided_slice_dynamic_parameter.h" 18043be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/scale_dynamic_parameter.h" 18044be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/conv_dynamic_parameter.h" 18045be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/arithmetic_dynamic_parameter.h" 18046be168c0dSopenharmony_ci+#include "coder/opcoders/nnacl/dynamic_parameter/pooling_dynamic_parameter.h" 18047be168c0dSopenharmony_ci 18048be168c0dSopenharmony_ci namespace mindspore::lite::micro::nnacl { 18049be168c0dSopenharmony_ci class NNaclFp32Serializer : public Serializer { 18050be168c0dSopenharmony_ci@@ -66,6 +75,7 @@ class NNaclFp32Serializer : public Serializer { 18051be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const InstanceNormParameter ¶m); 18052be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const ArithmeticParameter &arithmetic_parameter); 18053be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const ConvParameter &conv_parameter); 18054be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const MatMulParameter &mat_mul_parameter); 18055be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const MicroMatmulParameter µ_matmul_parameter); 18056be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const LstmParameter &lstm_parameter); 18057be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const ScaleStruct &scale_struct); 18058be168c0dSopenharmony_ci@@ -89,6 +99,24 @@ class NNaclFp32Serializer : public Serializer { 18059be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const SlidingWindowParam ¶m); 18060be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const UnstackParameter ¶m); 18061be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const FillStruct ¶m); 18062be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const TransposeParameter &transpose_param, 18063be168c0dSopenharmony_ci+ const TransposeDynamicParameter &dynamic_transpose_param); 18064be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const SplitParameter &split_parameter, 18065be168c0dSopenharmony_ci+ const SplitDynamicParameter &dynamic_split_param); 18066be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const BroadcastShapeInfo ¶m, 18067be168c0dSopenharmony_ci+ const BroadcastDynamicShapeInfo &dynamic_param); 18068be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const LstmParameter &lstm_param, 18069be168c0dSopenharmony_ci+ const DynamicLstmParameter &dynamic_lstm_param); 18070be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const SliceParameter &slice_parameter, 18071be168c0dSopenharmony_ci+ const SliceDynamicParameter &dynamic_slice_param); 18072be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const StridedSliceParameter &strided_slice_parameter, 18073be168c0dSopenharmony_ci+ const StridedSliceDynamicParameter &dynamic_strided_slice_param); 18074be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const ScaleStruct &scale_struct, 18075be168c0dSopenharmony_ci+ const ScaleDynamicParameter &dynamic_scale_param); 18076be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const ConvParameter &conv_parameter, 18077be168c0dSopenharmony_ci+ const ConvDynamicParameter &dynamic_conv_param); 18078be168c0dSopenharmony_ci+ void CodeStruct(const std::string &name, const PoolingComputeParam &pooling_compute, 18079be168c0dSopenharmony_ci+ const PoolingDynamicParameter &dynamic_pooling_param); 18080be168c0dSopenharmony_ci void CodeStruct(const std::string &name, const int *list, int size); 18081be168c0dSopenharmony_ci void CodeArrayStruct(const std::string &name, TensorC *tensorC, std::vector<Tensor *> tensor); 18082be168c0dSopenharmony_ci 18083be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/session.cc b/mindspore/lite/tools/converter/micro/coder/session.cc 18084be168c0dSopenharmony_ciindex 55df7a22..374f662d 100644 18085be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/session.cc 18086be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/session.cc 18087be168c0dSopenharmony_ci@@ -75,7 +75,10 @@ int CoderSession::PassArgsToContext(const std::string &model_name) { 18088be168c0dSopenharmony_ci context_->set_total_buffer_size(final_total_size); 18089be168c0dSopenharmony_ci context_->set_graph_inputs(coder_graph_->input_tensors()); 18090be168c0dSopenharmony_ci context_->set_graph_outputs(coder_graph_->output_tensors()); 18091be168c0dSopenharmony_ci- if (Configurator::GetInstance()->debug_mode()) { 18092be168c0dSopenharmony_ci+ context_->set_shape_info_container(&shape_info_container_); 18093be168c0dSopenharmony_ci+ context_->set_dynamic_mem_manager(&dynamic_mem_manager_); 18094be168c0dSopenharmony_ci+ Configurator *config = Configurator::GetInstance(); 18095be168c0dSopenharmony_ci+ if (config->debug_mode()) { 18096be168c0dSopenharmony_ci std::vector<std::string> blocks; 18097be168c0dSopenharmony_ci blocks = AddDumpDataInfo(context_->code_blocks(), op_coders_); 18098be168c0dSopenharmony_ci if (blocks.size() == 0) { 18099be168c0dSopenharmony_ci@@ -100,7 +103,16 @@ int CoderSession::Preprocess() { 18100be168c0dSopenharmony_ci Configurator::GetInstance()->changeable_weights_name()); 18101be168c0dSopenharmony_ci MS_CHECK_RET_CODE(ret, "assign memory failed"); 18102be168c0dSopenharmony_ci 18103be168c0dSopenharmony_ci- // prepare, init model parameters 18104be168c0dSopenharmony_ci+ if (dynamic_) { 18105be168c0dSopenharmony_ci+ auto config = Configurator::GetInstance(); 18106be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(config != nullptr, RET_NULL_PTR, "Config is a nullptr."); 18107be168c0dSopenharmony_ci+ ret = shape_info_container_.Init(op_coders_, graph_inputs_shape_infos_); 18108be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Init ShapeInfoContainer failed."); 18109be168c0dSopenharmony_ci+ auto outputs = coder_graph_->output_tensors(); 18110be168c0dSopenharmony_ci+ ret = dynamic_mem_manager_.AllocDynamicMem(op_coders_, inputs, outputs, &shape_info_container_); 18111be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "DynamicMemManager AllocDynamicMem failed."); 18112be168c0dSopenharmony_ci+ } 18113be168c0dSopenharmony_ci+ // 2. prepare, init model parameters 18114be168c0dSopenharmony_ci for (const auto &op_coder : op_coders_) { 18115be168c0dSopenharmony_ci MS_CHECK_PTR(op_coder); 18116be168c0dSopenharmony_ci MS_LOG(DEBUG) << "prepare: " << op_coder->name(); 18117be168c0dSopenharmony_ci@@ -133,7 +145,7 @@ int CoderSession::Run(const std::string &model_name) { 18118be168c0dSopenharmony_ci ret = PassArgsToContext(model_name); 18119be168c0dSopenharmony_ci MS_CHECK_RET_CODE(ret, "PassArgsToContext failed"); 18120be168c0dSopenharmony_ci MS_LOG(INFO) << "run opcoders success"; 18121be168c0dSopenharmony_ci- return RET_OK; 18122be168c0dSopenharmony_ci+ return ret; 18123be168c0dSopenharmony_ci } 18124be168c0dSopenharmony_ci 18125be168c0dSopenharmony_ci int CoderSession::GenerateCode() { 18126be168c0dSopenharmony_ci@@ -161,6 +173,9 @@ int CoderSession::Init(const void *content, int size, const int model_index, boo 18127be168c0dSopenharmony_ci context_ = std::make_unique<CoderContext>(model_index); 18128be168c0dSopenharmony_ci context_->set_end_flag(end_flag); 18129be168c0dSopenharmony_ci enable_fp16_ = enable_fp16; 18130be168c0dSopenharmony_ci+ Configurator *config = Configurator::GetInstance(); 18131be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(config != nullptr, RET_NULL_PTR, "Config is a nullptr."); 18132be168c0dSopenharmony_ci+ dynamic_ = !config->graph_inputs_shape_infos().empty(); 18133be168c0dSopenharmony_ci MS_LOG(INFO) << "CoderSession::Init done"; 18134be168c0dSopenharmony_ci return RET_OK; 18135be168c0dSopenharmony_ci } 18136be168c0dSopenharmony_ci@@ -227,6 +242,7 @@ int CoderSession::InitTensorsRef() { 18137be168c0dSopenharmony_ci } 18138be168c0dSopenharmony_ci } 18139be168c0dSopenharmony_ci tensor->set_ref_count(refcount); 18140be168c0dSopenharmony_ci+ tensor->set_init_ref_count(refcount); 18141be168c0dSopenharmony_ci } 18142be168c0dSopenharmony_ci return RET_OK; 18143be168c0dSopenharmony_ci } 18144be168c0dSopenharmony_ci@@ -325,6 +341,7 @@ int CoderSession::CreateOpCoders() { 18145be168c0dSopenharmony_ci .input_indices(input_indices) 18146be168c0dSopenharmony_ci .output_indices(output_indices) 18147be168c0dSopenharmony_ci .is_builtin_custom(is_built_in_custom_op) 18148be168c0dSopenharmony_ci+ .is_dynamic(dynamic_) 18149be168c0dSopenharmony_ci .build(schema_version_); 18150be168c0dSopenharmony_ci if (op_coder == nullptr) { 18151be168c0dSopenharmony_ci coder_graph_->DumpUnSupportLayer(code_target); 18152be168c0dSopenharmony_ci@@ -348,6 +365,20 @@ int CoderSession::CompileGraph() { 18153be168c0dSopenharmony_ci MS_CHECK_RET_CODE(InitCodeGraph(), "InitGraphInOutTensors failed"); 18154be168c0dSopenharmony_ci MS_CHECK_RET_CODE(CreateOpCoders(), "CreateOpCoders failed!"); 18155be168c0dSopenharmony_ci MS_CHECK_RET_CODE(InitTensorsRef(), "InitTensorsRefcount failed!"); 18156be168c0dSopenharmony_ci+ if (dynamic_) { 18157be168c0dSopenharmony_ci+ Configurator::GetInstance()->set_dynamic_shape(true); 18158be168c0dSopenharmony_ci+ std::vector<lite::Tensor *> inputs = coder_graph_->input_tensors(); 18159be168c0dSopenharmony_ci+ auto &graph_inputs_shape_infos = Configurator::GetInstance()->graph_inputs_shape_infos(); 18160be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(inputs.size() == graph_inputs_shape_infos.size(), RET_ERROR, 18161be168c0dSopenharmony_ci+ "Config graph_inputs_shape's num cannot match."); 18162be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs.size(); ++i) { 18163be168c0dSopenharmony_ci+ graph_inputs_shape_infos_[inputs[i]] = graph_inputs_shape_infos[i]; 18164be168c0dSopenharmony_ci+ } 18165be168c0dSopenharmony_ci+ } 18166be168c0dSopenharmony_ci+ for (auto &op_coder : op_coders_) { 18167be168c0dSopenharmony_ci+ op_coder->set_shape_info_container(&shape_info_container_); 18168be168c0dSopenharmony_ci+ op_coder->set_dynamic_mem_manager(&dynamic_mem_manager_); 18169be168c0dSopenharmony_ci+ } 18170be168c0dSopenharmony_ci return RET_OK; 18171be168c0dSopenharmony_ci } 18172be168c0dSopenharmony_ci CoderSession::~CoderSession() { allocator_->Free(); } 18173be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/session.h b/mindspore/lite/tools/converter/micro/coder/session.h 18174be168c0dSopenharmony_ciindex 98a8d008..452e3245 100644 18175be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/session.h 18176be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/session.h 18177be168c0dSopenharmony_ci@@ -65,6 +65,10 @@ class CoderSession { 18178be168c0dSopenharmony_ci private: 18179be168c0dSopenharmony_ci int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; 18180be168c0dSopenharmony_ci bool enable_fp16_{false}; 18181be168c0dSopenharmony_ci+ bool dynamic_{false}; 18182be168c0dSopenharmony_ci+ DynamicMemManager dynamic_mem_manager_; 18183be168c0dSopenharmony_ci+ ShapeInfoContainer shape_info_container_; 18184be168c0dSopenharmony_ci+ std::map<Tensor *, std::vector<std::vector<int>>> graph_inputs_shape_infos_; 18185be168c0dSopenharmony_ci }; 18186be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 18187be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SESSION_H_ 18188be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/shape_info_container.cc b/mindspore/lite/tools/converter/micro/coder/shape_info_container.cc 18189be168c0dSopenharmony_cinew file mode 100644 18190be168c0dSopenharmony_ciindex 00000000..c914be6c 18191be168c0dSopenharmony_ci--- /dev/null 18192be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/shape_info_container.cc 18193be168c0dSopenharmony_ci@@ -0,0 +1,131 @@ 18194be168c0dSopenharmony_ci+/** 18195be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 18196be168c0dSopenharmony_ci+ * 18197be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 18198be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 18199be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 18200be168c0dSopenharmony_ci+ * 18201be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 18202be168c0dSopenharmony_ci+ * 18203be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 18204be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 18205be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18206be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 18207be168c0dSopenharmony_ci+ * limitations under the License. 18208be168c0dSopenharmony_ci+ */ 18209be168c0dSopenharmony_ci+ 18210be168c0dSopenharmony_ci+#include "coder/shape_info_container.h" 18211be168c0dSopenharmony_ci+#include "src/litert/infer_manager.h" 18212be168c0dSopenharmony_ci+#include "coder/opcoders/op_coder.h" 18213be168c0dSopenharmony_ci+#include "coder/utils/coder_utils.h" 18214be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 18215be168c0dSopenharmony_ci+ 18216be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 18217be168c0dSopenharmony_ci+int ShapeInfoContainer::Init(const std::vector<std::unique_ptr<OperatorCoder>> &nodes_coder, 18218be168c0dSopenharmony_ci+ const std::map<Tensor *, std::vector<std::vector<int>>> &graph_inputs) { 18219be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!graph_inputs.empty(), RET_ERROR, "Cannot get graph_inputs's shape-info"); 18220be168c0dSopenharmony_ci+ auto scene_num = graph_inputs.begin()->second.size(); 18221be168c0dSopenharmony_ci+ for (const auto &item : graph_inputs) { 18222be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(item.first, RET_NULL_PTR, "Find a nullptr in graph_inputs"); 18223be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(item.second.size() == scene_num, RET_ERROR, "Graph inputs are invalid."); 18224be168c0dSopenharmony_ci+ } 18225be168c0dSopenharmony_ci+ var_tensor_shapes_.insert(graph_inputs.begin(), graph_inputs.end()); 18226be168c0dSopenharmony_ci+ for (size_t i = 0; i < scene_num; ++i) { 18227be168c0dSopenharmony_ci+ for (const auto &item : graph_inputs) { 18228be168c0dSopenharmony_ci+ item.first->set_shape(item.second[i]); 18229be168c0dSopenharmony_ci+ } 18230be168c0dSopenharmony_ci+ for (const auto &node_coder : nodes_coder) { 18231be168c0dSopenharmony_ci+ auto in_tensors = node_coder->input_tensors(); 18232be168c0dSopenharmony_ci+ auto out_tensors = node_coder->output_tensors(); 18233be168c0dSopenharmony_ci+ auto op_param = node_coder->get_parameter(); 18234be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(op_param, RET_NULL_PTR, "NodeCoder's op_param is a nullptr."); 18235be168c0dSopenharmony_ci+ auto node = node_coder->node(); 18236be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(node, RET_NULL_PTR, "NodeCoder's node is a nullptr."); 18237be168c0dSopenharmony_ci+ auto prim = node->primitive_; 18238be168c0dSopenharmony_ci+ auto ret = DoInferShape(in_tensors, out_tensors, op_param, prim); 18239be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "ShapeInfoContainer Init failed."); 18240be168c0dSopenharmony_ci+ } 18241be168c0dSopenharmony_ci+ } 18242be168c0dSopenharmony_ci+ auto ret = DetermineShapeVarInfos(); 18243be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "DetermineShapeVarInfos failed."); 18244be168c0dSopenharmony_ci+ return RET_OK; 18245be168c0dSopenharmony_ci+} 18246be168c0dSopenharmony_ci+ 18247be168c0dSopenharmony_ci+int ShapeInfoContainer::DoInferShape(const std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, 18248be168c0dSopenharmony_ci+ OpParameter *op_param, const void *primitive) { 18249be168c0dSopenharmony_ci+ auto ret = KernelInferShape(in_tensors, out_tensors, primitive, {}, lite::SCHEMA_CUR); 18250be168c0dSopenharmony_ci+ if (ret == lite::RET_NOT_SUPPORT) { 18251be168c0dSopenharmony_ci+ ret = KernelInferShape(in_tensors, out_tensors, op_param); 18252be168c0dSopenharmony_ci+ } 18253be168c0dSopenharmony_ci+ if (ret != RET_OK) { 18254be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Infer shape failed."; 18255be168c0dSopenharmony_ci+ return ret; 18256be168c0dSopenharmony_ci+ } 18257be168c0dSopenharmony_ci+ for (const auto out_tensor : out_tensors) { 18258be168c0dSopenharmony_ci+ var_tensor_shapes_[out_tensor].push_back(out_tensor->shape()); 18259be168c0dSopenharmony_ci+ } 18260be168c0dSopenharmony_ci+ return RET_OK; 18261be168c0dSopenharmony_ci+} 18262be168c0dSopenharmony_ci+ 18263be168c0dSopenharmony_ci+int ShapeInfoContainer::DetermineShapeVarInfos() { 18264be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(kShapePrefixName, RET_NULL_PTR, "kShapePrefixName is a nullptr."); 18265be168c0dSopenharmony_ci+ int index = 0; 18266be168c0dSopenharmony_ci+ for (const auto &item : var_tensor_shapes_) { 18267be168c0dSopenharmony_ci+ auto &tensor = item.first; 18268be168c0dSopenharmony_ci+ auto &shapes = item.second; 18269be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(!shapes.empty(), RET_ERROR, "Cannot get some tensor's shape."); 18270be168c0dSopenharmony_ci+ auto shape = shapes.front(); 18271be168c0dSopenharmony_ci+ auto dims = shape.size(); 18272be168c0dSopenharmony_ci+ auto is_same_dim = 18273be168c0dSopenharmony_ci+ std::all_of(shapes.begin(), shapes.end(), [dims](const std::vector<int> &item) { return item.size() == dims; }); 18274be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(is_same_dim, RET_ERROR, "Tensor's shape-dims-num are not same."); 18275be168c0dSopenharmony_ci+ std::vector<std::string> shape_symbols; 18276be168c0dSopenharmony_ci+ for (size_t i = 0; i < dims; ++i) { 18277be168c0dSopenharmony_ci+ int dim = shape[i]; 18278be168c0dSopenharmony_ci+ std::vector<int> real_nums; 18279be168c0dSopenharmony_ci+ auto is_same_pos = 18280be168c0dSopenharmony_ci+ std::all_of(shapes.begin(), shapes.end(), [dim, i](const std::vector<int> &item) { return item[i] == dim; }); 18281be168c0dSopenharmony_ci+ if (is_same_pos) { 18282be168c0dSopenharmony_ci+ shape_symbols.push_back(std::to_string(dim)); 18283be168c0dSopenharmony_ci+ continue; 18284be168c0dSopenharmony_ci+ } 18285be168c0dSopenharmony_ci+ (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(real_nums), 18286be168c0dSopenharmony_ci+ [i](const std::vector<int> &item) { return item[i]; }); 18287be168c0dSopenharmony_ci+ std::string shape_symbol; 18288be168c0dSopenharmony_ci+ for (const auto &shape_to_num : shape_to_nums_) { 18289be168c0dSopenharmony_ci+ if (shape_to_num.second == real_nums) { 18290be168c0dSopenharmony_ci+ shape_symbol = shape_to_num.first; 18291be168c0dSopenharmony_ci+ break; 18292be168c0dSopenharmony_ci+ } 18293be168c0dSopenharmony_ci+ } 18294be168c0dSopenharmony_ci+ if (shape_symbol.empty()) { 18295be168c0dSopenharmony_ci+ for (size_t scene_index = 0; scene_index < real_nums.size(); ++scene_index) { 18296be168c0dSopenharmony_ci+ shapes_whole_scenes_[scene_index].push_back(real_nums[scene_index]); 18297be168c0dSopenharmony_ci+ } 18298be168c0dSopenharmony_ci+ shape_symbol = std::string(kShapePrefixName) + "[" + std::to_string(index++) + "]"; 18299be168c0dSopenharmony_ci+ shape_to_nums_[shape_symbol] = real_nums; 18300be168c0dSopenharmony_ci+ } 18301be168c0dSopenharmony_ci+ shape_symbols.push_back(shape_symbol); 18302be168c0dSopenharmony_ci+ } 18303be168c0dSopenharmony_ci+ shape_templates_[tensor] = shape_symbols; 18304be168c0dSopenharmony_ci+ } 18305be168c0dSopenharmony_ci+ return RET_OK; 18306be168c0dSopenharmony_ci+} 18307be168c0dSopenharmony_ci+ 18308be168c0dSopenharmony_ci+std::vector<std::string> ShapeInfoContainer::GetTemplateShape(const Tensor *tensor) const { 18309be168c0dSopenharmony_ci+ if (shape_templates_.find(tensor) == shape_templates_.end()) { 18310be168c0dSopenharmony_ci+ return {}; 18311be168c0dSopenharmony_ci+ } 18312be168c0dSopenharmony_ci+ return shape_templates_.at(tensor); 18313be168c0dSopenharmony_ci+} 18314be168c0dSopenharmony_ci+ 18315be168c0dSopenharmony_ci+std::vector<int> ShapeInfoContainer::GetRealNums(const std::string &shape_var) const { 18316be168c0dSopenharmony_ci+ if (IsNumber(shape_var)) { 18317be168c0dSopenharmony_ci+ return {std::stoi(shape_var)}; 18318be168c0dSopenharmony_ci+ } 18319be168c0dSopenharmony_ci+ if (shape_to_nums_.find(shape_var) == shape_to_nums_.end()) { 18320be168c0dSopenharmony_ci+ return {}; 18321be168c0dSopenharmony_ci+ } 18322be168c0dSopenharmony_ci+ return shape_to_nums_.at(shape_var); 18323be168c0dSopenharmony_ci+} 18324be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 18325be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/shape_info_container.h b/mindspore/lite/tools/converter/micro/coder/shape_info_container.h 18326be168c0dSopenharmony_cinew file mode 100644 18327be168c0dSopenharmony_ciindex 00000000..9268b249 18328be168c0dSopenharmony_ci--- /dev/null 18329be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/shape_info_container.h 18330be168c0dSopenharmony_ci@@ -0,0 +1,59 @@ 18331be168c0dSopenharmony_ci+/** 18332be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 18333be168c0dSopenharmony_ci+ * 18334be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 18335be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 18336be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 18337be168c0dSopenharmony_ci+ * 18338be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 18339be168c0dSopenharmony_ci+ * 18340be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 18341be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 18342be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18343be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 18344be168c0dSopenharmony_ci+ * limitations under the License. 18345be168c0dSopenharmony_ci+ */ 18346be168c0dSopenharmony_ci+ 18347be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SHAPE_INFO_CONTAINER_H_ 18348be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SHAPE_INFO_CONTAINER_H_ 18349be168c0dSopenharmony_ci+ 18350be168c0dSopenharmony_ci+#include <vector> 18351be168c0dSopenharmony_ci+#include <string> 18352be168c0dSopenharmony_ci+#include <map> 18353be168c0dSopenharmony_ci+#include "tools/converter/micro/coder/config.h" 18354be168c0dSopenharmony_ci+#include "include/model.h" 18355be168c0dSopenharmony_ci+#include "src/tensor.h" 18356be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 18357be168c0dSopenharmony_ci+ 18358be168c0dSopenharmony_ci+namespace mindspore::lite::micro { 18359be168c0dSopenharmony_ci+class OperatorCoder; 18360be168c0dSopenharmony_ci+class ShapeInfoContainer { 18361be168c0dSopenharmony_ci+ public: 18362be168c0dSopenharmony_ci+ ShapeInfoContainer() = default; 18363be168c0dSopenharmony_ci+ ~ShapeInfoContainer() = default; 18364be168c0dSopenharmony_ci+ 18365be168c0dSopenharmony_ci+ int Init(const std::vector<std::unique_ptr<OperatorCoder>> &nodes_coder, 18366be168c0dSopenharmony_ci+ const std::map<Tensor *, std::vector<std::vector<int>>> &graph_inputs); 18367be168c0dSopenharmony_ci+ 18368be168c0dSopenharmony_ci+ const std::map<Tensor *, std::vector<std::vector<int>>> &GetVarTensorInfos() const { return var_tensor_shapes_; } 18369be168c0dSopenharmony_ci+ 18370be168c0dSopenharmony_ci+ std::vector<std::string> GetTemplateShape(const Tensor *tensor) const; 18371be168c0dSopenharmony_ci+ 18372be168c0dSopenharmony_ci+ const std::map<const Tensor *, std::vector<std::string>> &GetWholeTemplateShape() { return shape_templates_; } 18373be168c0dSopenharmony_ci+ 18374be168c0dSopenharmony_ci+ std::vector<int> GetRealNums(const std::string &shape_var) const; 18375be168c0dSopenharmony_ci+ 18376be168c0dSopenharmony_ci+ const std::map<int, std::vector<int>> &GetShapesWholeScenes() const { return shapes_whole_scenes_; } 18377be168c0dSopenharmony_ci+ 18378be168c0dSopenharmony_ci+ private: 18379be168c0dSopenharmony_ci+ int DoInferShape(const std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &out_tensors, OpParameter *op_param, 18380be168c0dSopenharmony_ci+ const void *primitive); 18381be168c0dSopenharmony_ci+ int DetermineShapeVarInfos(); 18382be168c0dSopenharmony_ci+ std::map<Tensor *, std::vector<std::vector<int>>> var_tensor_shapes_; 18383be168c0dSopenharmony_ci+ std::map<const Tensor *, std::vector<std::string>> shape_templates_; 18384be168c0dSopenharmony_ci+ std::map<std::string, std::vector<int>> shape_to_nums_; 18385be168c0dSopenharmony_ci+ std::map<int, std::vector<int>> shapes_whole_scenes_; 18386be168c0dSopenharmony_ci+ Model *model_{nullptr}; 18387be168c0dSopenharmony_ci+}; 18388be168c0dSopenharmony_ci+} // namespace mindspore::lite::micro 18389be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_SHAPE_INFO_CONTAINER_H_ 18390be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc 18391be168c0dSopenharmony_ciindex c86a967d..a4c15c83 100644 18392be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc 18393be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.cc 18394be168c0dSopenharmony_ci@@ -1,5 +1,5 @@ 18395be168c0dSopenharmony_ci /** 18396be168c0dSopenharmony_ci- * Copyright 2021-2022 Huawei Technologies Co., Ltd 18397be168c0dSopenharmony_ci+ * Copyright 2021 Huawei Technologies Co., Ltd 18398be168c0dSopenharmony_ci * 18399be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License"); 18400be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 18401be168c0dSopenharmony_ci@@ -22,6 +22,7 @@ 18402be168c0dSopenharmony_ci #include "tools/converter/micro/coder/log.h" 18403be168c0dSopenharmony_ci #include "tools/converter/micro/coder/utils/type_cast.h" 18404be168c0dSopenharmony_ci #include "tools/converter/micro/coder/allocator/allocator.h" 18405be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 18406be168c0dSopenharmony_ci 18407be168c0dSopenharmony_ci namespace mindspore::lite::micro { 18408be168c0dSopenharmony_ci bool CheckConstantTensor(const Tensor *const tensor) { 18409be168c0dSopenharmony_ci@@ -145,4 +146,36 @@ std::vector<std::string> SplitString(std::string str, const std::string &pattern 18410be168c0dSopenharmony_ci } 18411be168c0dSopenharmony_ci return results; 18412be168c0dSopenharmony_ci } 18413be168c0dSopenharmony_ci+ 18414be168c0dSopenharmony_ci+std::string AccumulateShape(const std::vector<std::string> &shape_template, size_t start_index, size_t end_index) { 18415be168c0dSopenharmony_ci+ int64_t const_part = 1; 18416be168c0dSopenharmony_ci+ std::string non_const_part; 18417be168c0dSopenharmony_ci+ for (size_t i = start_index; i < end_index; ++i) { 18418be168c0dSopenharmony_ci+ auto item = shape_template[i]; 18419be168c0dSopenharmony_ci+ if (IsNumber(item)) { 18420be168c0dSopenharmony_ci+ const_part *= std::stoi(item); 18421be168c0dSopenharmony_ci+ } else { 18422be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 18423be168c0dSopenharmony_ci+ non_const_part += " * "; 18424be168c0dSopenharmony_ci+ } 18425be168c0dSopenharmony_ci+ non_const_part += item; 18426be168c0dSopenharmony_ci+ } 18427be168c0dSopenharmony_ci+ } 18428be168c0dSopenharmony_ci+ std::string accumulate_shape = std::to_string(const_part); 18429be168c0dSopenharmony_ci+ if (!non_const_part.empty()) { 18430be168c0dSopenharmony_ci+ accumulate_shape += " * " + non_const_part; 18431be168c0dSopenharmony_ci+ } 18432be168c0dSopenharmony_ci+ return accumulate_shape; 18433be168c0dSopenharmony_ci+} 18434be168c0dSopenharmony_ci+ 18435be168c0dSopenharmony_ci+std::string GetTensorAddr(lite::Tensor *tensor, bool is_const, DynamicMemManager *dynamic_mem_manager, 18436be168c0dSopenharmony_ci+ MemoryAllocator *allocator) { 18437be168c0dSopenharmony_ci+ if (is_const) { 18438be168c0dSopenharmony_ci+ return allocator->GetRuntimeAddr(tensor, true); 18439be168c0dSopenharmony_ci+ } 18440be168c0dSopenharmony_ci+ if (dynamic_mem_manager == nullptr) { 18441be168c0dSopenharmony_ci+ return allocator->GetRuntimeAddr(tensor); 18442be168c0dSopenharmony_ci+ } 18443be168c0dSopenharmony_ci+ return dynamic_mem_manager->GetVarTensorAddr(tensor); 18444be168c0dSopenharmony_ci+} 18445be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 18446be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h 18447be168c0dSopenharmony_ciindex eabae70e..70a973cb 100644 18448be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h 18449be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/utils/coder_utils.h 18450be168c0dSopenharmony_ci@@ -41,5 +41,10 @@ std::string ArrayToString(std::vector<T> array) { 18451be168c0dSopenharmony_ci std::for_each(array.begin(), array.end(), [&result](const T &t) { result += std::to_string(t) + ", "; }); 18452be168c0dSopenharmony_ci return "{" + result + "}"; 18453be168c0dSopenharmony_ci } 18454be168c0dSopenharmony_ci+ 18455be168c0dSopenharmony_ci+std::string AccumulateShape(const std::vector<std::string> &shape_template, size_t start_index, size_t end_index); 18456be168c0dSopenharmony_ci+ 18457be168c0dSopenharmony_ci+std::string GetTensorAddr(lite::Tensor *tensor, bool is_const, DynamicMemManager *dynamic_mem_manager, 18458be168c0dSopenharmony_ci+ MemoryAllocator *allocator); 18459be168c0dSopenharmony_ci } // namespace mindspore::lite::micro 18460be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_TOOLS_CONVERTER_MICRO_CODER_UTILS_CODER_UTILS_H_ 18461be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 18462be168c0dSopenharmony_ciindex 61b22bae..1d3c02a0 100644 18463be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 18464be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/micro/coder/utils/type_cast.cc 18465be168c0dSopenharmony_ci@@ -54,32 +54,30 @@ std::string EnumNameDataType(TypeId type) { 18466be168c0dSopenharmony_ci std::string EnumNameMSDataType(TypeId type) { 18467be168c0dSopenharmony_ci switch (type) { 18468be168c0dSopenharmony_ci case kNumberTypeInt: 18469be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeInt32"; 18470be168c0dSopenharmony_ci+ case kNumberTypeInt32: 18471be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_INT32"; 18472be168c0dSopenharmony_ci case kNumberTypeInt8: 18473be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeInt8"; 18474be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_INT8"; 18475be168c0dSopenharmony_ci case kNumberTypeInt16: 18476be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeInt16"; 18477be168c0dSopenharmony_ci- case kNumberTypeInt32: 18478be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeInt32"; 18479be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_INT16"; 18480be168c0dSopenharmony_ci case kNumberTypeInt64: 18481be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeUInt64"; 18482be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_INT64"; 18483be168c0dSopenharmony_ci case kNumberTypeUInt: 18484be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeUInt32"; 18485be168c0dSopenharmony_ci+ case kNumberTypeUInt32: 18486be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_UINT32"; 18487be168c0dSopenharmony_ci case kNumberTypeUInt8: 18488be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeUInt8"; 18489be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_UINT8"; 18490be168c0dSopenharmony_ci case kNumberTypeUInt16: 18491be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeUInt16"; 18492be168c0dSopenharmony_ci- case kNumberTypeUInt32: 18493be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeUInt32"; 18494be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_UINT16"; 18495be168c0dSopenharmony_ci case kNumberTypeFloat: 18496be168c0dSopenharmony_ci case kNumberTypeFloat32: 18497be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeFloat32"; 18498be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_FLOAT32"; 18499be168c0dSopenharmony_ci case kNumberTypeFloat16: 18500be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeFloat16"; 18501be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_FLOAT16"; 18502be168c0dSopenharmony_ci case kNumberTypeFloat64: 18503be168c0dSopenharmony_ci- return "kMSDataTypeNumberTypeFloat64"; 18504be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_NUMBERTYPE_FLOAT64"; 18505be168c0dSopenharmony_ci case kTypeUnknown: 18506be168c0dSopenharmony_ci- return "kMSDataTypeUnknown"; 18507be168c0dSopenharmony_ci+ return "OH_AI_DATATYPE_UNKNOWN"; 18508be168c0dSopenharmony_ci default: 18509be168c0dSopenharmony_ci return "unsupported"; 18510be168c0dSopenharmony_ci } 18511be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 18512be168c0dSopenharmony_ciindex 652db4af..a82feb07 100644 18513be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 18514be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 18515be168c0dSopenharmony_ci@@ -62,7 +62,7 @@ STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) { 18516be168c0dSopenharmony_ci MS_LOG(ERROR) << "Missing config file in converting third party model"; 18517be168c0dSopenharmony_ci return RET_ERROR; 18518be168c0dSopenharmony_ci } 18519be168c0dSopenharmony_ci- auto ret = config_parser.ParseConfigFile(config_file); 18520be168c0dSopenharmony_ci+ auto ret = config_parser.ParseConfigFile(config_file, nullptr); 18521be168c0dSopenharmony_ci if (ret != RET_OK) { 18522be168c0dSopenharmony_ci MS_LOG(ERROR) << "Get third party model section from config file failed"; 18523be168c0dSopenharmony_ci return RET_ERROR; 18524be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.cc 18525be168c0dSopenharmony_cinew file mode 100644 18526be168c0dSopenharmony_ciindex 00000000..4caef237 18527be168c0dSopenharmony_ci--- /dev/null 18528be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.cc 18529be168c0dSopenharmony_ci@@ -0,0 +1,120 @@ 18530be168c0dSopenharmony_ci+/** 18531be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 18532be168c0dSopenharmony_ci+ * 18533be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 18534be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 18535be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 18536be168c0dSopenharmony_ci+ * 18537be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 18538be168c0dSopenharmony_ci+ * 18539be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 18540be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 18541be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18542be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 18543be168c0dSopenharmony_ci+ * limitations under the License. 18544be168c0dSopenharmony_ci+ */ 18545be168c0dSopenharmony_ci+ 18546be168c0dSopenharmony_ci+#define USE_DEPRECATED_API 18547be168c0dSopenharmony_ci+#include "tools/optimizer/fusion/tile_matmul_fusion.h" 18548be168c0dSopenharmony_ci+#include <memory> 18549be168c0dSopenharmony_ci+#include "tools/optimizer/common/gllo_utils.h" 18550be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 18551be168c0dSopenharmony_ci+#include "tools/lite_exporter/fetch_content.h" 18552be168c0dSopenharmony_ci+#include "ops/op_utils.h" 18553be168c0dSopenharmony_ci+#include "ops/lite_ops.h" 18554be168c0dSopenharmony_ci+#include "ops/fusion/tile_fusion.h" 18555be168c0dSopenharmony_ci+#include "ops/fusion/mat_mul_fusion.h" 18556be168c0dSopenharmony_ci+ 18557be168c0dSopenharmony_ci+namespace mindspore { 18558be168c0dSopenharmony_ci+namespace opt { 18559be168c0dSopenharmony_ci+bool TileMatMulFusion::CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { 18560be168c0dSopenharmony_ci+ auto tile_cnode = node->cast<CNodePtr>(); 18561be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(tile_cnode != nullptr, false); 18562be168c0dSopenharmony_ci+ auto tile_primc = ops::GetOperator<ops::TileFusion>(tile_cnode->input(0)); 18563be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(tile_primc != nullptr, false); 18564be168c0dSopenharmony_ci+ auto tile_prim_c = tile_primc->GetPrim(); 18565be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(tile_prim_c != nullptr, false); 18566be168c0dSopenharmony_ci+ if (IsQuantParameterNode(tile_prim_c)) { 18567be168c0dSopenharmony_ci+ MS_LOG(INFO) << tile_primc->name() << " is quant node"; 18568be168c0dSopenharmony_ci+ return false; 18569be168c0dSopenharmony_ci+ } 18570be168c0dSopenharmony_ci+ auto manager = func_graph->manager(); 18571be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(manager != nullptr, false); 18572be168c0dSopenharmony_ci+ auto node_users = manager->node_users()[tile_cnode]; 18573be168c0dSopenharmony_ci+ for (auto &node_user : node_users) { 18574be168c0dSopenharmony_ci+ auto post_node = node_user.first; 18575be168c0dSopenharmony_ci+ auto post_node_index = node_user.second; 18576be168c0dSopenharmony_ci+ if (!utils::isa<CNode>(post_node) || !CheckPrimitiveType(post_node, prim::kPrimMatMulFusion) || 18577be168c0dSopenharmony_ci+ post_node_index != C2NUM) { 18578be168c0dSopenharmony_ci+ MS_LOG(INFO) << "The post node of tile must be matmul's matirxB."; 18579be168c0dSopenharmony_ci+ return false; 18580be168c0dSopenharmony_ci+ } 18581be168c0dSopenharmony_ci+ auto matmul_primc = ops::GetOperator<ops::MatMulFusion>(GetInputs(post_node).at(0)); 18582be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_primc != nullptr, false); 18583be168c0dSopenharmony_ci+ auto matmul_prim_c = matmul_primc->GetPrim(); 18584be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(matmul_prim_c != nullptr, false); 18585be168c0dSopenharmony_ci+ if (IsQuantParameterNode(matmul_prim_c)) { 18586be168c0dSopenharmony_ci+ MS_LOG(INFO) << matmul_prim_c->name() << " is quant node"; 18587be168c0dSopenharmony_ci+ return false; 18588be168c0dSopenharmony_ci+ } 18589be168c0dSopenharmony_ci+ } 18590be168c0dSopenharmony_ci+ 18591be168c0dSopenharmony_ci+ lite::DataInfo data_info; 18592be168c0dSopenharmony_ci+ auto status = lite::FetchConstData(tile_cnode, C2NUM, converter::kFmkTypeMs, &data_info, false); 18593be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(status == RET_OK, false, "Fetch tile_cnode third input's const data failed."); 18594be168c0dSopenharmony_ci+ if ((data_info.data_type_ != kNumberTypeInt32 && data_info.data_type_ != kNumberTypeInt) || 18595be168c0dSopenharmony_ci+ data_info.data_.size() / sizeof(int) < DIMENSION_2D) { 18596be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Tile index data is invalid."; 18597be168c0dSopenharmony_ci+ return false; 18598be168c0dSopenharmony_ci+ } 18599be168c0dSopenharmony_ci+ auto data = reinterpret_cast<int *>(data_info.data_.data()); 18600be168c0dSopenharmony_ci+ int dim = static_cast<int>(data_info.data_.size() / sizeof(int)); 18601be168c0dSopenharmony_ci+ for (int i = dim - C1NUM; i > dim - C3NUM; --i) { 18602be168c0dSopenharmony_ci+ if (data[i] != C1NUM) { 18603be168c0dSopenharmony_ci+ return false; 18604be168c0dSopenharmony_ci+ } 18605be168c0dSopenharmony_ci+ } 18606be168c0dSopenharmony_ci+ lite::DataInfo weights_info; 18607be168c0dSopenharmony_ci+ auto left_pre_node = tile_cnode->input(C1NUM); 18608be168c0dSopenharmony_ci+ if (left_pre_node->isa<Parameter>() || left_pre_node->isa<ValueNode>()) { 18609be168c0dSopenharmony_ci+ status = lite::FetchConstData(tile_cnode, C1NUM, converter::kFmkTypeMs, &weights_info, false); 18610be168c0dSopenharmony_ci+ } else { 18611be168c0dSopenharmony_ci+ status = lite::FetchDataFromCNode(tile_cnode, C1NUM, &weights_info); 18612be168c0dSopenharmony_ci+ } 18613be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(status == RET_OK, false); 18614be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(weights_info.shape_.size() == static_cast<size_t>(dim), false, 18615be168c0dSopenharmony_ci+ "Tile_cnode second input's shape size is invalid."); 18616be168c0dSopenharmony_ci+ for (int i = 0; i < dim - C2NUM; i++) { 18617be168c0dSopenharmony_ci+ if (data[i] != C1NUM && weights_info.shape_[i] != C1NUM) { 18618be168c0dSopenharmony_ci+ return false; 18619be168c0dSopenharmony_ci+ } 18620be168c0dSopenharmony_ci+ } 18621be168c0dSopenharmony_ci+ return true; 18622be168c0dSopenharmony_ci+} 18623be168c0dSopenharmony_ci+ 18624be168c0dSopenharmony_ci+bool TileMatMulFusion::Run(const FuncGraphPtr &func_graph) { 18625be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(func_graph != nullptr, false); 18626be168c0dSopenharmony_ci+ auto node_list = TopoSort(func_graph->get_return()); 18627be168c0dSopenharmony_ci+ for (auto &node : node_list) { 18628be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(node != nullptr, false); 18629be168c0dSopenharmony_ci+ if (!utils::isa<CNode>(node)) { 18630be168c0dSopenharmony_ci+ continue; 18631be168c0dSopenharmony_ci+ } 18632be168c0dSopenharmony_ci+ if (!CheckPrimitiveType(node, prim::kPrimTileFusion)) { 18633be168c0dSopenharmony_ci+ continue; 18634be168c0dSopenharmony_ci+ } 18635be168c0dSopenharmony_ci+ if (!CheckCanFuse(func_graph, node)) { 18636be168c0dSopenharmony_ci+ continue; 18637be168c0dSopenharmony_ci+ } 18638be168c0dSopenharmony_ci+ auto tile_cnode = node->cast<CNodePtr>(); 18639be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(tile_cnode != nullptr, false); 18640be168c0dSopenharmony_ci+ auto left_pre_node = tile_cnode->input(SECOND_INPUT); 18641be168c0dSopenharmony_ci+ auto manage = func_graph->manager(); 18642be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(manage != nullptr, false); 18643be168c0dSopenharmony_ci+ auto success = manage->Replace(tile_cnode, left_pre_node); 18644be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(success, false, "Replace old node failed."); 18645be168c0dSopenharmony_ci+ } 18646be168c0dSopenharmony_ci+ return true; 18647be168c0dSopenharmony_ci+} 18648be168c0dSopenharmony_ci+} // namespace opt 18649be168c0dSopenharmony_ci+} // namespace mindspore 18650be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.h b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.h 18651be168c0dSopenharmony_cinew file mode 100644 18652be168c0dSopenharmony_ciindex 00000000..280dc265 18653be168c0dSopenharmony_ci--- /dev/null 18654be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/optimizer/fusion/tile_matmul_fusion.h 18655be168c0dSopenharmony_ci@@ -0,0 +1,37 @@ 18656be168c0dSopenharmony_ci+/** 18657be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 18658be168c0dSopenharmony_ci+ * 18659be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 18660be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 18661be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 18662be168c0dSopenharmony_ci+ * 18663be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 18664be168c0dSopenharmony_ci+ * 18665be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 18666be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 18667be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18668be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 18669be168c0dSopenharmony_ci+ * limitations under the License. 18670be168c0dSopenharmony_ci+ */ 18671be168c0dSopenharmony_ci+ 18672be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TILE_MATMUL_FUSION_H_ 18673be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TILE_MATMUL_FUSION_H_ 18674be168c0dSopenharmony_ci+ 18675be168c0dSopenharmony_ci+#include <string> 18676be168c0dSopenharmony_ci+#include "tools/optimizer/common/multiple_pattern_process_pass.h" 18677be168c0dSopenharmony_ci+#include "utils/check_convert_utils.h" 18678be168c0dSopenharmony_ci+ 18679be168c0dSopenharmony_ci+namespace mindspore { 18680be168c0dSopenharmony_ci+namespace opt { 18681be168c0dSopenharmony_ci+class TileMatMulFusion : public Pass { 18682be168c0dSopenharmony_ci+ public: 18683be168c0dSopenharmony_ci+ TileMatMulFusion() : Pass("TileMatMulFusion") {} 18684be168c0dSopenharmony_ci+ ~TileMatMulFusion() override = default; 18685be168c0dSopenharmony_ci+ bool Run(const FuncGraphPtr &func_graph) override; 18686be168c0dSopenharmony_ci+ 18687be168c0dSopenharmony_ci+ private: 18688be168c0dSopenharmony_ci+ bool CheckCanFuse(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 18689be168c0dSopenharmony_ci+}; 18690be168c0dSopenharmony_ci+} // namespace opt 18691be168c0dSopenharmony_ci+} // namespace mindspore 18692be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TILE_MATMUL_FUSION_H_ 18693be168c0dSopenharmony_cidiff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py 18694be168c0dSopenharmony_ciindex 59c9c883..5714b832 100644 18695be168c0dSopenharmony_ci--- a/mindspore/python/mindspore/ops/operations/_grad_ops.py 18696be168c0dSopenharmony_ci+++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py 18697be168c0dSopenharmony_ci@@ -1521,7 +1521,7 @@ class LSTMGrad(Primitive): 18698be168c0dSopenharmony_ci """Computes the data and weight gradients of LSTM.""" 18699be168c0dSopenharmony_ci 18700be168c0dSopenharmony_ci @prim_attr_register 18701be168c0dSopenharmony_ci- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 18702be168c0dSopenharmony_ci+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0): 18703be168c0dSopenharmony_ci self.input_size = validator.check_positive_int(input_size, 'input_size', self.name) 18704be168c0dSopenharmony_ci self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name) 18705be168c0dSopenharmony_ci self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name) 18706be168c0dSopenharmony_ci@@ -1529,12 +1529,53 @@ class LSTMGrad(Primitive): 18707be168c0dSopenharmony_ci self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) 18708be168c0dSopenharmony_ci self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 18709be168c0dSopenharmony_ci self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name) 18710be168c0dSopenharmony_ci+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, Rel.INC_LEFT, 18711be168c0dSopenharmony_ci+ 'proj_size', self.name) 18712be168c0dSopenharmony_ci+ 18713be168c0dSopenharmony_ci 18714be168c0dSopenharmony_ci if bidirectional: 18715be168c0dSopenharmony_ci self.num_directions = 2 18716be168c0dSopenharmony_ci else: 18717be168c0dSopenharmony_ci self.num_directions = 1 18718be168c0dSopenharmony_ci 18719be168c0dSopenharmony_ci+ def infer_shape(self, x_shape, hx_shape, cx_shape, w_shape, y_shape, hy_shape, cy_shape, dy_shape, dhy_shape, 18720be168c0dSopenharmony_ci+ dcy_shape, reserve_shape): 18721be168c0dSopenharmony_ci+ # dhy and dcy should be same shape 18722be168c0dSopenharmony_ci+ validator.check_equal_int(len(dhy_shape), 3, "h_shape", self.name) 18723be168c0dSopenharmony_ci+ validator.check_equal_int(len(dhy_shape), len(dcy_shape), "h_shape", self.name) 18724be168c0dSopenharmony_ci+ if self.proj_size == 0: 18725be168c0dSopenharmony_ci+ validator.check_equal_int(dhy_shape[0], dcy_shape[0], "h_shape[0]", self.name) 18726be168c0dSopenharmony_ci+ validator.check_equal_int(dhy_shape[1], dcy_shape[1], "h_shape[1]", self.name) 18727be168c0dSopenharmony_ci+ validator.check_equal_int(dhy_shape[2], dcy_shape[2], "h_shape[2]", self.name) 18728be168c0dSopenharmony_ci+ 18729be168c0dSopenharmony_ci+ real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size 18730be168c0dSopenharmony_ci+ validator.check_int(dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h_shape[0]", self.name) 18731be168c0dSopenharmony_ci+ validator.check_equal_int(dhy_shape[2], real_hidden_size, "h_shape[2]", self.name) 18732be168c0dSopenharmony_ci+ 18733be168c0dSopenharmony_ci+ validator.check_equal_int(len(dy_shape), 3, "dy_shape", self.name) 18734be168c0dSopenharmony_ci+ validator.check_equal_int(dy_shape[1], dhy_shape[1], "dy[1]", self.name) 18735be168c0dSopenharmony_ci+ validator.check_int(dy_shape[2], real_hidden_size * self.num_directions, Rel.EQ, "dy[2]", self.name) 18736be168c0dSopenharmony_ci+ 18737be168c0dSopenharmony_ci+ dx_shape = (y_shape[0], y_shape[1], self.input_size) 18738be168c0dSopenharmony_ci+ dhx_shape = dhy_shape 18739be168c0dSopenharmony_ci+ dcx_shape = dcy_shape 18740be168c0dSopenharmony_ci+ weight_size = 0 18741be168c0dSopenharmony_ci+ gate_size = 4 * self.hidden_size 18742be168c0dSopenharmony_ci+ for layer in range(self.num_layers): 18743be168c0dSopenharmony_ci+ for _ in range(self.num_directions): 18744be168c0dSopenharmony_ci+ input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions 18745be168c0dSopenharmony_ci+ weight_size += gate_size * input_layer_size 18746be168c0dSopenharmony_ci+ weight_size += gate_size * real_hidden_size 18747be168c0dSopenharmony_ci+ if self.proj_size > 0: 18748be168c0dSopenharmony_ci+ weight_size += self.proj_size * self.hidden_size 18749be168c0dSopenharmony_ci+ if self.has_bias: 18750be168c0dSopenharmony_ci+ weight_size += gate_size 18751be168c0dSopenharmony_ci+ 18752be168c0dSopenharmony_ci+ return (dx_shape, dhx_shape, dcx_shape, (weight_size, 1, 1)) 18753be168c0dSopenharmony_ci+ 18754be168c0dSopenharmony_ci+ def infer_dtype(self, x_dtype, hx_dtype, cx_dtype, w_dtype, y_dtype, hy_dtype, cy_dtype, dy_dtype, dhy_dtype, 18755be168c0dSopenharmony_ci+ dcy_dtype, reserve_dtype): 18756be168c0dSopenharmony_ci+ return (dy_dtype, dy_dtype, dy_dtype, hx_dtype) 18757be168c0dSopenharmony_ci 18758be168c0dSopenharmony_ci class DynamicRNNGrad(Primitive): 18759be168c0dSopenharmony_ci """Computes the input gradients of DynamicRNN.""" 18760be168c0dSopenharmony_cidiff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py 18761be168c0dSopenharmony_ciindex 3a0eb3d6..8ae747be 100644 18762be168c0dSopenharmony_ci--- a/mindspore/python/mindspore/ops/operations/nn_ops.py 18763be168c0dSopenharmony_ci+++ b/mindspore/python/mindspore/ops/operations/nn_ops.py 18764be168c0dSopenharmony_ci@@ -4356,7 +4356,7 @@ class LSTM(Primitive): 18765be168c0dSopenharmony_ci """ 18766be168c0dSopenharmony_ci 18767be168c0dSopenharmony_ci @prim_attr_register 18768be168c0dSopenharmony_ci- def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): 18769be168c0dSopenharmony_ci+ def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0): 18770be168c0dSopenharmony_ci """Initialize LSTM.""" 18771be168c0dSopenharmony_ci self.input_size = validator.check_positive_int(input_size, "input_size", self.name) 18772be168c0dSopenharmony_ci self.hidden_size = validator.check_positive_int(hidden_size, "hidden_size", self.name) 18773be168c0dSopenharmony_ci@@ -4365,12 +4365,40 @@ class LSTM(Primitive): 18774be168c0dSopenharmony_ci self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) 18775be168c0dSopenharmony_ci self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) 18776be168c0dSopenharmony_ci self.dropout = validator.check_float_range(dropout, 0, 1, validator.INC_BOTH, 'dropout', self.name) 18777be168c0dSopenharmony_ci+ self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT, 18778be168c0dSopenharmony_ci+ 'proj_size', self.name) 18779be168c0dSopenharmony_ci 18780be168c0dSopenharmony_ci if bidirectional: 18781be168c0dSopenharmony_ci self.num_directions = 2 18782be168c0dSopenharmony_ci else: 18783be168c0dSopenharmony_ci self.num_directions = 1 18784be168c0dSopenharmony_ci 18785be168c0dSopenharmony_ci+ def infer_shape(self, x_shape, h_shape, c_shape, w_shape): 18786be168c0dSopenharmony_ci+ validator.check_equal_int(len(x_shape), 3, "x rank", self.name) 18787be168c0dSopenharmony_ci+ validator.check_equal_int(x_shape[2], self.input_size, "x[2]", self.name) 18788be168c0dSopenharmony_ci+ 18789be168c0dSopenharmony_ci+ # h and c should be same shape 18790be168c0dSopenharmony_ci+ validator.check_equal_int(len(h_shape), 3, "h rank", self.name) 18791be168c0dSopenharmony_ci+ if self.proj_size == 0: 18792be168c0dSopenharmony_ci+ validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) 18793be168c0dSopenharmony_ci+ 18794be168c0dSopenharmony_ci+ real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size 18795be168c0dSopenharmony_ci+ validator.check_int(h_shape[0], self.num_layers * self.num_directions, Rel.EQ, "h[0]", self.name) 18796be168c0dSopenharmony_ci+ validator.check_equal_int(h_shape[1], x_shape[1], "h[1]", self.name) 18797be168c0dSopenharmony_ci+ validator.check_int(h_shape[2], real_hidden_size, Rel.EQ, "h[2]", self.name) 18798be168c0dSopenharmony_ci+ 18799be168c0dSopenharmony_ci+ y_shape = (x_shape[0], x_shape[1], real_hidden_size * self.num_directions) 18800be168c0dSopenharmony_ci+ 18801be168c0dSopenharmony_ci+ # set arbitrary shape for reserved space 18802be168c0dSopenharmony_ci+ reserved_shape = (1, 1) 18803be168c0dSopenharmony_ci+ state_shape = (1, 1) 18804be168c0dSopenharmony_ci+ return y_shape, h_shape, c_shape, reserved_shape, state_shape 18805be168c0dSopenharmony_ci+ 18806be168c0dSopenharmony_ci+ def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): 18807be168c0dSopenharmony_ci+ args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} 18808be168c0dSopenharmony_ci+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32, mstype.float16), self.name) 18809be168c0dSopenharmony_ci+ return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype 18810be168c0dSopenharmony_ci+ 18811be168c0dSopenharmony_ci 18812be168c0dSopenharmony_ci class SigmoidCrossEntropyWithLogits(Primitive): 18813be168c0dSopenharmony_ci r""" 18814