1be168c0dSopenharmony_ciFrom aa1ba4e66099334f25fa83d7c27c92900b735b4e Mon Sep 17 00:00:00 2001 2be168c0dSopenharmony_ciFrom: qinzheng4 <qinzheng4@huawei.com> 3be168c0dSopenharmony_ciDate: Thu, 28 Dec 2023 22:23:50 +0800 4be168c0dSopenharmony_ciSubject: [PATCH] 0006-remove-lite-expression-fix-double-loadso 5be168c0dSopenharmony_ci 6be168c0dSopenharmony_ci--- 7be168c0dSopenharmony_ci include/api/graph.h | 15 - 8be168c0dSopenharmony_ci include/api/model.h | 14 - 9be168c0dSopenharmony_ci include/api/net.h | 142 ------- 10be168c0dSopenharmony_ci .../cpu/kernel/nnacl/infer/reshape_infer.c | 2 +- 11be168c0dSopenharmony_ci mindspore/lite/BUILD.gn | 48 +-- 12be168c0dSopenharmony_ci mindspore/lite/src/CMakeLists.txt | 32 +- 13be168c0dSopenharmony_ci .../common/ops/populate/custom_populate.cc | 9 +- 14be168c0dSopenharmony_ci mindspore/lite/src/expression/cfg.h | 68 ---- 15be168c0dSopenharmony_ci mindspore/lite/src/expression/export.cc | 76 ---- 16be168c0dSopenharmony_ci mindspore/lite/src/expression/export.h | 52 --- 17be168c0dSopenharmony_ci mindspore/lite/src/expression/expr.cc | 98 ----- 18be168c0dSopenharmony_ci mindspore/lite/src/expression/expr.h | 70 ---- 19be168c0dSopenharmony_ci mindspore/lite/src/expression/import.cc | 180 --------- 20be168c0dSopenharmony_ci mindspore/lite/src/expression/import.h | 61 --- 21be168c0dSopenharmony_ci mindspore/lite/src/expression/net.cc | 268 ------------- 22be168c0dSopenharmony_ci mindspore/lite/src/expression/net.h | 114 ------ 23be168c0dSopenharmony_ci mindspore/lite/src/expression/node.cc | 271 ------------- 24be168c0dSopenharmony_ci mindspore/lite/src/expression/node.h | 156 -------- 25be168c0dSopenharmony_ci mindspore/lite/src/expression/ops.cc | 66 ---- 26be168c0dSopenharmony_ci mindspore/lite/src/expression/ops.h | 69 ---- 27be168c0dSopenharmony_ci .../lite/src/expression/ops/activation.cc | 133 ------- 28be168c0dSopenharmony_ci .../lite/src/expression/ops/activation.h | 44 --- 29be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/adam.cc | 142 ------- 30be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/adam.h | 46 --- 31be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/addn.cc | 42 -- 32be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/addn.h | 34 -- 33be168c0dSopenharmony_ci .../lite/src/expression/ops/arithmetic.cc | 223 ----------- 34be168c0dSopenharmony_ci .../lite/src/expression/ops/arithmetic.h | 75 ---- 35be168c0dSopenharmony_ci .../src/expression/ops/arithmetic_self.cc | 72 ---- 36be168c0dSopenharmony_ci .../lite/src/expression/ops/arithmetic_self.h | 46 --- 37be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/assign.cc | 60 --- 38be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/assign.h | 35 -- 39be168c0dSopenharmony_ci .../lite/src/expression/ops/batchnorm.cc | 135 ------- 40be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/batchnorm.h | 43 -- 41be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/biasadd.cc | 93 ----- 42be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/biasadd.h | 44 --- 43be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/conv.cc | 241 ------------ 44be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/conv.h | 58 --- 45be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/dense.cc | 151 ------- 46be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/dense.h | 44 --- 47be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/depend.cc | 43 -- 48be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/dropout.cc | 91 ----- 49be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/dropout.h | 42 -- 50be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/flatten.cc | 71 ---- 51be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/flatten.h | 36 -- 52be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/pooling.cc | 215 ---------- 53be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/pooling.h | 74 ---- 54be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/reduce.cc | 126 ------ 55be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/reduce.h | 42 -- 56be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/reshape.cc | 74 ---- 57be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/reshape.h | 37 -- 58be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/softmax.cc | 119 ------ 59be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/softmax.h | 39 -- 60be168c0dSopenharmony_ci .../lite/src/expression/ops/softmaxCE.cc | 93 ----- 61be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/softmaxCE.h | 47 --- 62be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/tile.cc | 62 --- 63be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/tile.h | 40 -- 64be168c0dSopenharmony_ci .../lite/src/expression/ops/transpose.cc | 88 ----- 65be168c0dSopenharmony_ci mindspore/lite/src/expression/ops/transpose.h | 59 --- 66be168c0dSopenharmony_ci mindspore/lite/src/expression/ops_utils.cc | 275 ------------- 67be168c0dSopenharmony_ci mindspore/lite/src/expression/ops_utils.h | 69 ---- 68be168c0dSopenharmony_ci mindspore/lite/src/expression/param.cc | 70 ---- 69be168c0dSopenharmony_ci mindspore/lite/src/expression/param.h | 60 --- 70be168c0dSopenharmony_ci mindspore/lite/src/expression/sequential.cc | 30 -- 71be168c0dSopenharmony_ci mindspore/lite/src/expression/sequential.h | 32 -- 72be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/expression/net.cc | 145 ------- 73be168c0dSopenharmony_ci .../src/litert/cxx_api/expression/net_impl.cc | 220 ----------- 74be168c0dSopenharmony_ci .../src/litert/cxx_api/expression/net_impl.h | 95 ----- 75be168c0dSopenharmony_ci .../litert/cxx_api/expression/node_impl.cc | 50 --- 76be168c0dSopenharmony_ci .../src/litert/cxx_api/expression/node_impl.h | 71 ---- 77be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/graph/graph.cc | 9 - 78be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/graph/net_data.cc | 21 - 79be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/graph/net_data.h | 35 -- 80be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/model/model.cc | 1 - 81be168c0dSopenharmony_ci .../src/litert/cxx_api/model/model_impl.cc | 8 - 82be168c0dSopenharmony_ci .../src/litert/cxx_api/model/model_impl.h | 4 - 83be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/serialization.cc | 39 +- 84be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/train/model.cc | 1 - 85be168c0dSopenharmony_ci .../src/litert/cxx_api/train/model_build.cc | 28 -- 86be168c0dSopenharmony_ci .../litert/cxx_api/train/model_build_impl.cc | 28 -- 87be168c0dSopenharmony_ci .../src/litert/cxx_api/train/model_impl.cc | 1 - 88be168c0dSopenharmony_ci mindspore/lite/src/train/graph_fusion.cc | 7 + 89be168c0dSopenharmony_ci .../fusion/remove_redundant_tensor.cc | 89 +++++ 90be168c0dSopenharmony_ci .../fusion/remove_redundant_tensor.h} | 25 +- 91be168c0dSopenharmony_ci mindspore/lite/src/train/train_session.cc | 90 ++--- 92be168c0dSopenharmony_ci mindspore/lite/src/train/train_session.h | 3 - 93be168c0dSopenharmony_ci .../test/config_level0/models_ms_train.cfg | 2 - 94be168c0dSopenharmony_ci .../lite/tools/benchmark_train/CMakeLists.txt | 7 - 95be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_runner.cc | 371 ------------------ 96be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_runner.h | 81 ---- 97be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_train.cc | 1 - 98be168c0dSopenharmony_ci 91 files changed, 162 insertions(+), 6776 deletions(-) 99be168c0dSopenharmony_ci delete mode 100644 include/api/net.h 100be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/cfg.h 101be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/export.cc 102be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/export.h 103be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/expr.cc 104be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/expr.h 105be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/import.cc 106be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/import.h 107be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/net.cc 108be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/net.h 109be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/node.cc 110be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/node.h 111be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops.cc 112be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops.h 113be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/activation.cc 114be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/activation.h 115be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/adam.cc 116be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/adam.h 117be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/addn.cc 118be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/addn.h 119be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/arithmetic.cc 120be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/arithmetic.h 121be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/arithmetic_self.cc 122be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/arithmetic_self.h 123be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/assign.cc 124be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/assign.h 125be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/batchnorm.cc 126be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/batchnorm.h 127be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/biasadd.cc 128be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/biasadd.h 129be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/conv.cc 130be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/conv.h 131be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/dense.cc 132be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/dense.h 133be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/depend.cc 134be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/dropout.cc 135be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/dropout.h 136be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/flatten.cc 137be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/flatten.h 138be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/pooling.cc 139be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/pooling.h 140be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/reduce.cc 141be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/reduce.h 142be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/reshape.cc 143be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/reshape.h 144be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/softmax.cc 145be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/softmax.h 146be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/softmaxCE.cc 147be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/softmaxCE.h 148be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/tile.cc 149be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/tile.h 150be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/transpose.cc 151be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops/transpose.h 152be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops_utils.cc 153be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/ops_utils.h 154be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/param.cc 155be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/param.h 156be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/sequential.cc 157be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/expression/sequential.h 158be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/expression/net.cc 159be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/expression/net_impl.cc 160be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/expression/net_impl.h 161be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/expression/node_impl.cc 162be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/expression/node_impl.h 163be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/graph/net_data.cc 164be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/src/litert/cxx_api/graph/net_data.h 165be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.cc 166be168c0dSopenharmony_ci rename mindspore/lite/src/{expression/ops/depend.h => train/optimizer/fusion/remove_redundant_tensor.h} (55%) 167be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/tools/benchmark_train/net_runner.cc 168be168c0dSopenharmony_ci delete mode 100644 mindspore/lite/tools/benchmark_train/net_runner.h 169be168c0dSopenharmony_ci 170be168c0dSopenharmony_cidiff --git a/include/api/graph.h b/include/api/graph.h 171be168c0dSopenharmony_ciindex 05548890..f25a6217 100644 172be168c0dSopenharmony_ci--- a/include/api/graph.h 173be168c0dSopenharmony_ci+++ b/include/api/graph.h 174be168c0dSopenharmony_ci@@ -24,38 +24,23 @@ 175be168c0dSopenharmony_ci #include "include/api/types.h" 176be168c0dSopenharmony_ci 177be168c0dSopenharmony_ci namespace mindspore { 178be168c0dSopenharmony_ci-class NetData; 179be168c0dSopenharmony_ci-class Net; 180be168c0dSopenharmony_ci- 181be168c0dSopenharmony_ci class MS_API Graph { 182be168c0dSopenharmony_ci public: 183be168c0dSopenharmony_ci class GraphData; 184be168c0dSopenharmony_ci- enum Type : uint32_t { 185be168c0dSopenharmony_ci- kExpressionGraph = 0, ///< graph as expression - can auto grad 186be168c0dSopenharmony_ci- kExecutableGraph = 1, ///< graph is loaded as is 187be168c0dSopenharmony_ci- kUnknownTypeGraph = 0xffffffff 188be168c0dSopenharmony_ci- }; 189be168c0dSopenharmony_ci Graph(); 190be168c0dSopenharmony_ci explicit Graph(const std::shared_ptr<GraphData> &graph_data); 191be168c0dSopenharmony_ci explicit Graph(std::shared_ptr<GraphData> &&graph_data); 192be168c0dSopenharmony_ci explicit Graph(std::nullptr_t); 193be168c0dSopenharmony_ci ~Graph(); 194be168c0dSopenharmony_ci- explicit Graph(Type executable); 195be168c0dSopenharmony_ci- explicit Graph(Net *net); 196be168c0dSopenharmony_ci 197be168c0dSopenharmony_ci enum ModelType ModelType() const; 198be168c0dSopenharmony_ci bool operator==(std::nullptr_t) const; 199be168c0dSopenharmony_ci bool operator!=(std::nullptr_t) const; 200be168c0dSopenharmony_ci- bool IsExecutable() { return graph_type_ == kExecutableGraph; } 201be168c0dSopenharmony_ci 202be168c0dSopenharmony_ci private: 203be168c0dSopenharmony_ci friend class GraphCell; 204be168c0dSopenharmony_ci friend class ModelImpl; 205be168c0dSopenharmony_ci- friend class NetImpl; 206be168c0dSopenharmony_ci- friend class Model; 207be168c0dSopenharmony_ci std::shared_ptr<GraphData> graph_data_; 208be168c0dSopenharmony_ci- std::shared_ptr<NetData> net_data_; 209be168c0dSopenharmony_ci- Type graph_type_ = kExecutableGraph; 210be168c0dSopenharmony_ci }; 211be168c0dSopenharmony_ci } // namespace mindspore 212be168c0dSopenharmony_ci #endif // MINDSPORE_INCLUDE_API_GRAPH_H 213be168c0dSopenharmony_cidiff --git a/include/api/model.h b/include/api/model.h 214be168c0dSopenharmony_ciindex 4c4359f4..64c52bb1 100644 215be168c0dSopenharmony_ci--- a/include/api/model.h 216be168c0dSopenharmony_ci+++ b/include/api/model.h 217be168c0dSopenharmony_ci@@ -33,9 +33,6 @@ 218be168c0dSopenharmony_ci namespace mindspore { 219be168c0dSopenharmony_ci class ModelImpl; 220be168c0dSopenharmony_ci class Metrics; 221be168c0dSopenharmony_ci-class Net; 222be168c0dSopenharmony_ci-class Node; 223be168c0dSopenharmony_ci-class Expr; 224be168c0dSopenharmony_ci 225be168c0dSopenharmony_ci namespace dataset { 226be168c0dSopenharmony_ci class Dataset; 227be168c0dSopenharmony_ci@@ -112,17 +109,6 @@ class MS_API Model { 228be168c0dSopenharmony_ci Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr, 229be168c0dSopenharmony_ci const std::shared_ptr<TrainCfg> &train_cfg = nullptr); 230be168c0dSopenharmony_ci 231be168c0dSopenharmony_ci- /// \brief Build train model 232be168c0dSopenharmony_ci- /// 233be168c0dSopenharmony_ci- /// \param[in] graph A forward network 234be168c0dSopenharmony_ci- /// \param[in] optimizer An optimizer node 235be168c0dSopenharmony_ci- /// \param[in] inputs Inputs expression for the trained network (ex: input, label ) 236be168c0dSopenharmony_ci- /// \param[in] model_context A context used to store options during execution. 237be168c0dSopenharmony_ci- /// \param[in] train_cfg A config used by training 238be168c0dSopenharmony_ci- /// \return Status 239be168c0dSopenharmony_ci- Status Build(GraphCell graph, Node *optimizer, std::vector<Expr *> inputs, 240be168c0dSopenharmony_ci- const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg); 241be168c0dSopenharmony_ci- 242be168c0dSopenharmony_ci /// \brief Build a Transfer Learning model where the backbone weights are fixed and the head weights are trainable 243be168c0dSopenharmony_ci /// 244be168c0dSopenharmony_ci /// \param[in] backbone The static, non-learnable part of the graph 245be168c0dSopenharmony_cidiff --git a/include/api/net.h b/include/api/net.h 246be168c0dSopenharmony_cideleted file mode 100644 247be168c0dSopenharmony_ciindex 61990ae0..00000000 248be168c0dSopenharmony_ci--- a/include/api/net.h 249be168c0dSopenharmony_ci+++ /dev/null 250be168c0dSopenharmony_ci@@ -1,142 +0,0 @@ 251be168c0dSopenharmony_ci-/** 252be168c0dSopenharmony_ci- * Copyright 2022-2023 Huawei Technologies Co., Ltd 253be168c0dSopenharmony_ci- * 254be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 255be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 256be168c0dSopenharmony_ci- * You may obtain a copy of the License at 257be168c0dSopenharmony_ci- * 258be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 259be168c0dSopenharmony_ci- * 260be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 261be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 262be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 263be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 264be168c0dSopenharmony_ci- * limitations under the License. 265be168c0dSopenharmony_ci- */ 266be168c0dSopenharmony_ci- 267be168c0dSopenharmony_ci-#ifndef MINDSPORE_INCLUDE_API_NET_H 268be168c0dSopenharmony_ci-#define MINDSPORE_INCLUDE_API_NET_H 269be168c0dSopenharmony_ci- 270be168c0dSopenharmony_ci-#include <memory> 271be168c0dSopenharmony_ci-#include <vector> 272be168c0dSopenharmony_ci-#include <unordered_set> 273be168c0dSopenharmony_ci-#include <string> 274be168c0dSopenharmony_ci-#include "include/api/types.h" 275be168c0dSopenharmony_ci-#include "include/api/data_type.h" 276be168c0dSopenharmony_ci-#include "include/api/cfg.h" 277be168c0dSopenharmony_ci- 278be168c0dSopenharmony_ci-namespace mindspore { 279be168c0dSopenharmony_ci-/// \brief Register node or sub network 280be168c0dSopenharmony_ci-#define REG(_name) Register(_name, #_name) 281be168c0dSopenharmony_ci- 282be168c0dSopenharmony_ci-class Expr; 283be168c0dSopenharmony_ci-class NodeImpl; 284be168c0dSopenharmony_ci-class NetImpl; 285be168c0dSopenharmony_ci-class NodeSet; 286be168c0dSopenharmony_ci-class Graph; 287be168c0dSopenharmony_ci-class NetData; 288be168c0dSopenharmony_ci- 289be168c0dSopenharmony_ci-class MS_API NetBase { 290be168c0dSopenharmony_ci- public: 291be168c0dSopenharmony_ci- NetBase() = default; 292be168c0dSopenharmony_ci- virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0; 293be168c0dSopenharmony_ci- virtual uint32_t type() = 0; 294be168c0dSopenharmony_ci-}; 295be168c0dSopenharmony_ci- 296be168c0dSopenharmony_ci-class MS_API Node : public NetBase { 297be168c0dSopenharmony_ci- public: 298be168c0dSopenharmony_ci- Node(); 299be168c0dSopenharmony_ci- virtual ~Node(); 300be168c0dSopenharmony_ci- /// \brief Create output expression from node 301be168c0dSopenharmony_ci- 302be168c0dSopenharmony_ci- /// \param[in] name Name of input (like "labels" etc.) 303be168c0dSopenharmony_ci- /// 304be168c0dSopenharmony_ci- /// \return Expression 305be168c0dSopenharmony_ci- Expr *Create(std::string name); 306be168c0dSopenharmony_ci- /// \brief Run node on inputs. This operator is used in Net::construct() 307be168c0dSopenharmony_ci- /// 308be168c0dSopenharmony_ci- /// \param[in] inputs Inputs expression for the node. 309be168c0dSopenharmony_ci- /// \return Output node expression vector 310be168c0dSopenharmony_ci- std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override; 311be168c0dSopenharmony_ci- uint32_t type() final; 312be168c0dSopenharmony_ci- 313be168c0dSopenharmony_ci- private: 314be168c0dSopenharmony_ci- friend NodeImpl; 315be168c0dSopenharmony_ci- std::shared_ptr<NodeImpl> impl_ = nullptr; 316be168c0dSopenharmony_ci-}; 317be168c0dSopenharmony_ci- 318be168c0dSopenharmony_ci-class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> { 319be168c0dSopenharmony_ci- public: 320be168c0dSopenharmony_ci- Net(); 321be168c0dSopenharmony_ci- virtual ~Net(); 322be168c0dSopenharmony_ci- explicit Net(std::string name); 323be168c0dSopenharmony_ci- explicit Net(const Graph &g); 324be168c0dSopenharmony_ci- /// \brief Define the relation between network inputs and outputs 325be168c0dSopenharmony_ci- /// 326be168c0dSopenharmony_ci- /// \param[in] inputs expression vector 327be168c0dSopenharmony_ci- /// 328be168c0dSopenharmony_ci- /// \return expression vector 329be168c0dSopenharmony_ci- 330be168c0dSopenharmony_ci- virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs); 331be168c0dSopenharmony_ci- /// \brief Addition operation 332be168c0dSopenharmony_ci- /// 333be168c0dSopenharmony_ci- /// \param[in] inputs Two elements to add 334be168c0dSopenharmony_ci- /// 335be168c0dSopenharmony_ci- /// \return expression vector (single element) 336be168c0dSopenharmony_ci- 337be168c0dSopenharmony_ci- /// \brief Execution operator. Connect inputs to outputs via user defined construct 338be168c0dSopenharmony_ci- /// 339be168c0dSopenharmony_ci- /// \return expression vector 340be168c0dSopenharmony_ci- 341be168c0dSopenharmony_ci- std::vector<Expr *> operator()(const std::vector<Expr *> &inputs); 342be168c0dSopenharmony_ci- void Register(Net *net, std::string &&name); 343be168c0dSopenharmony_ci- void Register(Node *node, std::string &&name); 344be168c0dSopenharmony_ci- /// \brief Find the trainable params for the trained network 345be168c0dSopenharmony_ci- /// 346be168c0dSopenharmony_ci- /// \return NodeSet for all trainable nodes 347be168c0dSopenharmony_ci- std::shared_ptr<NodeSet> trainable_params(); 348be168c0dSopenharmony_ci- virtual void Add(NetBase *element); 349be168c0dSopenharmony_ci- /// \brief Input shape 350be168c0dSopenharmony_ci- /// 351be168c0dSopenharmony_ci- /// \param[in] idx input index 352be168c0dSopenharmony_ci- /// 353be168c0dSopenharmony_ci- /// \return Specific input shape vector 354be168c0dSopenharmony_ci- const std::vector<int> InputShape(int idx); 355be168c0dSopenharmony_ci- /// \brief Output shape 356be168c0dSopenharmony_ci- /// 357be168c0dSopenharmony_ci- /// \param[in] idx Output index 358be168c0dSopenharmony_ci- /// 359be168c0dSopenharmony_ci- /// \return Specific output shape vector 360be168c0dSopenharmony_ci- const std::vector<int> OutputShape(int idx); 361be168c0dSopenharmony_ci- uint32_t type() final; 362be168c0dSopenharmony_ci- 363be168c0dSopenharmony_ci- private: 364be168c0dSopenharmony_ci- friend NetImpl; 365be168c0dSopenharmony_ci- friend NetData; 366be168c0dSopenharmony_ci- std::shared_ptr<NetImpl> impl_; 367be168c0dSopenharmony_ci-}; 368be168c0dSopenharmony_ci- 369be168c0dSopenharmony_ci-class MS_API SoftMaxCrossEntropyCfg { 370be168c0dSopenharmony_ci- public: 371be168c0dSopenharmony_ci- std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */ 372be168c0dSopenharmony_ci-}; 373be168c0dSopenharmony_ci- 374be168c0dSopenharmony_ci-class MS_API AdamConfig { 375be168c0dSopenharmony_ci- public: 376be168c0dSopenharmony_ci- float learning_rate_ = 1e-3; 377be168c0dSopenharmony_ci- float beta1_ = 0.9; 378be168c0dSopenharmony_ci- float beta2_ = 0.999; 379be168c0dSopenharmony_ci- float eps_ = 1e-08; 380be168c0dSopenharmony_ci- bool use_nesterov_ = false; 381be168c0dSopenharmony_ci-}; 382be168c0dSopenharmony_ci- 383be168c0dSopenharmony_ci-namespace NN { 384be168c0dSopenharmony_ci-MS_API Net *NetWithLoss(Net *net, Node *loss); 385be168c0dSopenharmony_ci-MS_API Graph *GraphWithLoss(Graph *g, Node *loss); 386be168c0dSopenharmony_ci-MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg); 387be168c0dSopenharmony_ci-MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg); 388be168c0dSopenharmony_ci-MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32, 389be168c0dSopenharmony_ci- int fmt = NHWC); 390be168c0dSopenharmony_ci-}; // namespace NN 391be168c0dSopenharmony_ci-} // namespace mindspore 392be168c0dSopenharmony_ci-#endif // MINDSPORE_INCLUDE_API_NET_H 393be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 394be168c0dSopenharmony_ciindex 3c192df7..37aaa410 100644 395be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 396be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/reshape_infer.c 397be168c0dSopenharmony_ci@@ -182,7 +182,7 @@ int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC 398be168c0dSopenharmony_ci size_t out_shape_size = 0; 399be168c0dSopenharmony_ci if (inputs_size == 2) { 400be168c0dSopenharmony_ci const TensorC *shape_tensor = inputs[1]; 401be168c0dSopenharmony_ci- if (GetElementNum(input) == 1 && input->shape_size_ == 0) { 402be168c0dSopenharmony_ci+ if (GetElementNum(input) == 1) { 403be168c0dSopenharmony_ci if (shape_tensor->data_ == NULL || (shape_tensor->shape_size_ == 1 && shape_tensor->shape_[0] == 0)) { 404be168c0dSopenharmony_ci SetShapeArray(output, out_shape, out_shape_size); 405be168c0dSopenharmony_ci return NNACL_OK; 406be168c0dSopenharmony_cidiff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 407be168c0dSopenharmony_ciindex 9318d54e..d7fe4f55 100644 408be168c0dSopenharmony_ci--- a/mindspore/lite/BUILD.gn 409be168c0dSopenharmony_ci+++ b/mindspore/lite/BUILD.gn 410be168c0dSopenharmony_ci@@ -99,7 +99,6 @@ cxx_api_sources = [ 411be168c0dSopenharmony_ci "src/litert/cxx_api/model/model_group_impl.cc", 412be168c0dSopenharmony_ci "src/litert/cxx_api/model/model_impl.cc", 413be168c0dSopenharmony_ci "src/litert/cxx_api/graph/graph.cc", 414be168c0dSopenharmony_ci- "src/litert/cxx_api/graph/net_data.cc", 415be168c0dSopenharmony_ci "src/litert/cxx_api/tensor/tensor_impl.cc", 416be168c0dSopenharmony_ci ] 417be168c0dSopenharmony_ci 418be168c0dSopenharmony_ci@@ -532,50 +531,6 @@ ohos_shared_library("mindspore_ndk") { 419be168c0dSopenharmony_ci } 420be168c0dSopenharmony_ci 421be168c0dSopenharmony_ci # Train library 422be168c0dSopenharmony_ci-expression_cxx_api_sources = [ 423be168c0dSopenharmony_ci- "src/litert/cxx_api/expression/net.cc", 424be168c0dSopenharmony_ci- "src/litert/cxx_api/expression/net_impl.cc", 425be168c0dSopenharmony_ci- "src/litert/cxx_api/expression/node_impl.cc", 426be168c0dSopenharmony_ci-] 427be168c0dSopenharmony_ci- 428be168c0dSopenharmony_ci-expression_op_sources = [ 429be168c0dSopenharmony_ci- "src/expression/ops/activation.cc", 430be168c0dSopenharmony_ci- "src/expression/ops/adam.cc", 431be168c0dSopenharmony_ci- "src/expression/ops/addn.cc", 432be168c0dSopenharmony_ci- "src/expression/ops/arithmetic.cc", 433be168c0dSopenharmony_ci- "src/expression/ops/arithmetic_self.cc", 434be168c0dSopenharmony_ci- "src/expression/ops/assign.cc", 435be168c0dSopenharmony_ci- "src/expression/ops/batchnorm.cc", 436be168c0dSopenharmony_ci- "src/expression/ops/biasadd.cc", 437be168c0dSopenharmony_ci- "src/expression/ops/conv.cc", 438be168c0dSopenharmony_ci- "src/expression/ops/dense.cc", 439be168c0dSopenharmony_ci- "src/expression/ops/depend.cc", 440be168c0dSopenharmony_ci- "src/expression/ops/dropout.cc", 441be168c0dSopenharmony_ci- "src/expression/ops/flatten.cc", 442be168c0dSopenharmony_ci- "src/expression/ops/pooling.cc", 443be168c0dSopenharmony_ci- "src/expression/ops/reduce.cc", 444be168c0dSopenharmony_ci- "src/expression/ops/reshape.cc", 445be168c0dSopenharmony_ci- "src/expression/ops/softmax.cc", 446be168c0dSopenharmony_ci- "src/expression/ops/softmaxCE.cc", 447be168c0dSopenharmony_ci- "src/expression/ops/tile.cc", 448be168c0dSopenharmony_ci- "src/expression/ops/transpose.cc", 449be168c0dSopenharmony_ci-] 450be168c0dSopenharmony_ci- 451be168c0dSopenharmony_ci-all_expression_sources = [ 452be168c0dSopenharmony_ci- "src/expression/export.cc", 453be168c0dSopenharmony_ci- "src/expression/expr.cc", 454be168c0dSopenharmony_ci- "src/expression/import.cc", 455be168c0dSopenharmony_ci- "src/expression/net.cc", 456be168c0dSopenharmony_ci- "src/expression/node.cc", 457be168c0dSopenharmony_ci- "src/expression/ops.cc", 458be168c0dSopenharmony_ci- "src/expression/ops_utils.cc", 459be168c0dSopenharmony_ci- "src/expression/param.cc", 460be168c0dSopenharmony_ci- "src/expression/sequential.cc", 461be168c0dSopenharmony_ci-] 462be168c0dSopenharmony_ci- 463be168c0dSopenharmony_ci-all_expression_sources += expression_cxx_api_sources 464be168c0dSopenharmony_ci-all_expression_sources += expression_op_sources 465be168c0dSopenharmony_ci- 466be168c0dSopenharmony_ci all_train_sources = [ 467be168c0dSopenharmony_ci # ${API_TRAIN_SRC} is empty. 468be168c0dSopenharmony_ci # ${TRAIN_SRC_WITH_MD} is empty. 469be168c0dSopenharmony_ci@@ -604,6 +559,7 @@ all_train_sources = [ 470be168c0dSopenharmony_ci "src/train/optimizer/fusion/gru_fusion_pass.cc", 471be168c0dSopenharmony_ci "src/train/optimizer/fusion/matmul_add_fusion_pass.cc", 472be168c0dSopenharmony_ci "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc", 473be168c0dSopenharmony_ci+ "src/train/optimizer/fusion/remove_redundant_tensor.cc", 474be168c0dSopenharmony_ci "src/common/storage.cc", 475be168c0dSopenharmony_ci "tools/converter/optimizer.cc", 476be168c0dSopenharmony_ci "tools/converter/legacy_optimizer/fusion/fusion_pass.cc", 477be168c0dSopenharmony_ci@@ -616,8 +572,6 @@ all_train_sources = [ 478be168c0dSopenharmony_ci "tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc", 479be168c0dSopenharmony_ci ] 480be168c0dSopenharmony_ci 481be168c0dSopenharmony_ci-all_train_sources += all_expression_sources 482be168c0dSopenharmony_ci- 483be168c0dSopenharmony_ci fp16_train_kernel_sources = [ 484be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp16_grad/activation_fp16_grad.cc", 485be168c0dSopenharmony_ci "src/litert/kernel/cpu/fp16_grad/arithmetic_fp16_grad.cc", 486be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 487be168c0dSopenharmony_ciindex 469bcb6b..47033473 100644 488be168c0dSopenharmony_ci--- a/mindspore/lite/src/CMakeLists.txt 489be168c0dSopenharmony_ci+++ b/mindspore/lite/src/CMakeLists.txt 490be168c0dSopenharmony_ci@@ -289,33 +289,9 @@ if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") 491be168c0dSopenharmony_ci ) 492be168c0dSopenharmony_ci endif() 493be168c0dSopenharmony_ci 494be168c0dSopenharmony_ci- 495be168c0dSopenharmony_ci-file(GLOB CXX_API_EXPRESSION 496be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/expression/*.cc 497be168c0dSopenharmony_ci- ) 498be168c0dSopenharmony_ci- 499be168c0dSopenharmony_ci-file(GLOB EXPRESSION_OPS 500be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/ops/*.cc 501be168c0dSopenharmony_ci- ) 502be168c0dSopenharmony_ci- 503be168c0dSopenharmony_ci-set(EXPRESSION_SRC 504be168c0dSopenharmony_ci- ${CXX_API_EXPRESSION} 505be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/export.cc 506be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/expr.cc 507be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/import.cc 508be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/net.cc 509be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/node.cc 510be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/ops.cc 511be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/ops_utils.cc 512be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/param.cc 513be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/expression/sequential.cc 514be168c0dSopenharmony_ci- ${EXPRESSION_OPS} 515be168c0dSopenharmony_ci- ) 516be168c0dSopenharmony_ci- 517be168c0dSopenharmony_ci set(TRAIN_SRC 518be168c0dSopenharmony_ci ${API_TRAIN_SRC} 519be168c0dSopenharmony_ci ${TRAIN_SRC_WITH_MD} 520be168c0dSopenharmony_ci- ${EXPRESSION_SRC} 521be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/metrics/accuracy.cc 522be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/train/model_build.cc 523be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/litert/cxx_api/train/model_build_impl.cc 524be168c0dSopenharmony_ci@@ -340,6 +316,7 @@ set(TRAIN_SRC 525be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_add_fusion_pass.cc 526be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/matmul_matmul_add_fusion_pass.cc 527be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/reshape_gather_reshape_fusion_pass.cc 528be168c0dSopenharmony_ci+ ${CMAKE_CURRENT_SOURCE_DIR}/train/optimizer/fusion/remove_redundant_tensor.cc 529be168c0dSopenharmony_ci ${TOOLS_DIR}/converter/optimizer.cc 530be168c0dSopenharmony_ci ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pass.cc 531be168c0dSopenharmony_ci ${TOOLS_DIR}/converter/legacy_optimizer/fusion/fusion_pattern.cc 532be168c0dSopenharmony_ci@@ -399,13 +376,6 @@ if(NOT MSLITE_ENABLE_COREML) 533be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/litert/delegate/coreml/stub/coreml_delegate_stub.cc) 534be168c0dSopenharmony_ci endif() 535be168c0dSopenharmony_ci 536be168c0dSopenharmony_ci-if(MSVC) 537be168c0dSopenharmony_ci- set(LITE_SRC 538be168c0dSopenharmony_ci- ${LITE_SRC} 539be168c0dSopenharmony_ci- ${EXPRESSION_SRC} 540be168c0dSopenharmony_ci- ) 541be168c0dSopenharmony_ci-endif() 542be168c0dSopenharmony_ci- 543be168c0dSopenharmony_ci add_subdirectory(litert/kernel/cpu) 544be168c0dSopenharmony_ci add_subdirectory(common) 545be168c0dSopenharmony_ci 546be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc 547be168c0dSopenharmony_ciindex 6c490130..9933f6ab 100644 548be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc 549be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc 550be168c0dSopenharmony_ci@@ -23,6 +23,7 @@ 551be168c0dSopenharmony_ci #include "nnacl/custom_is_inf_parameter.h" 552be168c0dSopenharmony_ci #include "nnacl/custom_tensor_scatter_max_parameter.h" 553be168c0dSopenharmony_ci #include "nnacl/custom_gather_d_grad_v2_parameter.h" 554be168c0dSopenharmony_ci+#include "nnacl/scatter_nd_parameter.h" 555be168c0dSopenharmony_ci using mindspore::schema::PrimitiveType_Custom; 556be168c0dSopenharmony_ci 557be168c0dSopenharmony_ci namespace mindspore { 558be168c0dSopenharmony_ci@@ -108,13 +109,13 @@ OpParameter *CreateCustomIsInfParameter() { 559be168c0dSopenharmony_ci } 560be168c0dSopenharmony_ci 561be168c0dSopenharmony_ci OpParameter *CreateCustomTensorScatterMaxParameter() { 562be168c0dSopenharmony_ci- auto *param = static_cast<CustomTensorScatterMaxParameter *>(malloc(sizeof(CustomTensorScatterMaxParameter))); 563be168c0dSopenharmony_ci+ auto *param = static_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter))); 564be168c0dSopenharmony_ci if (param == nullptr) { 565be168c0dSopenharmony_ci- MS_LOG(ERROR) << "malloc CustomTensorScatterMaxParameter failed."; 566be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc ScatterNDParameter failed."; 567be168c0dSopenharmony_ci return nullptr; 568be168c0dSopenharmony_ci } 569be168c0dSopenharmony_ci- memset(param, 0, sizeof(CustomTensorScatterMaxParameter)); 570be168c0dSopenharmony_ci- param->op_parameter_.type_ = PrimType_Inner_CustomTensorScatterMax; 571be168c0dSopenharmony_ci+ memset(param, 0, sizeof(ScatterNDParameter)); 572be168c0dSopenharmony_ci+ param->op_parameter.type_ = PrimType_Inner_CustomTensorScatterMax; 573be168c0dSopenharmony_ci return reinterpret_cast<OpParameter *>(param); 574be168c0dSopenharmony_ci } 575be168c0dSopenharmony_ci 576be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/cfg.h b/mindspore/lite/src/expression/cfg.h 577be168c0dSopenharmony_cideleted file mode 100644 578be168c0dSopenharmony_ciindex e590d2b7..00000000 579be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/cfg.h 580be168c0dSopenharmony_ci+++ /dev/null 581be168c0dSopenharmony_ci@@ -1,68 +0,0 @@ 582be168c0dSopenharmony_ci-/** 583be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 584be168c0dSopenharmony_ci- * 585be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 586be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 587be168c0dSopenharmony_ci- * You may obtain a copy of the License at 588be168c0dSopenharmony_ci- * 589be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 590be168c0dSopenharmony_ci- * 591be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 592be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 593be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 594be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 595be168c0dSopenharmony_ci- * limitations under the License. 596be168c0dSopenharmony_ci- */ 597be168c0dSopenharmony_ci- 598be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_ 599be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_ 600be168c0dSopenharmony_ci- 601be168c0dSopenharmony_ci-#include <vector> 602be168c0dSopenharmony_ci-#include <string> 603be168c0dSopenharmony_ci- 604be168c0dSopenharmony_ci-namespace mindspore { 605be168c0dSopenharmony_ci-namespace lite { 606be168c0dSopenharmony_ci-class ConvConfig { 607be168c0dSopenharmony_ci- public: 608be168c0dSopenharmony_ci- ConvConfig() = default; 609be168c0dSopenharmony_ci- int in_channel_ = 3; /**< The channel number of the input of the Conv2d layer */ 610be168c0dSopenharmony_ci- int out_channel_ = 3; /**< The channel number of the output tensor of the Conv2d layer */ 611be168c0dSopenharmony_ci- std::vector<int64_t> kernel_size_ = {3, 3}; /**< Specifies the height and width of the 2D convolution kernel. */ 612be168c0dSopenharmony_ci- std::vector<int64_t> stride_ = {1, 1}; /**< The movement stride of the 2D convolution kernel */ 613be168c0dSopenharmony_ci- std::vector<int64_t> padding_ = {0, 0, 0, 0}; /**< The top, bottom, left, and right padding input */ 614be168c0dSopenharmony_ci- std::vector<int64_t> dilation_ = {1, 1}; /**< diletion height and width*/ 615be168c0dSopenharmony_ci- int group_ = 1; // < Splits filter into groups, `in_channels` and `out_channels` must be 616be168c0dSopenharmony_ci- // divisible by `group`. If the group is equal to `in_channels` and `out_channels`, 617be168c0dSopenharmony_ci- // this 2D convolution layer also can be called 2D depthwise convolution layer */ 618be168c0dSopenharmony_ci- bool has_bias = false; /** < Whether the Conv2d layer has a bias parameter */ 619be168c0dSopenharmony_ci- std::string weight_init_ = 620be168c0dSopenharmony_ci- "normal"; /**< Initialization method of weight parameter ("normal","uniform", "ones", "zeros") */ 621be168c0dSopenharmony_ci- std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid", "pad" */ 622be168c0dSopenharmony_ci- 623be168c0dSopenharmony_ci- private: 624be168c0dSopenharmony_ci- std::string bias_init_ = "zeros"; 625be168c0dSopenharmony_ci- std::string data_format; 626be168c0dSopenharmony_ci-}; 627be168c0dSopenharmony_ci- 628be168c0dSopenharmony_ci-class DenseConfig { 629be168c0dSopenharmony_ci- public: 630be168c0dSopenharmony_ci- int in_channels_; /**< The number of channels in the input space */ 631be168c0dSopenharmony_ci- int out_channels_; /**< The number of channels in the output space */ 632be168c0dSopenharmony_ci- bool has_bias_ = false; /** Specifies whether the layer uses a bias vector **/ 633be168c0dSopenharmony_ci- private: 634be168c0dSopenharmony_ci- std::string weight_init_ = "normal"; 635be168c0dSopenharmony_ci- std::string bias_init_ = "zeros"; 636be168c0dSopenharmony_ci- std::string activation_ = "none"; 637be168c0dSopenharmony_ci-}; 638be168c0dSopenharmony_ci- 639be168c0dSopenharmony_ci-class PoolingConfig { 640be168c0dSopenharmony_ci- public: 641be168c0dSopenharmony_ci- PoolingConfig() = default; 642be168c0dSopenharmony_ci- std::vector<int64_t> kernel_size_ = {1, 1}; /**< Specifies the height and width of the 2D kernel. */ 643be168c0dSopenharmony_ci- std::vector<int64_t> stride_ = {1, 1}; /**< The movement stride of the 2D kernel */ 644be168c0dSopenharmony_ci- std::string pad_mode_ = "same"; /**< Specifies padding mode. The optional values are "same", "valid" */ 645be168c0dSopenharmony_ci-}; 646be168c0dSopenharmony_ci-} // namespace lite 647be168c0dSopenharmony_ci-} // namespace mindspore 648be168c0dSopenharmony_ci- 649be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_CFG_H_ 650be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/export.cc b/mindspore/lite/src/expression/export.cc 651be168c0dSopenharmony_cideleted file mode 100644 652be168c0dSopenharmony_ciindex a86c54a6..00000000 653be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/export.cc 654be168c0dSopenharmony_ci+++ /dev/null 655be168c0dSopenharmony_ci@@ -1,76 +0,0 @@ 656be168c0dSopenharmony_ci-/** 657be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 658be168c0dSopenharmony_ci- * 659be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 660be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 661be168c0dSopenharmony_ci- * You may obtain a copy of the License at 662be168c0dSopenharmony_ci- * 663be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 664be168c0dSopenharmony_ci- * 665be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 666be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 667be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 668be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 669be168c0dSopenharmony_ci- * limitations under the License. 670be168c0dSopenharmony_ci- */ 671be168c0dSopenharmony_ci- 672be168c0dSopenharmony_ci-#include <utility> 673be168c0dSopenharmony_ci-#include "src/expression/export.h" 674be168c0dSopenharmony_ci-#include "src/expression/ops.h" 675be168c0dSopenharmony_ci-#include "src/common/utils.h" 676be168c0dSopenharmony_ci-#include "nnacl/conv_parameter.h" 677be168c0dSopenharmony_ci-#include "include/errorcode.h" 678be168c0dSopenharmony_ci- 679be168c0dSopenharmony_ci-namespace mindspore { 680be168c0dSopenharmony_ci-namespace lite { 681be168c0dSopenharmony_ci-constexpr static int kFmkVal = 3; 682be168c0dSopenharmony_ci- 683be168c0dSopenharmony_ci-int ExportSession::Init(const std::string model_name, std::string version) { 684be168c0dSopenharmony_ci- meta_graph_ = new (std::nothrow) mindspore::schema::MetaGraphT(); 685be168c0dSopenharmony_ci- if (meta_graph_ == nullptr) { 686be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot allocate meta_graph"; 687be168c0dSopenharmony_ci- return RET_ERROR; 688be168c0dSopenharmony_ci- } 689be168c0dSopenharmony_ci- meta_graph_->fmkType = kFmkVal; 690be168c0dSopenharmony_ci- meta_graph_->name = model_name; 691be168c0dSopenharmony_ci- meta_graph_->version = version; 692be168c0dSopenharmony_ci- return RET_OK; 693be168c0dSopenharmony_ci-} 694be168c0dSopenharmony_ci- 695be168c0dSopenharmony_ci-bool ExportSession::IsToDependOnly(EXPR *expr) { 696be168c0dSopenharmony_ci- auto itr = outmap_.find(expr); 697be168c0dSopenharmony_ci- if (itr != outmap_.end() && !itr->second.empty()) { 698be168c0dSopenharmony_ci- for (auto expr : itr->second) { 699be168c0dSopenharmony_ci- auto node = expr->node(); 700be168c0dSopenharmony_ci- if (node->primitive() != schema::PrimitiveType_Depend) return false; 701be168c0dSopenharmony_ci- } 702be168c0dSopenharmony_ci- return true; 703be168c0dSopenharmony_ci- } 704be168c0dSopenharmony_ci- return false; 705be168c0dSopenharmony_ci-} 706be168c0dSopenharmony_ci- 707be168c0dSopenharmony_ci-int ExportSession::SetInputOutput(const std::vector<EXPR *> &inputs, const std::vector<EXPR *> &outputs) { 708be168c0dSopenharmony_ci- for (auto &in : inputs) { 709be168c0dSopenharmony_ci- auto id = GetOutput(in); 710be168c0dSopenharmony_ci- meta_graph_->inputIndex.push_back(id); 711be168c0dSopenharmony_ci- } 712be168c0dSopenharmony_ci- for (auto &out : outputs) { 713be168c0dSopenharmony_ci- auto id = GetOutput(out); 714be168c0dSopenharmony_ci- meta_graph_->outputIndex.push_back(id); 715be168c0dSopenharmony_ci- } 716be168c0dSopenharmony_ci- auto sub_graph = std::make_unique<mindspore::schema::SubGraphT>(); 717be168c0dSopenharmony_ci- if (sub_graph == nullptr) { 718be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot allocate SubGraphT"; 719be168c0dSopenharmony_ci- return RET_ERROR; 720be168c0dSopenharmony_ci- } 721be168c0dSopenharmony_ci- auto model_name = meta_graph_->name; 722be168c0dSopenharmony_ci- sub_graph->name = model_name + "_subgraph"; 723be168c0dSopenharmony_ci- sub_graph->inputIndices = meta_graph_->inputIndex; 724be168c0dSopenharmony_ci- sub_graph->outputIndices = meta_graph_->outputIndex; 725be168c0dSopenharmony_ci- for (size_t i = 0; i < meta_graph_->nodes.size(); i++) sub_graph->nodeIndices.push_back(i); 726be168c0dSopenharmony_ci- for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) sub_graph->tensorIndices.push_back(i); 727be168c0dSopenharmony_ci- meta_graph_->subGraph.emplace_back(std::move(sub_graph)); 728be168c0dSopenharmony_ci- return RET_OK; 729be168c0dSopenharmony_ci-} 730be168c0dSopenharmony_ci-} // namespace lite 731be168c0dSopenharmony_ci-} // namespace mindspore 732be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/export.h b/mindspore/lite/src/expression/export.h 733be168c0dSopenharmony_cideleted file mode 100644 734be168c0dSopenharmony_ciindex 8009e080..00000000 735be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/export.h 736be168c0dSopenharmony_ci+++ /dev/null 737be168c0dSopenharmony_ci@@ -1,52 +0,0 @@ 738be168c0dSopenharmony_ci-/** 739be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 740be168c0dSopenharmony_ci- * 741be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 742be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 743be168c0dSopenharmony_ci- * You may obtain a copy of the License at 744be168c0dSopenharmony_ci- * 745be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 746be168c0dSopenharmony_ci- * 747be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 748be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 749be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 750be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 751be168c0dSopenharmony_ci- * limitations under the License. 752be168c0dSopenharmony_ci- */ 753be168c0dSopenharmony_ci- 754be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_ 755be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_ 756be168c0dSopenharmony_ci- 757be168c0dSopenharmony_ci-#include <vector> 758be168c0dSopenharmony_ci-#include <memory> 759be168c0dSopenharmony_ci-#include <unordered_map> 760be168c0dSopenharmony_ci-#include <string> 761be168c0dSopenharmony_ci-#include <list> 762be168c0dSopenharmony_ci-#include <map> 763be168c0dSopenharmony_ci-#include <iostream> 764be168c0dSopenharmony_ci-#include "src/expression/expr.h" 765be168c0dSopenharmony_ci- 766be168c0dSopenharmony_ci-namespace mindspore { 767be168c0dSopenharmony_ci-namespace schema { 768be168c0dSopenharmony_ci-struct MetaGraphT; 769be168c0dSopenharmony_ci-} 770be168c0dSopenharmony_ci-namespace lite { 771be168c0dSopenharmony_ci-class ExportSession { 772be168c0dSopenharmony_ci- public: 773be168c0dSopenharmony_ci- explicit ExportSession(std::map<EXPR *, std::list<EXPR *>> &outmap) : outmap_(outmap) {} 774be168c0dSopenharmony_ci- int Init(const std::string model_name, std::string version); 775be168c0dSopenharmony_ci- void UpdateOutput(EXPR *expr, int id) { output_tensors_[expr] = id; } 776be168c0dSopenharmony_ci- int GetOutput(EXPR *expr) { return output_tensors_.at(expr); } 777be168c0dSopenharmony_ci- schema::MetaGraphT *&meta_graph() { return meta_graph_; } 778be168c0dSopenharmony_ci- int SetInputOutput(const std::vector<EXPR *> &inputs, const std::vector<EXPR *> &outputs); 779be168c0dSopenharmony_ci- bool IsToDependOnly(EXPR *expr); 780be168c0dSopenharmony_ci- 781be168c0dSopenharmony_ci- private: 782be168c0dSopenharmony_ci- schema::MetaGraphT *meta_graph_{nullptr}; 783be168c0dSopenharmony_ci- std::unordered_map<EXPR *, int> output_tensors_; // output tensors per EXPR 784be168c0dSopenharmony_ci- std::map<EXPR *, std::list<EXPR *>> &outmap_; 785be168c0dSopenharmony_ci-}; 786be168c0dSopenharmony_ci-} // namespace lite 787be168c0dSopenharmony_ci-} // namespace mindspore 788be168c0dSopenharmony_ci- 789be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPORT_H_ 790be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/expr.cc b/mindspore/lite/src/expression/expr.cc 791be168c0dSopenharmony_cideleted file mode 100644 792be168c0dSopenharmony_ciindex b27853d2..00000000 793be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/expr.cc 794be168c0dSopenharmony_ci+++ /dev/null 795be168c0dSopenharmony_ci@@ -1,98 +0,0 @@ 796be168c0dSopenharmony_ci-/** 797be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 798be168c0dSopenharmony_ci- * 799be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 800be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 801be168c0dSopenharmony_ci- * You may obtain a copy of the License at 802be168c0dSopenharmony_ci- * 803be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 804be168c0dSopenharmony_ci- * 805be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 806be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 807be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 808be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 809be168c0dSopenharmony_ci- * limitations under the License. 810be168c0dSopenharmony_ci- */ 811be168c0dSopenharmony_ci- 812be168c0dSopenharmony_ci-#include <string> 813be168c0dSopenharmony_ci-#include <iostream> 814be168c0dSopenharmony_ci-#include "src/expression/expr.h" 815be168c0dSopenharmony_ci-#include "src/expression/node.h" 816be168c0dSopenharmony_ci- 817be168c0dSopenharmony_ci-namespace mindspore { 818be168c0dSopenharmony_ci-namespace lite { 819be168c0dSopenharmony_ci-std::string EXPR::name() { return node_->name(); } 820be168c0dSopenharmony_ci-void EXPR::Travers(std::function<bool(EXPR *e, EXPR *itr)> cb) { 821be168c0dSopenharmony_ci- if (!visited) { 822be168c0dSopenharmony_ci- visited = true; 823be168c0dSopenharmony_ci- for (auto &itr : params_) { 824be168c0dSopenharmony_ci- if (cb(this, itr)) { 825be168c0dSopenharmony_ci- itr->Travers(cb); 826be168c0dSopenharmony_ci- } 827be168c0dSopenharmony_ci- } 828be168c0dSopenharmony_ci- } 829be168c0dSopenharmony_ci-} 830be168c0dSopenharmony_ci- 831be168c0dSopenharmony_ci-void EXPR::Replace(EXPR **old, EXPR **n, std::vector<Node *> *to_delete) { 832be168c0dSopenharmony_ci- if (!visited) { 833be168c0dSopenharmony_ci- visited = true; 834be168c0dSopenharmony_ci- for (auto &itr : params_) 835be168c0dSopenharmony_ci- if (itr == *old) { 836be168c0dSopenharmony_ci- to_delete->push_back(itr->node()); 837be168c0dSopenharmony_ci- itr = *n; 838be168c0dSopenharmony_ci- } 839be168c0dSopenharmony_ci- for (auto &itr : params_) itr->Replace(old, n, to_delete); 840be168c0dSopenharmony_ci- } 841be168c0dSopenharmony_ci-} 842be168c0dSopenharmony_ci- 843be168c0dSopenharmony_ci-void EXPR::Replace(const std::vector<EXPR *> &vec, std::vector<EXPR *> *old, std::vector<EXPR *> *n) { 844be168c0dSopenharmony_ci- std::vector<Node *> to_delete; 845be168c0dSopenharmony_ci- for (auto &e : vec) { 846be168c0dSopenharmony_ci- for (std::size_t i = 0; i < old->size(); i++) { 847be168c0dSopenharmony_ci- e->Replace(&old->at(i), &n->at(i), &to_delete); 848be168c0dSopenharmony_ci- } 849be168c0dSopenharmony_ci- } 850be168c0dSopenharmony_ci- for (auto &itr : to_delete) delete itr; 851be168c0dSopenharmony_ci- for (auto e : vec) e->Clear(); 852be168c0dSopenharmony_ci-} 853be168c0dSopenharmony_ci- 854be168c0dSopenharmony_ci-void EXPR::Clear() { 855be168c0dSopenharmony_ci- EXPR *item = this; 856be168c0dSopenharmony_ci- if (visited == false) return; 857be168c0dSopenharmony_ci- visited = false; 858be168c0dSopenharmony_ci- while (item->params_.size() == 1) { 859be168c0dSopenharmony_ci- item = item->params_.front(); 860be168c0dSopenharmony_ci- if (item->visited == false) return; 861be168c0dSopenharmony_ci- item->visited = false; 862be168c0dSopenharmony_ci- } 863be168c0dSopenharmony_ci- for (auto &itr : item->params_) itr->Clear(); 864be168c0dSopenharmony_ci-} 865be168c0dSopenharmony_ci- 866be168c0dSopenharmony_ci-void EXPR::Clear(std::vector<EXPR *> vec) { 867be168c0dSopenharmony_ci- for (auto e : vec) e->Clear(); 868be168c0dSopenharmony_ci-} 869be168c0dSopenharmony_ci- 870be168c0dSopenharmony_ci-void EXPR::CreateOutputMap(std::vector<EXPR *> vec, std::map<EXPR *, std::list<EXPR *>> *outmap) { 871be168c0dSopenharmony_ci- for (auto e : vec) { 872be168c0dSopenharmony_ci- e->Travers([&](EXPR *e, EXPR *itr) { 873be168c0dSopenharmony_ci- (*outmap)[itr].push_back(e); 874be168c0dSopenharmony_ci- return true; 875be168c0dSopenharmony_ci- }); 876be168c0dSopenharmony_ci- } 877be168c0dSopenharmony_ci- Clear(vec); 878be168c0dSopenharmony_ci-} 879be168c0dSopenharmony_ci- 880be168c0dSopenharmony_ci-void EXPR::PrintDot(std::vector<EXPR *> vec) { 881be168c0dSopenharmony_ci- std::cout << "digraph \"expr\" { " << std::endl; 882be168c0dSopenharmony_ci- for (auto e : vec) { 883be168c0dSopenharmony_ci- e->Travers([](EXPR *e, EXPR *itr) { 884be168c0dSopenharmony_ci- std::cout << "\"" << itr->node_->name() << "\"->" 885be168c0dSopenharmony_ci- << "\"" << e->node_->name() << "\"" << std::endl; 886be168c0dSopenharmony_ci- return true; 887be168c0dSopenharmony_ci- }); 888be168c0dSopenharmony_ci- } 889be168c0dSopenharmony_ci- std::cout << "}" << std::endl; 890be168c0dSopenharmony_ci- Clear(vec); 891be168c0dSopenharmony_ci-} 892be168c0dSopenharmony_ci-} // namespace lite 893be168c0dSopenharmony_ci-} // namespace mindspore 894be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/expr.h b/mindspore/lite/src/expression/expr.h 895be168c0dSopenharmony_cideleted file mode 100644 896be168c0dSopenharmony_ciindex 8c76befd..00000000 897be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/expr.h 898be168c0dSopenharmony_ci+++ /dev/null 899be168c0dSopenharmony_ci@@ -1,70 +0,0 @@ 900be168c0dSopenharmony_ci-/** 901be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 902be168c0dSopenharmony_ci- * 903be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 904be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 905be168c0dSopenharmony_ci- * You may obtain a copy of the License at 906be168c0dSopenharmony_ci- * 907be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 908be168c0dSopenharmony_ci- * 909be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 910be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 911be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 912be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 913be168c0dSopenharmony_ci- * limitations under the License. 914be168c0dSopenharmony_ci- */ 915be168c0dSopenharmony_ci- 916be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_ 917be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_ 918be168c0dSopenharmony_ci- 919be168c0dSopenharmony_ci-#include <vector> 920be168c0dSopenharmony_ci-#include <list> 921be168c0dSopenharmony_ci-#include <memory> 922be168c0dSopenharmony_ci-#include <map> 923be168c0dSopenharmony_ci-#include <functional> 924be168c0dSopenharmony_ci-#include <string> 925be168c0dSopenharmony_ci-#include "include/api/format.h" 926be168c0dSopenharmony_ci-#include "mindapi/base/type_id.h" 927be168c0dSopenharmony_ci- 928be168c0dSopenharmony_ci-namespace mindspore { 929be168c0dSopenharmony_ci-namespace lite { 930be168c0dSopenharmony_ci-class Node; 931be168c0dSopenharmony_ci- 932be168c0dSopenharmony_ci-class EXPR { 933be168c0dSopenharmony_ci- public: 934be168c0dSopenharmony_ci- explicit EXPR(Node *node) : node_(node) { SetSize(1); } 935be168c0dSopenharmony_ci- static void PrintDot(std::vector<EXPR *> vec); 936be168c0dSopenharmony_ci- static void Replace(const std::vector<EXPR *> &vec, std::vector<EXPR *> *old, std::vector<EXPR *> *n); 937be168c0dSopenharmony_ci- static void CreateOutputMap(std::vector<EXPR *> vec, std::map<EXPR *, std::list<EXPR *>> *outmap); 938be168c0dSopenharmony_ci- static void Clear(std::vector<EXPR *> vec); 939be168c0dSopenharmony_ci- void Travers(std::function<bool(EXPR *e, EXPR *itr)> cb); 940be168c0dSopenharmony_ci- std::string name(); 941be168c0dSopenharmony_ci- EXPR *GetInput(int idx) { return params_.at(idx); } 942be168c0dSopenharmony_ci- void set_node(Node *node) { node_ = node; } 943be168c0dSopenharmony_ci- Node *node() { return node_; } 944be168c0dSopenharmony_ci- bool visited = false; 945be168c0dSopenharmony_ci- void set_params(std::vector<EXPR *> params) { params_ = params; } 946be168c0dSopenharmony_ci- void set_params(int idx, EXPR *expr) { params_[idx] = expr; } 947be168c0dSopenharmony_ci- void add_params(EXPR *e) { params_.push_back(e); } 948be168c0dSopenharmony_ci- std::vector<EXPR *> ¶ms() { return params_; } 949be168c0dSopenharmony_ci- EXPR *params(int i) { return params_[i]; } 950be168c0dSopenharmony_ci- void SetSize(int n) { params_.resize(n); } 951be168c0dSopenharmony_ci- void SetDims(std::vector<int> dims) { dims_ = dims; } 952be168c0dSopenharmony_ci- std::vector<int> &dims() { return dims_; } 953be168c0dSopenharmony_ci- void set_format(int fmt) { format_ = fmt; } 954be168c0dSopenharmony_ci- int format() { return format_; } 955be168c0dSopenharmony_ci- void set_data_type(TypeId data_type) { data_type_ = data_type; } 956be168c0dSopenharmony_ci- TypeId data_type() { return data_type_; } 957be168c0dSopenharmony_ci- 958be168c0dSopenharmony_ci- private: 959be168c0dSopenharmony_ci- void Replace(EXPR **old, EXPR **n, std::vector<Node *> *to_delete); 960be168c0dSopenharmony_ci- std::vector<EXPR *> params_; 961be168c0dSopenharmony_ci- Node *node_{nullptr}; 962be168c0dSopenharmony_ci- void Clear(); 963be168c0dSopenharmony_ci- std::vector<int> dims_; 964be168c0dSopenharmony_ci- int format_ = NHWC; 965be168c0dSopenharmony_ci- TypeId data_type_ = kNumberTypeFloat32; 966be168c0dSopenharmony_ci-}; 967be168c0dSopenharmony_ci-} // namespace lite 968be168c0dSopenharmony_ci-} // namespace mindspore 969be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_EXPR_H_ 970be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/import.cc b/mindspore/lite/src/expression/import.cc 971be168c0dSopenharmony_cideleted file mode 100644 972be168c0dSopenharmony_ciindex b3c20839..00000000 973be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/import.cc 974be168c0dSopenharmony_ci+++ /dev/null 975be168c0dSopenharmony_ci@@ -1,180 +0,0 @@ 976be168c0dSopenharmony_ci-/** 977be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 978be168c0dSopenharmony_ci- * 979be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 980be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 981be168c0dSopenharmony_ci- * You may obtain a copy of the License at 982be168c0dSopenharmony_ci- * 983be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 984be168c0dSopenharmony_ci- * 985be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 986be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 987be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 988be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 989be168c0dSopenharmony_ci- * limitations under the License. 990be168c0dSopenharmony_ci- */ 991be168c0dSopenharmony_ci- 992be168c0dSopenharmony_ci-#include <vector> 993be168c0dSopenharmony_ci-#include "src/expression/import.h" 994be168c0dSopenharmony_ci-#include "common/ops/populate/populate_register.h" 995be168c0dSopenharmony_ci-#include "src/expression/ops.h" 996be168c0dSopenharmony_ci-#include "src/expression/ops/activation.h" 997be168c0dSopenharmony_ci-#include "src/expression/ops/batchnorm.h" 998be168c0dSopenharmony_ci-#include "src/expression/ops/biasadd.h" 999be168c0dSopenharmony_ci-#include "src/expression/ops/conv.h" 1000be168c0dSopenharmony_ci-#include "src/expression/ops/dense.h" 1001be168c0dSopenharmony_ci-#include "src/expression/ops/pooling.h" 1002be168c0dSopenharmony_ci-#include "src/expression/ops/reshape.h" 1003be168c0dSopenharmony_ci-#include "src/expression/ops/transpose.h" 1004be168c0dSopenharmony_ci- 1005be168c0dSopenharmony_ci-namespace mindspore { 1006be168c0dSopenharmony_ci-namespace lite { 1007be168c0dSopenharmony_ci-std::unordered_map<mindspore::schema::PrimitiveType, import_func> ImportReg::import_map_; 1008be168c0dSopenharmony_ci- 1009be168c0dSopenharmony_ci-import_func ImportReg::GetImportFunc(mindspore::schema::PrimitiveType type) { 1010be168c0dSopenharmony_ci- auto f = import_map_.find(type); 1011be168c0dSopenharmony_ci- if (f == import_map_.end()) { 1012be168c0dSopenharmony_ci- return nullptr; 1013be168c0dSopenharmony_ci- } 1014be168c0dSopenharmony_ci- return f->second; 1015be168c0dSopenharmony_ci-} 1016be168c0dSopenharmony_ci- 1017be168c0dSopenharmony_ci-OpParameter *Import::GetAttr(const schema::Primitive *prim) { 1018be168c0dSopenharmony_ci- auto parameter_gen = PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR); 1019be168c0dSopenharmony_ci- if (parameter_gen == nullptr) { 1020be168c0dSopenharmony_ci- MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type()); 1021be168c0dSopenharmony_ci- return nullptr; 1022be168c0dSopenharmony_ci- } 1023be168c0dSopenharmony_ci- auto parameter = parameter_gen(prim); 1024be168c0dSopenharmony_ci- if (parameter == nullptr) { 1025be168c0dSopenharmony_ci- MS_LOG(ERROR) << "parameter is nullptr."; 1026be168c0dSopenharmony_ci- return nullptr; 1027be168c0dSopenharmony_ci- } 1028be168c0dSopenharmony_ci- return parameter; 1029be168c0dSopenharmony_ci-} 1030be168c0dSopenharmony_ci- 1031be168c0dSopenharmony_ci-std::unique_ptr<Node> Import::CreateNode(const schema::CNode *cnode) { 1032be168c0dSopenharmony_ci- auto param = GetAttr(cnode->primitive()); 1033be168c0dSopenharmony_ci- auto type = cnode->primitive()->value_type(); 1034be168c0dSopenharmony_ci- auto fn = ImportReg::GetImportFunc(type); 1035be168c0dSopenharmony_ci- if (fn == nullptr) { 1036be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot find importer for " << schema::EnumNamePrimitiveType(type); 1037be168c0dSopenharmony_ci- return nullptr; 1038be168c0dSopenharmony_ci- } 1039be168c0dSopenharmony_ci- auto node = fn(); 1040be168c0dSopenharmony_ci- if (node == nullptr) { 1041be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate node" << cnode->name()->str(); 1042be168c0dSopenharmony_ci- return nullptr; 1043be168c0dSopenharmony_ci- } 1044be168c0dSopenharmony_ci- node->SetOpParam(param); 1045be168c0dSopenharmony_ci- node->set_name(cnode->name()->str()); 1046be168c0dSopenharmony_ci- node->set_primitive(type); 1047be168c0dSopenharmony_ci- return std::unique_ptr<Node>(node); 1048be168c0dSopenharmony_ci-} 1049be168c0dSopenharmony_ci- 1050be168c0dSopenharmony_ci-Net *Import::ImportMs(std::string file_name) { 1051be168c0dSopenharmony_ci- std::ifstream infile; 1052be168c0dSopenharmony_ci- infile.open(file_name, std::ios::binary | std::ios::in); 1053be168c0dSopenharmony_ci- if (!infile.good()) { 1054be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot read " << file_name << std::endl; 1055be168c0dSopenharmony_ci- return nullptr; 1056be168c0dSopenharmony_ci- } 1057be168c0dSopenharmony_ci- infile.seekg(0, std::ios::end); 1058be168c0dSopenharmony_ci- int length = infile.tellg(); 1059be168c0dSopenharmony_ci- infile.seekg(0, std::ios::beg); 1060be168c0dSopenharmony_ci- auto data_ptr = std::make_unique<int8_t[]>(length); 1061be168c0dSopenharmony_ci- auto *data = data_ptr.get(); 1062be168c0dSopenharmony_ci- infile.read(reinterpret_cast<char *>(data), length); 1063be168c0dSopenharmony_ci- infile.close(); 1064be168c0dSopenharmony_ci- flatbuffers::Verifier verifier = flatbuffers::Verifier(reinterpret_cast<const uint8_t *>(data), length); 1065be168c0dSopenharmony_ci- bool res = schema::VerifyMetaGraphBuffer(verifier); 1066be168c0dSopenharmony_ci- if (res != true) { 1067be168c0dSopenharmony_ci- MS_LOG(ERROR) << "fault file: " << file_name << "(" << length << ")\n"; 1068be168c0dSopenharmony_ci- return nullptr; 1069be168c0dSopenharmony_ci- } else { 1070be168c0dSopenharmony_ci- MS_LOG(INFO) << "verify pass file: " << file_name << "(" << length << ")\n"; 1071be168c0dSopenharmony_ci- } 1072be168c0dSopenharmony_ci- buffer_ = data_ptr.get(); 1073be168c0dSopenharmony_ci- auto metaGraph = schema::GetMetaGraph(data_ptr.release()); 1074be168c0dSopenharmony_ci- return ImportMs(metaGraph); 1075be168c0dSopenharmony_ci-} 1076be168c0dSopenharmony_ci- 1077be168c0dSopenharmony_ci-Net *Import::ImportMs(const schema::MetaGraph *metaGraph) { 1078be168c0dSopenharmony_ci- if (metaGraph == nullptr) { 1079be168c0dSopenharmony_ci- MS_LOG(ERROR) << "null input"; 1080be168c0dSopenharmony_ci- return nullptr; 1081be168c0dSopenharmony_ci- } 1082be168c0dSopenharmony_ci- std::string NetName = "Network"; 1083be168c0dSopenharmony_ci- if (metaGraph->name() != nullptr) NetName = metaGraph->name()->str(); 1084be168c0dSopenharmony_ci- auto net = std::make_unique<Net>(NetName); 1085be168c0dSopenharmony_ci- std::unordered_map<int, EXPR *> outputs; 1086be168c0dSopenharmony_ci- // save inputs 1087be168c0dSopenharmony_ci- for (size_t i = 0; i < metaGraph->inputIndex()->size(); i++) { 1088be168c0dSopenharmony_ci- auto tensor_id = metaGraph->inputIndex()->Get(i); 1089be168c0dSopenharmony_ci- const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id); 1090be168c0dSopenharmony_ci- auto input = new (std::nothrow) InputM(tensor); 1091be168c0dSopenharmony_ci- if (input == nullptr) { 1092be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate input"; 1093be168c0dSopenharmony_ci- return nullptr; 1094be168c0dSopenharmony_ci- } 1095be168c0dSopenharmony_ci- auto e = input->expr(); 1096be168c0dSopenharmony_ci- outputs[tensor_id] = e; 1097be168c0dSopenharmony_ci- net->PushInput(e); 1098be168c0dSopenharmony_ci- } 1099be168c0dSopenharmony_ci- for (size_t i = 0; i < metaGraph->nodes()->size(); i++) { 1100be168c0dSopenharmony_ci- auto Cnode = metaGraph->nodes()->Get(i); 1101be168c0dSopenharmony_ci- std::vector<EXPR *> param_tensors; 1102be168c0dSopenharmony_ci- for (size_t j = 0; j < Cnode->inputIndex()->size(); j++) { 1103be168c0dSopenharmony_ci- int tensor_id = Cnode->inputIndex()->Get(j); 1104be168c0dSopenharmony_ci- const schema::Tensor *tensor = metaGraph->allTensors()->Get(tensor_id); 1105be168c0dSopenharmony_ci- auto iter = outputs.find(tensor_id); 1106be168c0dSopenharmony_ci- if (iter == outputs.end()) { 1107be168c0dSopenharmony_ci- // create value node if not exist 1108be168c0dSopenharmony_ci- if (tensor->nodeType() != NodeType::NodeType_CNode) { 1109be168c0dSopenharmony_ci- auto valnode = new (std::nothrow) InputM(tensor); 1110be168c0dSopenharmony_ci- if (valnode == nullptr) { 1111be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate valnode"; 1112be168c0dSopenharmony_ci- return nullptr; 1113be168c0dSopenharmony_ci- } 1114be168c0dSopenharmony_ci- outputs[tensor_id] = valnode->expr(); 1115be168c0dSopenharmony_ci- param_tensors.push_back(valnode->expr()); 1116be168c0dSopenharmony_ci- net->PushOp(valnode); 1117be168c0dSopenharmony_ci- } else { 1118be168c0dSopenharmony_ci- MS_LOG(ERROR) << "did not found input tensor " << tensor_id; 1119be168c0dSopenharmony_ci- return nullptr; 1120be168c0dSopenharmony_ci- } 1121be168c0dSopenharmony_ci- } else { 1122be168c0dSopenharmony_ci- param_tensors.push_back(iter->second); 1123be168c0dSopenharmony_ci- } 1124be168c0dSopenharmony_ci- } 1125be168c0dSopenharmony_ci- // create expression from node // 1126be168c0dSopenharmony_ci- auto node = CreateNode(Cnode); 1127be168c0dSopenharmony_ci- if (node != nullptr) { 1128be168c0dSopenharmony_ci- node->SetOutputs(Cnode->outputIndex()->size()); 1129be168c0dSopenharmony_ci- std::vector<EXPR *> e = (*node)(param_tensors); 1130be168c0dSopenharmony_ci- for (size_t j = 0; j < Cnode->outputIndex()->size(); j++) { 1131be168c0dSopenharmony_ci- int tensor_id = Cnode->outputIndex()->Get(j); 1132be168c0dSopenharmony_ci- outputs[tensor_id] = e.at(j); 1133be168c0dSopenharmony_ci- } 1134be168c0dSopenharmony_ci- } else { 1135be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to create node " << Cnode->name(); 1136be168c0dSopenharmony_ci- return nullptr; 1137be168c0dSopenharmony_ci- } 1138be168c0dSopenharmony_ci- auto node_ptr = node.release(); 1139be168c0dSopenharmony_ci- net->PushOp(node_ptr); 1140be168c0dSopenharmony_ci- node_ptr->SetLearn(); 1141be168c0dSopenharmony_ci- } 1142be168c0dSopenharmony_ci- for (size_t i = 0; i < metaGraph->outputIndex()->size(); i++) { 1143be168c0dSopenharmony_ci- auto tensor_id = metaGraph->outputIndex()->Get(i); 1144be168c0dSopenharmony_ci- auto iter = outputs.find(tensor_id); 1145be168c0dSopenharmony_ci- if (iter == outputs.end()) { 1146be168c0dSopenharmony_ci- MS_LOG(ERROR) << "could not find source for tensor " << tensor_id; 1147be168c0dSopenharmony_ci- return nullptr; 1148be168c0dSopenharmony_ci- } else { 1149be168c0dSopenharmony_ci- net->PushOutput(iter->second); 1150be168c0dSopenharmony_ci- } 1151be168c0dSopenharmony_ci- } 1152be168c0dSopenharmony_ci- return net.release(); 1153be168c0dSopenharmony_ci-} 1154be168c0dSopenharmony_ci-} // namespace lite 1155be168c0dSopenharmony_ci-} // namespace mindspore 1156be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/import.h b/mindspore/lite/src/expression/import.h 1157be168c0dSopenharmony_cideleted file mode 100644 1158be168c0dSopenharmony_ciindex 3d4f301e..00000000 1159be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/import.h 1160be168c0dSopenharmony_ci+++ /dev/null 1161be168c0dSopenharmony_ci@@ -1,61 +0,0 @@ 1162be168c0dSopenharmony_ci-/** 1163be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 1164be168c0dSopenharmony_ci- * 1165be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 1166be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 1167be168c0dSopenharmony_ci- * You may obtain a copy of the License at 1168be168c0dSopenharmony_ci- * 1169be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 1170be168c0dSopenharmony_ci- * 1171be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 1172be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 1173be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1174be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 1175be168c0dSopenharmony_ci- * limitations under the License. 1176be168c0dSopenharmony_ci- */ 1177be168c0dSopenharmony_ci- 1178be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_ 1179be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_ 1180be168c0dSopenharmony_ci- 1181be168c0dSopenharmony_ci-#include <string> 1182be168c0dSopenharmony_ci-#include <unordered_map> 1183be168c0dSopenharmony_ci-#include <memory> 1184be168c0dSopenharmony_ci-#include "nnacl/op_base.h" 1185be168c0dSopenharmony_ci-#include "src/expression/net.h" 1186be168c0dSopenharmony_ci- 1187be168c0dSopenharmony_ci-namespace mindspore { 1188be168c0dSopenharmony_ci-namespace lite { 1189be168c0dSopenharmony_ci-using import_func = std::function<Node *()>; 1190be168c0dSopenharmony_ci- 1191be168c0dSopenharmony_ci-template <typename T> 1192be168c0dSopenharmony_ci-Node *ReturnNode() { 1193be168c0dSopenharmony_ci- return new (std::nothrow) T(); 1194be168c0dSopenharmony_ci-} 1195be168c0dSopenharmony_ci- 1196be168c0dSopenharmony_ci-class ImportReg { 1197be168c0dSopenharmony_ci- public: 1198be168c0dSopenharmony_ci- explicit ImportReg(mindspore::schema::PrimitiveType type, import_func func) { import_map_[type] = func; } 1199be168c0dSopenharmony_ci- static import_func GetImportFunc(mindspore::schema::PrimitiveType type); 1200be168c0dSopenharmony_ci- 1201be168c0dSopenharmony_ci- private: 1202be168c0dSopenharmony_ci- static std::unordered_map<mindspore::schema::PrimitiveType, import_func> import_map_; 1203be168c0dSopenharmony_ci-}; 1204be168c0dSopenharmony_ci- 1205be168c0dSopenharmony_ci-class Import { 1206be168c0dSopenharmony_ci- private: 1207be168c0dSopenharmony_ci- int8_t *buffer_ = nullptr; 1208be168c0dSopenharmony_ci- OpParameter *GetAttr(const schema::Primitive *prim); 1209be168c0dSopenharmony_ci- std::unique_ptr<Node> CreateNode(const schema::CNode *cnode); 1210be168c0dSopenharmony_ci- 1211be168c0dSopenharmony_ci- public: 1212be168c0dSopenharmony_ci- Net *ImportMs(const schema::MetaGraph *meta_graph); 1213be168c0dSopenharmony_ci- Net *ImportMs(std::string file_name); 1214be168c0dSopenharmony_ci- ~Import() { 1215be168c0dSopenharmony_ci- delete[] buffer_; 1216be168c0dSopenharmony_ci- buffer_ = nullptr; 1217be168c0dSopenharmony_ci- } 1218be168c0dSopenharmony_ci-}; 1219be168c0dSopenharmony_ci-} // namespace lite 1220be168c0dSopenharmony_ci-} // namespace mindspore 1221be168c0dSopenharmony_ci- 1222be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_IMPORT_H_ 1223be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/net.cc b/mindspore/lite/src/expression/net.cc 1224be168c0dSopenharmony_cideleted file mode 100644 1225be168c0dSopenharmony_ciindex 560f8cf6..00000000 1226be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/net.cc 1227be168c0dSopenharmony_ci+++ /dev/null 1228be168c0dSopenharmony_ci@@ -1,268 +0,0 @@ 1229be168c0dSopenharmony_ci-/** 1230be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 1231be168c0dSopenharmony_ci- * 1232be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 1233be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 1234be168c0dSopenharmony_ci- * You may obtain a copy of the License at 1235be168c0dSopenharmony_ci- * 1236be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 1237be168c0dSopenharmony_ci- * 1238be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 1239be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 1240be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1241be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 1242be168c0dSopenharmony_ci- * limitations under the License. 1243be168c0dSopenharmony_ci- */ 1244be168c0dSopenharmony_ci- 1245be168c0dSopenharmony_ci-#include "src/expression/net.h" 1246be168c0dSopenharmony_ci-#include <vector> 1247be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 1248be168c0dSopenharmony_ci-#include "src/expression/ops.h" 1249be168c0dSopenharmony_ci-#include "src/expression/export.h" 1250be168c0dSopenharmony_ci-#include "src/expression/ops/addn.h" 1251be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic.h" 1252be168c0dSopenharmony_ci-#include "src/common/storage.h" 1253be168c0dSopenharmony_ci-#include "tools/common/meta_graph_serializer.h" 1254be168c0dSopenharmony_ci-namespace mindspore { 1255be168c0dSopenharmony_ci-namespace lite { 1256be168c0dSopenharmony_ci-void Net::update_name(std::string name) { 1257be168c0dSopenharmony_ci- if (!this->name().empty()) 1258be168c0dSopenharmony_ci- Node::update_name(name); 1259be168c0dSopenharmony_ci- else 1260be168c0dSopenharmony_ci- set_name(name); 1261be168c0dSopenharmony_ci- for (auto &itr : ops_) { 1262be168c0dSopenharmony_ci- itr->update_name(name); 1263be168c0dSopenharmony_ci- } 1264be168c0dSopenharmony_ci-} 1265be168c0dSopenharmony_ci- 1266be168c0dSopenharmony_ci-std::vector<EXPR *> Net::operator()(const std::initializer_list<EXPR *> &&inputs) { 1267be168c0dSopenharmony_ci- std::vector<EXPR *> vec = inputs; 1268be168c0dSopenharmony_ci- std::vector<EXPR *> x; 1269be168c0dSopenharmony_ci- if (impl_ == nullptr) { 1270be168c0dSopenharmony_ci- x = construct(inputs); 1271be168c0dSopenharmony_ci- } else { 1272be168c0dSopenharmony_ci- x = impl_->construct(vec); 1273be168c0dSopenharmony_ci- } 1274be168c0dSopenharmony_ci- return x; 1275be168c0dSopenharmony_ci-} 1276be168c0dSopenharmony_ci- 1277be168c0dSopenharmony_ci-std::vector<EXPR *> Net::operator()(const std::vector<EXPR *> &inputs) { 1278be168c0dSopenharmony_ci- std::vector<EXPR *> x; 1279be168c0dSopenharmony_ci- if (impl_ == nullptr) { 1280be168c0dSopenharmony_ci- x = construct(inputs); 1281be168c0dSopenharmony_ci- } else { 1282be168c0dSopenharmony_ci- x = impl_->construct(inputs); 1283be168c0dSopenharmony_ci- } 1284be168c0dSopenharmony_ci- input_ = inputs; 1285be168c0dSopenharmony_ci- output_ = x; 1286be168c0dSopenharmony_ci- real_output_ = x; 1287be168c0dSopenharmony_ci- return x; 1288be168c0dSopenharmony_ci-} 1289be168c0dSopenharmony_ci- 1290be168c0dSopenharmony_ci-std::vector<EXPR *> Net::construct(const std::vector<EXPR *> &inputs) { 1291be168c0dSopenharmony_ci- if (!output_.empty()) { 1292be168c0dSopenharmony_ci- if (input_.size() != inputs.size()) { 1293be168c0dSopenharmony_ci- MS_LOG(ERROR) << "input size mismatch, should be " << input_.size() << " got " << inputs.size(); 1294be168c0dSopenharmony_ci- return {}; 1295be168c0dSopenharmony_ci- } 1296be168c0dSopenharmony_ci- auto in_ptr = inputs; 1297be168c0dSopenharmony_ci- EXPR::Replace(output_, &input_, &in_ptr); 1298be168c0dSopenharmony_ci- } else { 1299be168c0dSopenharmony_ci- MS_LOG(ERROR) << "no network construction function"; 1300be168c0dSopenharmony_ci- } 1301be168c0dSopenharmony_ci- return output_; 1302be168c0dSopenharmony_ci-} 1303be168c0dSopenharmony_ci- 1304be168c0dSopenharmony_ci-void Net::TopoSortUtil(Node *node, std::stack<Node *> *stack) { 1305be168c0dSopenharmony_ci- visited_.insert(node); 1306be168c0dSopenharmony_ci- for (size_t i = 0; i < node->OutputsNum(); i++) { 1307be168c0dSopenharmony_ci- auto expr = node->expr(i); 1308be168c0dSopenharmony_ci- auto itr = outmap_.find(expr); 1309be168c0dSopenharmony_ci- if (itr != outmap_.end()) { 1310be168c0dSopenharmony_ci- for (auto &e : itr->second) 1311be168c0dSopenharmony_ci- if (visited_.find(e->node()) == visited_.end()) { 1312be168c0dSopenharmony_ci- TopoSortUtil(e->node(), stack); 1313be168c0dSopenharmony_ci- } 1314be168c0dSopenharmony_ci- } 1315be168c0dSopenharmony_ci- } 1316be168c0dSopenharmony_ci- stack->push(node); 1317be168c0dSopenharmony_ci-} 1318be168c0dSopenharmony_ci- 1319be168c0dSopenharmony_ci-std::vector<Node *> Net::Sort() { 1320be168c0dSopenharmony_ci- std::stack<Node *> stack; 1321be168c0dSopenharmony_ci- outmap_.clear(); 1322be168c0dSopenharmony_ci- EXPR::CreateOutputMap(output_, &outmap_); 1323be168c0dSopenharmony_ci- for (auto &itr : outmap_) { 1324be168c0dSopenharmony_ci- EXPR *e = itr.first; 1325be168c0dSopenharmony_ci- if (visited_.find(e->node()) == visited_.end()) { 1326be168c0dSopenharmony_ci- TopoSortUtil(e->node(), &stack); 1327be168c0dSopenharmony_ci- } 1328be168c0dSopenharmony_ci- } 1329be168c0dSopenharmony_ci- std::vector<Node *> res; 1330be168c0dSopenharmony_ci- while (stack.empty() == false) { 1331be168c0dSopenharmony_ci- res.push_back(stack.top()); 1332be168c0dSopenharmony_ci- stack.pop(); 1333be168c0dSopenharmony_ci- } 1334be168c0dSopenharmony_ci- visited_.clear(); 1335be168c0dSopenharmony_ci- return res; 1336be168c0dSopenharmony_ci-} 1337be168c0dSopenharmony_ci- 1338be168c0dSopenharmony_ci-std::unique_ptr<schema::MetaGraphT> Net::MakeMs() { 1339be168c0dSopenharmony_ci- auto nodes = Sort(); 1340be168c0dSopenharmony_ci- auto s = new (std::nothrow) ExportSession(outmap_); 1341be168c0dSopenharmony_ci- if (s == nullptr) { 1342be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate export session"; 1343be168c0dSopenharmony_ci- return nullptr; 1344be168c0dSopenharmony_ci- } 1345be168c0dSopenharmony_ci- session_.reset(s); 1346be168c0dSopenharmony_ci- session_->Init(name(), Version()); 1347be168c0dSopenharmony_ci- for (auto node : nodes) { 1348be168c0dSopenharmony_ci- auto res = node->MakeEntry(session_.get()); 1349be168c0dSopenharmony_ci- if (res != RET_OK) { 1350be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed in MakeEntry: " << node->name(); 1351be168c0dSopenharmony_ci- return nullptr; 1352be168c0dSopenharmony_ci- } 1353be168c0dSopenharmony_ci- } 1354be168c0dSopenharmony_ci- session_->SetInputOutput(input_, real_output_); 1355be168c0dSopenharmony_ci- auto res = session_->meta_graph(); 1356be168c0dSopenharmony_ci- return std::unique_ptr<schema::MetaGraphT>(res); 1357be168c0dSopenharmony_ci-} 1358be168c0dSopenharmony_ci- 1359be168c0dSopenharmony_ci-std::unique_ptr<schema::MetaGraphT> Net::MakeMs(const std::string file_name) { 1360be168c0dSopenharmony_ci- auto graph = MakeMs(); 1361be168c0dSopenharmony_ci- Save(*graph, file_name); 1362be168c0dSopenharmony_ci- return graph; 1363be168c0dSopenharmony_ci-} 1364be168c0dSopenharmony_ci- 1365be168c0dSopenharmony_ci-std::set<Node *> Net::trainable_params() { 1366be168c0dSopenharmony_ci- std::set<Node *> res; 1367be168c0dSopenharmony_ci- for (auto &node : ops_) { 1368be168c0dSopenharmony_ci- res.merge(node->trainable_params()); 1369be168c0dSopenharmony_ci- } 1370be168c0dSopenharmony_ci- return res; 1371be168c0dSopenharmony_ci-} 1372be168c0dSopenharmony_ci- 1373be168c0dSopenharmony_ci-int Net::BuildGrad(Node *optimizer) { 1374be168c0dSopenharmony_ci- std::set<Node *> learn = optimizer->trainable_params(); 1375be168c0dSopenharmony_ci- auto NetOrder = Sort(); 1376be168c0dSopenharmony_ci- optimizer_.reset(optimizer); 1377be168c0dSopenharmony_ci- optimizer->AddNetOutput(&output_); 1378be168c0dSopenharmony_ci- std::map<std::pair<EXPR *, EXPR *>, EXPR *> backprop; 1379be168c0dSopenharmony_ci- for (auto itr = NetOrder.rbegin(); itr != NetOrder.rend(); itr++) { 1380be168c0dSopenharmony_ci- Node *node = *itr; 1381be168c0dSopenharmony_ci- EXPR *yt = nullptr; 1382be168c0dSopenharmony_ci- if (node->primitive() == schema::PrimitiveType_NONE) continue; 1383be168c0dSopenharmony_ci- if (outmap_.find(node->expr()) == outmap_.end() || outmap_[node->expr()].size() == 0) { 1384be168c0dSopenharmony_ci- yt = node->expr(); 1385be168c0dSopenharmony_ci- } else { 1386be168c0dSopenharmony_ci- std::vector<EXPR *> add_params; 1387be168c0dSopenharmony_ci- for (auto &output : outmap_[node->expr()]) { 1388be168c0dSopenharmony_ci- auto link = std::make_pair(node->expr(), output); 1389be168c0dSopenharmony_ci- auto grad = backprop[link]; 1390be168c0dSopenharmony_ci- add_params.push_back(grad); 1391be168c0dSopenharmony_ci- } 1392be168c0dSopenharmony_ci- if (add_params.size() == 1) { 1393be168c0dSopenharmony_ci- yt = add_params.front(); 1394be168c0dSopenharmony_ci- } else { 1395be168c0dSopenharmony_ci- auto addn = new (std::nothrow) AddN(0); 1396be168c0dSopenharmony_ci- if (addn == nullptr) { 1397be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate add operator"; 1398be168c0dSopenharmony_ci- return RET_ERROR; 1399be168c0dSopenharmony_ci- } 1400be168c0dSopenharmony_ci- PushOp(addn); 1401be168c0dSopenharmony_ci- addn->update_name(name()); 1402be168c0dSopenharmony_ci- yt = (*addn)(add_params).front(); 1403be168c0dSopenharmony_ci- } 1404be168c0dSopenharmony_ci- } 1405be168c0dSopenharmony_ci- auto inGrads = node->Grad(yt); 1406be168c0dSopenharmony_ci- for (size_t i = 0; i < node->inputs().size(); i++) { 1407be168c0dSopenharmony_ci- EXPR *inGrad{nullptr}; 1408be168c0dSopenharmony_ci- if (i < inGrads.size()) { 1409be168c0dSopenharmony_ci- inGrad = inGrads[i]; 1410be168c0dSopenharmony_ci- } else { 1411be168c0dSopenharmony_ci- inGrad = nullptr; 1412be168c0dSopenharmony_ci- } 1413be168c0dSopenharmony_ci- auto input = node->input(i); 1414be168c0dSopenharmony_ci- if (learn.find(input->node()) != learn.end()) { 1415be168c0dSopenharmony_ci- auto opt = optimizer->Clone(inGrad, input); 1416be168c0dSopenharmony_ci- if (opt.size() == 0) { 1417be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to create optimizer"; 1418be168c0dSopenharmony_ci- return RET_ERROR; 1419be168c0dSopenharmony_ci- } 1420be168c0dSopenharmony_ci- if (inGrad == nullptr) { 1421be168c0dSopenharmony_ci- MS_LOG(ERROR) << "illegal null value for grad"; 1422be168c0dSopenharmony_ci- return RET_ERROR; 1423be168c0dSopenharmony_ci- } 1424be168c0dSopenharmony_ci- if (opt.size() == 0) { 1425be168c0dSopenharmony_ci- MS_LOG(ERROR) << "optimizer for " << input->node()->name() << " failure"; 1426be168c0dSopenharmony_ci- return RET_ERROR; 1427be168c0dSopenharmony_ci- } 1428be168c0dSopenharmony_ci- auto opt_op = opt.at(0)->node(); 1429be168c0dSopenharmony_ci- PushOp(opt_op); 1430be168c0dSopenharmony_ci- opt_op->update_name(node->name()); 1431be168c0dSopenharmony_ci- output_.push_back(opt.at(0)); 1432be168c0dSopenharmony_ci- } 1433be168c0dSopenharmony_ci- auto link = std::make_pair(input, node->expr()); 1434be168c0dSopenharmony_ci- backprop[link] = inGrad; 1435be168c0dSopenharmony_ci- } 1436be168c0dSopenharmony_ci- } 1437be168c0dSopenharmony_ci- return RET_OK; 1438be168c0dSopenharmony_ci-} 1439be168c0dSopenharmony_ci- 1440be168c0dSopenharmony_ci-std::vector<EXPR *> Net::add(const std::vector<EXPR *> &input) { 1441be168c0dSopenharmony_ci- auto _add = NN::Add(); 1442be168c0dSopenharmony_ci- _add->set_name(name() + "/" + _add->name()); 1443be168c0dSopenharmony_ci- ops_.push_back(_add); 1444be168c0dSopenharmony_ci- return (*_add)(input); 1445be168c0dSopenharmony_ci-} 1446be168c0dSopenharmony_ci- 1447be168c0dSopenharmony_ci-Net *Net::TrainNet(Node *optimizer, Node *loss_fn, const std::vector<EXPR *> &inputs) { 1448be168c0dSopenharmony_ci- auto net = new (std::nothrow) NetWithLoss(this, loss_fn); 1449be168c0dSopenharmony_ci- if (net == nullptr) { 1450be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate loss network"; 1451be168c0dSopenharmony_ci- return nullptr; 1452be168c0dSopenharmony_ci- } 1453be168c0dSopenharmony_ci- return net->TrainNet(optimizer, inputs); 1454be168c0dSopenharmony_ci-} 1455be168c0dSopenharmony_ci- 1456be168c0dSopenharmony_ci-Net *Net::TrainNet(Node *optimizer, const std::vector<EXPR *> &inputs) { 1457be168c0dSopenharmony_ci- auto x = (*this)(inputs); 1458be168c0dSopenharmony_ci- auto res = BuildGrad(optimizer); 1459be168c0dSopenharmony_ci- if (res != RET_OK) { 1460be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Build gradient network failed"; 1461be168c0dSopenharmony_ci- return nullptr; 1462be168c0dSopenharmony_ci- } 1463be168c0dSopenharmony_ci- real_output_ = x; 1464be168c0dSopenharmony_ci- return this; 1465be168c0dSopenharmony_ci-} 1466be168c0dSopenharmony_ci- 1467be168c0dSopenharmony_ci-int Net::Save(const schema::MetaGraphT &graph, std::string file_name) { return Storage::Save(graph, file_name); } 1468be168c0dSopenharmony_ci- 1469be168c0dSopenharmony_ci-const std::vector<int> Net::OutputShape(int idx) { 1470be168c0dSopenharmony_ci- if (static_cast<size_t>(idx) >= real_output_.size()) { 1471be168c0dSopenharmony_ci- MS_LOG(ERROR) << "index (" << idx << ") exceed output size (" << real_output_.size() << ")"; 1472be168c0dSopenharmony_ci- return {}; 1473be168c0dSopenharmony_ci- } 1474be168c0dSopenharmony_ci- return real_output_.at(idx)->dims(); 1475be168c0dSopenharmony_ci-} 1476be168c0dSopenharmony_ci- 1477be168c0dSopenharmony_ci-const std::vector<int> Net::InputShape(int idx) { 1478be168c0dSopenharmony_ci- if (static_cast<size_t>(idx) >= input_.size()) { 1479be168c0dSopenharmony_ci- MS_LOG(ERROR) << "index (" << idx << ") exceed input size (" << input_.size() << ")"; 1480be168c0dSopenharmony_ci- return {}; 1481be168c0dSopenharmony_ci- } 1482be168c0dSopenharmony_ci- return input_.at(idx)->dims(); 1483be168c0dSopenharmony_ci-} 1484be168c0dSopenharmony_ci- 1485be168c0dSopenharmony_ci-Net::~Net() { 1486be168c0dSopenharmony_ci- if (impl_ != nullptr) { 1487be168c0dSopenharmony_ci- impl_->erase_net(); 1488be168c0dSopenharmony_ci- auto pnet = impl_->pnet(); 1489be168c0dSopenharmony_ci- if (pnet != nullptr) { 1490be168c0dSopenharmony_ci- impl_->set_pnet(nullptr); 1491be168c0dSopenharmony_ci- } 1492be168c0dSopenharmony_ci- } 1493be168c0dSopenharmony_ci- impl_ = nullptr; 1494be168c0dSopenharmony_ci-} 1495be168c0dSopenharmony_ci-} // namespace lite 1496be168c0dSopenharmony_ci-} // namespace mindspore 1497be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/net.h b/mindspore/lite/src/expression/net.h 1498be168c0dSopenharmony_cideleted file mode 100644 1499be168c0dSopenharmony_ciindex f5525173..00000000 1500be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/net.h 1501be168c0dSopenharmony_ci+++ /dev/null 1502be168c0dSopenharmony_ci@@ -1,114 +0,0 @@ 1503be168c0dSopenharmony_ci-/** 1504be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 1505be168c0dSopenharmony_ci- * 1506be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 1507be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 1508be168c0dSopenharmony_ci- * You may obtain a copy of the License at 1509be168c0dSopenharmony_ci- * 1510be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 1511be168c0dSopenharmony_ci- * 1512be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 1513be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 1514be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1515be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 1516be168c0dSopenharmony_ci- * limitations under the License. 1517be168c0dSopenharmony_ci- */ 1518be168c0dSopenharmony_ci- 1519be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NET_H_ 1520be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_NET_H_ 1521be168c0dSopenharmony_ci-#include <stack> 1522be168c0dSopenharmony_ci-#include <memory> 1523be168c0dSopenharmony_ci-#include <set> 1524be168c0dSopenharmony_ci-#include <map> 1525be168c0dSopenharmony_ci-#include <utility> 1526be168c0dSopenharmony_ci-#include <string> 1527be168c0dSopenharmony_ci-#include <unordered_set> 1528be168c0dSopenharmony_ci-#include <list> 1529be168c0dSopenharmony_ci-#include <vector> 1530be168c0dSopenharmony_ci-#include "src/expression/node.h" 1531be168c0dSopenharmony_ci-#include "inner/model_generated.h" 1532be168c0dSopenharmony_ci- 1533be168c0dSopenharmony_ci-namespace mindspore { 1534be168c0dSopenharmony_ci-class Net; 1535be168c0dSopenharmony_ci-class NetImpl; 1536be168c0dSopenharmony_ci-namespace lite { 1537be168c0dSopenharmony_ci-#define REG(_name) Register(_name, #_name) 1538be168c0dSopenharmony_ci- 1539be168c0dSopenharmony_ci-class ExportSession; 1540be168c0dSopenharmony_ci- 1541be168c0dSopenharmony_ci-class Net : public Node { 1542be168c0dSopenharmony_ci- public: 1543be168c0dSopenharmony_ci- Net() = default; 1544be168c0dSopenharmony_ci- virtual ~Net(); 1545be168c0dSopenharmony_ci- explicit Net(std::string name) : Node(name) {} 1546be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 1547be168c0dSopenharmony_ci- std::vector<EXPR *> operator()(const std::vector<EXPR *> &inputs) override; 1548be168c0dSopenharmony_ci- std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &&inputs) override; 1549be168c0dSopenharmony_ci- void update_name(std::string name) override; 1550be168c0dSopenharmony_ci- Net *TrainNet(Node *optimizer, Node *loss_fn, const std::vector<EXPR *> &inputs); 1551be168c0dSopenharmony_ci- Net *TrainNet(Node *optimizer, const std::vector<EXPR *> &inputs); 1552be168c0dSopenharmony_ci- void PrintDot() { EXPR::PrintDot(output_); } 1553be168c0dSopenharmony_ci- 1554be168c0dSopenharmony_ci- void PushOutput(EXPR *e) { output_.push_back(e); } 1555be168c0dSopenharmony_ci- void PushInput(EXPR *e) { input_.push_back(e); } 1556be168c0dSopenharmony_ci- void SetRealOutput() { real_output_ = output_; } 1557be168c0dSopenharmony_ci- std::set<Node *> trainable_params() override; 1558be168c0dSopenharmony_ci- std::vector<Node *> Sort(); 1559be168c0dSopenharmony_ci- int BuildGrad(Node *optimizer); 1560be168c0dSopenharmony_ci- int BuildGrad(Node *optimizer, std::set<Node *> learnable); 1561be168c0dSopenharmony_ci- std::unique_ptr<schema::MetaGraphT> MakeMs(); 1562be168c0dSopenharmony_ci- std::unique_ptr<schema::MetaGraphT> MakeMs(std::string file_name); 1563be168c0dSopenharmony_ci- schema::MetaGraph *meta_graph() { return meta_graph_; } 1564be168c0dSopenharmony_ci- int Save(const schema::MetaGraphT &graph, const std::string filename); 1565be168c0dSopenharmony_ci- void set_impl(std::shared_ptr<mindspore::NetImpl> impl) { impl_ = impl; } 1566be168c0dSopenharmony_ci- const std::vector<int> InputShape(int idx); 1567be168c0dSopenharmony_ci- const std::vector<int> OutputShape(int idx); 1568be168c0dSopenharmony_ci- 1569be168c0dSopenharmony_ci- protected: 1570be168c0dSopenharmony_ci- std::vector<EXPR *> add(const std::vector<EXPR *> &input); 1571be168c0dSopenharmony_ci- void Register(Node *node, std::string &&name) { 1572be168c0dSopenharmony_ci- if (node != nullptr) { 1573be168c0dSopenharmony_ci- PushOp(node); 1574be168c0dSopenharmony_ci- node->update_name(name); 1575be168c0dSopenharmony_ci- } 1576be168c0dSopenharmony_ci- } 1577be168c0dSopenharmony_ci- 1578be168c0dSopenharmony_ci- private: 1579be168c0dSopenharmony_ci- friend mindspore::Net; 1580be168c0dSopenharmony_ci- std::unordered_set<Node *> visited_; 1581be168c0dSopenharmony_ci- std::map<EXPR *, std::list<EXPR *>> outmap_; // outputs per expression 1582be168c0dSopenharmony_ci- std::map<EXPR *, std::list<EXPR *>> inmap_; // inputs per expression 1583be168c0dSopenharmony_ci- std::vector<EXPR *> output_; // network output expression 1584be168c0dSopenharmony_ci- std::vector<EXPR *> real_output_; // network output for export 1585be168c0dSopenharmony_ci- std::vector<EXPR *> input_; // network input expression 1586be168c0dSopenharmony_ci- schema::MetaGraph *meta_graph_; // imported meta_graph 1587be168c0dSopenharmony_ci- std::unique_ptr<ExportSession> session_; // export session 1588be168c0dSopenharmony_ci- std::unique_ptr<Node> optimizer_; 1589be168c0dSopenharmony_ci- void TopoSortUtil(Node *v, std::stack<Node *> *stack); 1590be168c0dSopenharmony_ci- void CreateOutputMap(std::vector<EXPR *> vec, std::map<Node *, std::list<Node *>> *outmap); 1591be168c0dSopenharmony_ci- std::shared_ptr<mindspore::NetImpl> impl_; 1592be168c0dSopenharmony_ci-}; 1593be168c0dSopenharmony_ci- 1594be168c0dSopenharmony_ci-class NetWithLoss : public Net { 1595be168c0dSopenharmony_ci- public: 1596be168c0dSopenharmony_ci- NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) { 1597be168c0dSopenharmony_ci- REG(net_); 1598be168c0dSopenharmony_ci- REG(loss_fn_); 1599be168c0dSopenharmony_ci- loss_fn_->set_name("_loss_fn"); 1600be168c0dSopenharmony_ci- } 1601be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) { 1602be168c0dSopenharmony_ci- auto input = inputs[0]; 1603be168c0dSopenharmony_ci- auto label = inputs[1]; 1604be168c0dSopenharmony_ci- auto x = (*net_)({input}); 1605be168c0dSopenharmony_ci- x = (*loss_fn_)({x[0], label}); 1606be168c0dSopenharmony_ci- return {x.front()}; 1607be168c0dSopenharmony_ci- } 1608be168c0dSopenharmony_ci- 1609be168c0dSopenharmony_ci- private: 1610be168c0dSopenharmony_ci- Net *net_{nullptr}; 1611be168c0dSopenharmony_ci- Node *loss_fn_{nullptr}; 1612be168c0dSopenharmony_ci-}; 1613be168c0dSopenharmony_ci-} // namespace lite 1614be168c0dSopenharmony_ci-} // namespace mindspore 1615be168c0dSopenharmony_ci- 1616be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_NET_H_ 1617be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/node.cc b/mindspore/lite/src/expression/node.cc 1618be168c0dSopenharmony_cideleted file mode 100644 1619be168c0dSopenharmony_ciindex 022f15d4..00000000 1620be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/node.cc 1621be168c0dSopenharmony_ci+++ /dev/null 1622be168c0dSopenharmony_ci@@ -1,271 +0,0 @@ 1623be168c0dSopenharmony_ci-/** 1624be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 1625be168c0dSopenharmony_ci- * 1626be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 1627be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 1628be168c0dSopenharmony_ci- * You may obtain a copy of the License at 1629be168c0dSopenharmony_ci- * 1630be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 1631be168c0dSopenharmony_ci- * 1632be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 1633be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 1634be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1635be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 1636be168c0dSopenharmony_ci- * limitations under the License. 1637be168c0dSopenharmony_ci- */ 1638be168c0dSopenharmony_ci- 1639be168c0dSopenharmony_ci-#include <algorithm> 1640be168c0dSopenharmony_ci-#include <utility> 1641be168c0dSopenharmony_ci-#include <functional> 1642be168c0dSopenharmony_ci-#include "src/expression/node.h" 1643be168c0dSopenharmony_ci-#include "src/expression/ops.h" 1644be168c0dSopenharmony_ci-#include "src/expression/export.h" 1645be168c0dSopenharmony_ci-#include "src/litert/infer_manager.h" 1646be168c0dSopenharmony_ci-#include "src/common/utils.h" 1647be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 1648be168c0dSopenharmony_ci- 1649be168c0dSopenharmony_ci-namespace mindspore { 1650be168c0dSopenharmony_ci-namespace lite { 1651be168c0dSopenharmony_ci-int Node::name_id; 1652be168c0dSopenharmony_ci- 1653be168c0dSopenharmony_ci-std::vector<EXPR *> Node::construct(const std::vector<EXPR *> &inputs) { 1654be168c0dSopenharmony_ci- if (inputs.size() >= expr()->params().size()) { 1655be168c0dSopenharmony_ci- expr()->set_params(inputs); 1656be168c0dSopenharmony_ci- } else { 1657be168c0dSopenharmony_ci- for (std::size_t i = 0; i < inputs.size(); i++) { 1658be168c0dSopenharmony_ci- expr()->set_params(i, inputs[i]); 1659be168c0dSopenharmony_ci- } 1660be168c0dSopenharmony_ci- } 1661be168c0dSopenharmony_ci- auto ret = InferShape(); 1662be168c0dSopenharmony_ci- if (ret != RET_OK) { 1663be168c0dSopenharmony_ci- MS_LOG(ERROR) << "error infershape for node " << name(); 1664be168c0dSopenharmony_ci- return {}; 1665be168c0dSopenharmony_ci- } 1666be168c0dSopenharmony_ci- std::vector<EXPR *> res(expr_.size()); 1667be168c0dSopenharmony_ci- (void)std::transform(expr_.begin(), expr_.end(), res.begin(), [](const EXPR &e) { return const_cast<EXPR *>(&e); }); 1668be168c0dSopenharmony_ci- return res; 1669be168c0dSopenharmony_ci-} 1670be168c0dSopenharmony_ci- 1671be168c0dSopenharmony_ci-std::vector<EXPR *> Node::Grad(EXPR *expr) { 1672be168c0dSopenharmony_ci- MS_LOG(ERROR) << name() << " (" << schema::EnumNamePrimitiveType(primitive()) << ") does not have grad defined"; 1673be168c0dSopenharmony_ci- return {}; 1674be168c0dSopenharmony_ci-} 1675be168c0dSopenharmony_ci- 1676be168c0dSopenharmony_ci-int Node::CreateTensorFromExpr(const std::vector<EXPR *> &expr, std::vector<Tensor *> *tensors, bool is_input) { 1677be168c0dSopenharmony_ci- MS_ASSERT(tensors != nullptr); 1678be168c0dSopenharmony_ci- int ret = RET_OK; 1679be168c0dSopenharmony_ci- for (auto e : expr) { 1680be168c0dSopenharmony_ci- // Tensor -> TensorC 1681be168c0dSopenharmony_ci- if (is_input && e->node()->primitive() == schema::PrimitiveType_Depend) { 1682be168c0dSopenharmony_ci- continue; 1683be168c0dSopenharmony_ci- } 1684be168c0dSopenharmony_ci- auto type = (e->node()->primitive() != schema::PrimitiveType_NONE) ? Category::VAR : Category::CONST_TENSOR; 1685be168c0dSopenharmony_ci- auto t = std::make_unique<Tensor>(e->data_type(), e->dims(), (mindspore::Format)e->format(), type); 1686be168c0dSopenharmony_ci- if (t == nullptr) { 1687be168c0dSopenharmony_ci- ret = RET_NULL_PTR; 1688be168c0dSopenharmony_ci- break; 1689be168c0dSopenharmony_ci- } 1690be168c0dSopenharmony_ci- // copy data if any 1691be168c0dSopenharmony_ci- if (type == Category::CONST_TENSOR) { 1692be168c0dSopenharmony_ci- void *dst = t->MutableData(); 1693be168c0dSopenharmony_ci- if (dst == nullptr) { 1694be168c0dSopenharmony_ci- ret = RET_NULL_PTR; 1695be168c0dSopenharmony_ci- break; 1696be168c0dSopenharmony_ci- } 1697be168c0dSopenharmony_ci- if (e->node()->data() && (e->node()->data()->data().size() > 0)) { 1698be168c0dSopenharmony_ci- uint8_t *src = e->node()->data()->data().data(); 1699be168c0dSopenharmony_ci- memcpy(dst, src, t->Size()); 1700be168c0dSopenharmony_ci- } 1701be168c0dSopenharmony_ci- } 1702be168c0dSopenharmony_ci- tensors->push_back(t.release()); 1703be168c0dSopenharmony_ci- } 1704be168c0dSopenharmony_ci- return ret; 1705be168c0dSopenharmony_ci-} 1706be168c0dSopenharmony_ci- 1707be168c0dSopenharmony_ci-void Node::FreeAllTensors(std::vector<Tensor *> *tensors) { 1708be168c0dSopenharmony_ci- MS_ASSERT(tensors != nullptr); 1709be168c0dSopenharmony_ci- for (auto &t : *tensors) { 1710be168c0dSopenharmony_ci- delete t; 1711be168c0dSopenharmony_ci- } 1712be168c0dSopenharmony_ci- tensors->clear(); 1713be168c0dSopenharmony_ci-} 1714be168c0dSopenharmony_ci- 1715be168c0dSopenharmony_ci-int Node::InferShape() { 1716be168c0dSopenharmony_ci- auto ret = RET_OK; 1717be168c0dSopenharmony_ci- std::vector<Tensor *> in_tensors; 1718be168c0dSopenharmony_ci- std::vector<Tensor *> out_tensors; 1719be168c0dSopenharmony_ci- // build in \ out tensors 1720be168c0dSopenharmony_ci- ret = CreateTensorFromExpr(expr()->params(), &in_tensors, true); 1721be168c0dSopenharmony_ci- if (ret != RET_OK) { 1722be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Failed in create in tensors"; 1723be168c0dSopenharmony_ci- FreeAllTensors(&in_tensors); 1724be168c0dSopenharmony_ci- return RET_ERROR; 1725be168c0dSopenharmony_ci- } 1726be168c0dSopenharmony_ci- std::vector<EXPR *> expr(expr_.size()); 1727be168c0dSopenharmony_ci- (void)std::transform(expr_.begin(), expr_.end(), expr.begin(), [](const EXPR &e) { return const_cast<EXPR *>(&e); }); 1728be168c0dSopenharmony_ci- ret = CreateTensorFromExpr(expr, &out_tensors); 1729be168c0dSopenharmony_ci- if (ret != RET_OK) { 1730be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Failed in create out tensors"; 1731be168c0dSopenharmony_ci- FreeAllTensors(&in_tensors); 1732be168c0dSopenharmony_ci- FreeAllTensors(&out_tensors); 1733be168c0dSopenharmony_ci- return RET_ERROR; 1734be168c0dSopenharmony_ci- } 1735be168c0dSopenharmony_ci- // Do infer Shape 1736be168c0dSopenharmony_ci- ret = KernelInferShape(in_tensors, out_tensors, OpParam()); 1737be168c0dSopenharmony_ci- if (ret != RET_OK) { 1738be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed in infer shape for " << name(); 1739be168c0dSopenharmony_ci- FreeAllTensors(&in_tensors); 1740be168c0dSopenharmony_ci- FreeAllTensors(&out_tensors); 1741be168c0dSopenharmony_ci- return RET_ERROR; 1742be168c0dSopenharmony_ci- } 1743be168c0dSopenharmony_ci- // copy infer shape into expr 1744be168c0dSopenharmony_ci- for (uint32_t i = 0; i < expr_.size(); i++) { 1745be168c0dSopenharmony_ci- auto e = &expr_.at(i); 1746be168c0dSopenharmony_ci- auto o = out_tensors.at(i); 1747be168c0dSopenharmony_ci- e->set_format((o->format())); 1748be168c0dSopenharmony_ci- e->set_data_type(o->data_type()); 1749be168c0dSopenharmony_ci- e->SetDims(o->shape()); 1750be168c0dSopenharmony_ci- } 1751be168c0dSopenharmony_ci- // cleanup 1752be168c0dSopenharmony_ci- FreeAllTensors(&in_tensors); 1753be168c0dSopenharmony_ci- FreeAllTensors(&out_tensors); 1754be168c0dSopenharmony_ci- 1755be168c0dSopenharmony_ci- return ret; 1756be168c0dSopenharmony_ci-} 1757be168c0dSopenharmony_ci- 1758be168c0dSopenharmony_ci-EXPR *Node::CreateWeights(std::vector<int> dims, TypeId data_type, int format, Param::Mode mode, std::string name) { 1759be168c0dSopenharmony_ci- auto weights = new (std::nothrow) InputM(dims); 1760be168c0dSopenharmony_ci- if (weights == nullptr) { 1761be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate weights"; 1762be168c0dSopenharmony_ci- return nullptr; 1763be168c0dSopenharmony_ci- } 1764be168c0dSopenharmony_ci- weights->set_name(this->name() + "/" + name); 1765be168c0dSopenharmony_ci- int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()); 1766be168c0dSopenharmony_ci- weights->data()->SetSize(size); 1767be168c0dSopenharmony_ci- weights->data()->Fill(mode); 1768be168c0dSopenharmony_ci- PushOp(weights); 1769be168c0dSopenharmony_ci- return weights->expr(); 1770be168c0dSopenharmony_ci-} 1771be168c0dSopenharmony_ci- 1772be168c0dSopenharmony_ci-Node *Node::CreateConstTensor(int index, std::vector<int> dims, TypeId data_type, int format, std::string name, 1773be168c0dSopenharmony_ci- const void *data) { 1774be168c0dSopenharmony_ci- auto tensor = NN::Input(dims, data_type, format); 1775be168c0dSopenharmony_ci- int elem_size = DataTypeSize(data_type); 1776be168c0dSopenharmony_ci- tensor->set_name(this->name() + "/" + name); 1777be168c0dSopenharmony_ci- int size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int>()) * elem_size; 1778be168c0dSopenharmony_ci- tensor->data()->SetSize(size); 1779be168c0dSopenharmony_ci- tensor->data()->Copy(reinterpret_cast<const uint8_t *>(data), size); 1780be168c0dSopenharmony_ci- expr()->set_params(index, tensor->expr()); 1781be168c0dSopenharmony_ci- PushOp(tensor); 1782be168c0dSopenharmony_ci- return tensor; 1783be168c0dSopenharmony_ci-} 1784be168c0dSopenharmony_ci- 1785be168c0dSopenharmony_ci-int Node::MakeEntry(ExportSession *session) { 1786be168c0dSopenharmony_ci- std::vector<uint32_t> input_idx; 1787be168c0dSopenharmony_ci- std::vector<uint32_t> output_idx; 1788be168c0dSopenharmony_ci- std::vector<uint8_t> empty; 1789be168c0dSopenharmony_ci- if (primitive() == schema::PrimitiveType_Depend) return RET_OK; 1790be168c0dSopenharmony_ci- // create node input 1791be168c0dSopenharmony_ci- size_t inputs = InputsNum(); 1792be168c0dSopenharmony_ci- for (size_t i = 0; i < inputs; i++) { 1793be168c0dSopenharmony_ci- EXPR *ex = expr()->GetInput(i); 1794be168c0dSopenharmony_ci- if (ex->node()->primitive() == schema::PrimitiveType_Depend) continue; 1795be168c0dSopenharmony_ci- uint32_t id = session->GetOutput(ex); 1796be168c0dSopenharmony_ci- input_idx.push_back(id); 1797be168c0dSopenharmony_ci- } 1798be168c0dSopenharmony_ci- size_t outputs = OutputsNum(); 1799be168c0dSopenharmony_ci- size_t last_id = session->meta_graph()->allTensors.size(); 1800be168c0dSopenharmony_ci- int type = (primitive() == schema::PrimitiveType_NONE) ? static_cast<int>(NodeType_ValueNode) : static_cast<int>(NodeType_CNode); 1801be168c0dSopenharmony_ci- auto data = (type == static_cast<int>(NodeType_ValueNode)) ? this->data()->data() : empty; 1802be168c0dSopenharmony_ci- if (data.empty()) type = NodeType_CNode; // input is Cnode !!? 1803be168c0dSopenharmony_ci- int idx = 0; 1804be168c0dSopenharmony_ci- for (size_t i = 0; i < outputs; i++) { 1805be168c0dSopenharmony_ci- if (session->IsToDependOnly(expr(i))) continue; 1806be168c0dSopenharmony_ci- output_idx.push_back(last_id + idx); 1807be168c0dSopenharmony_ci- session->UpdateOutput(expr(i), last_id + idx); 1808be168c0dSopenharmony_ci- auto odims = dims(i); 1809be168c0dSopenharmony_ci- auto data_type = expr(i)->data_type(); 1810be168c0dSopenharmony_ci- auto format = expr(i)->format(); 1811be168c0dSopenharmony_ci- std::string footer = (i > 0) ? ("-" + std::to_string(i)) : ""; 1812be168c0dSopenharmony_ci- auto otensor = CreateTensor(name() + footer, type, data_type, odims, format, data); 1813be168c0dSopenharmony_ci- std::cout << "tensor -" << last_id + idx << ": " << name() + footer << std::endl; 1814be168c0dSopenharmony_ci- idx++; 1815be168c0dSopenharmony_ci- session->meta_graph()->allTensors.emplace_back(std::move(otensor)); 1816be168c0dSopenharmony_ci- } 1817be168c0dSopenharmony_ci- if (primitive() != schema::PrimitiveType_NONE) { 1818be168c0dSopenharmony_ci- if (output_idx.size() == 0) { 1819be168c0dSopenharmony_ci- return RET_OK; 1820be168c0dSopenharmony_ci- } 1821be168c0dSopenharmony_ci- auto cnode = CreateCNode(input_idx, output_idx); 1822be168c0dSopenharmony_ci- 1823be168c0dSopenharmony_ci- auto ret = UnPopulate(cnode); 1824be168c0dSopenharmony_ci- if (ret != RET_OK) { 1825be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to populate cnode"; 1826be168c0dSopenharmony_ci- return RET_ERROR; 1827be168c0dSopenharmony_ci- } 1828be168c0dSopenharmony_ci- session->meta_graph()->nodes.emplace_back(std::move(cnode)); 1829be168c0dSopenharmony_ci- } 1830be168c0dSopenharmony_ci- 1831be168c0dSopenharmony_ci- return RET_OK; 1832be168c0dSopenharmony_ci-} 1833be168c0dSopenharmony_ci- 1834be168c0dSopenharmony_ci-std::unique_ptr<schema::CNodeT> Node::CreateCNode(std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex) { 1835be168c0dSopenharmony_ci- auto cnode = std::make_unique<schema::CNodeT>(); 1836be168c0dSopenharmony_ci- cnode->primitive = std::make_unique<schema::PrimitiveT>(); 1837be168c0dSopenharmony_ci- cnode->primitive->value.type = primitive(); 1838be168c0dSopenharmony_ci- cnode->name = name(); 1839be168c0dSopenharmony_ci- cnode->inputIndex = inputIndex; 1840be168c0dSopenharmony_ci- cnode->outputIndex = outputIndex; 1841be168c0dSopenharmony_ci- return cnode; 1842be168c0dSopenharmony_ci-} 1843be168c0dSopenharmony_ci- 1844be168c0dSopenharmony_ci-int Node::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 1845be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Node " << schema::EnumNamePrimitiveType(primitive()) << " cannot be exported"; 1846be168c0dSopenharmony_ci- return RET_ERROR; 1847be168c0dSopenharmony_ci-} 1848be168c0dSopenharmony_ci- 1849be168c0dSopenharmony_ci-std::unique_ptr<mindspore::schema::TensorT> Node::CreateTensor(std::string name, int type, int data_type, 1850be168c0dSopenharmony_ci- const std::vector<int32_t> dims, int format, 1851be168c0dSopenharmony_ci- const std::vector<uint8_t> &data) { 1852be168c0dSopenharmony_ci- auto tensorT = std::make_unique<mindspore::schema::TensorT>(); 1853be168c0dSopenharmony_ci- tensorT->nodeType = type; 1854be168c0dSopenharmony_ci- tensorT->dims = dims; 1855be168c0dSopenharmony_ci- tensorT->format = static_cast<schema::Format>(format); 1856be168c0dSopenharmony_ci- tensorT->name = name; 1857be168c0dSopenharmony_ci- tensorT->refCount = 0; 1858be168c0dSopenharmony_ci- tensorT->offset = 0; 1859be168c0dSopenharmony_ci- tensorT->dataType = data_type; 1860be168c0dSopenharmony_ci- tensorT->data = data; 1861be168c0dSopenharmony_ci- tensorT->enableHuffmanCode = false; 1862be168c0dSopenharmony_ci- if (tensorT->nodeType == mindspore::lite::NodeType_ValueNode) { 1863be168c0dSopenharmony_ci- tensorT->data = data; 1864be168c0dSopenharmony_ci- } 1865be168c0dSopenharmony_ci- return tensorT; 1866be168c0dSopenharmony_ci-} 1867be168c0dSopenharmony_ci- 1868be168c0dSopenharmony_ci-int Node::SetOutputs(int num) { 1869be168c0dSopenharmony_ci- EXPR e(this); 1870be168c0dSopenharmony_ci- e.SetSize(0); 1871be168c0dSopenharmony_ci- for (auto i = expr_.size(); i < static_cast<size_t>(num); i++) { 1872be168c0dSopenharmony_ci- expr_.emplace_back(e); 1873be168c0dSopenharmony_ci- } 1874be168c0dSopenharmony_ci- return RET_OK; 1875be168c0dSopenharmony_ci-} 1876be168c0dSopenharmony_ci- 1877be168c0dSopenharmony_ci-Node::~Node() { 1878be168c0dSopenharmony_ci- for (auto &op : ops_) { 1879be168c0dSopenharmony_ci- delete op; 1880be168c0dSopenharmony_ci- } 1881be168c0dSopenharmony_ci- ops_.clear(); 1882be168c0dSopenharmony_ci- if (impl_ != nullptr) { 1883be168c0dSopenharmony_ci- impl_->set_node(nullptr); 1884be168c0dSopenharmony_ci- auto pnode = impl_->pnode(); 1885be168c0dSopenharmony_ci- if (pnode != nullptr) { 1886be168c0dSopenharmony_ci- impl_->set_pnode(nullptr); 1887be168c0dSopenharmony_ci- delete pnode; 1888be168c0dSopenharmony_ci- } 1889be168c0dSopenharmony_ci- } 1890be168c0dSopenharmony_ci- impl_ = nullptr; 1891be168c0dSopenharmony_ci-} 1892be168c0dSopenharmony_ci-} // namespace lite 1893be168c0dSopenharmony_ci-} // namespace mindspore 1894be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/node.h b/mindspore/lite/src/expression/node.h 1895be168c0dSopenharmony_cideleted file mode 100644 1896be168c0dSopenharmony_ciindex b6a820e0..00000000 1897be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/node.h 1898be168c0dSopenharmony_ci+++ /dev/null 1899be168c0dSopenharmony_ci@@ -1,156 +0,0 @@ 1900be168c0dSopenharmony_ci-/** 1901be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 1902be168c0dSopenharmony_ci- * 1903be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 1904be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 1905be168c0dSopenharmony_ci- * You may obtain a copy of the License at 1906be168c0dSopenharmony_ci- * 1907be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 1908be168c0dSopenharmony_ci- * 1909be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 1910be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 1911be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1912be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 1913be168c0dSopenharmony_ci- * limitations under the License. 1914be168c0dSopenharmony_ci- */ 1915be168c0dSopenharmony_ci- 1916be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_ 1917be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_ 1918be168c0dSopenharmony_ci- 1919be168c0dSopenharmony_ci-#include <stdlib.h> 1920be168c0dSopenharmony_ci-#include <vector> 1921be168c0dSopenharmony_ci-#include <string> 1922be168c0dSopenharmony_ci-#include <iostream> 1923be168c0dSopenharmony_ci-#include <memory> 1924be168c0dSopenharmony_ci-#include <set> 1925be168c0dSopenharmony_ci-#include "src/expression/export.h" 1926be168c0dSopenharmony_ci-#include "inner/model_generated.h" 1927be168c0dSopenharmony_ci-#include "src/expression/param.h" 1928be168c0dSopenharmony_ci-#include "src/expression/expr.h" 1929be168c0dSopenharmony_ci-#include "src/tensor.h" 1930be168c0dSopenharmony_ci-#include "nnacl/op_base.h" 1931be168c0dSopenharmony_ci- 1932be168c0dSopenharmony_ci-namespace mindspore { 1933be168c0dSopenharmony_ci-class NodeImpl; 1934be168c0dSopenharmony_ci-namespace schema { 1935be168c0dSopenharmony_ci-struct TensorT; 1936be168c0dSopenharmony_ci-struct CNodeT; 1937be168c0dSopenharmony_ci-} // namespace schema 1938be168c0dSopenharmony_ci- 1939be168c0dSopenharmony_ci-namespace lite { 1940be168c0dSopenharmony_ci-class Node { 1941be168c0dSopenharmony_ci- public: 1942be168c0dSopenharmony_ci- const std::string kGradName = "Gradients"; 1943be168c0dSopenharmony_ci- explicit Node(const std::string name) : opParam_(nullptr), name_(name) { expr_.emplace_back(this); } 1944be168c0dSopenharmony_ci- virtual ~Node(); 1945be168c0dSopenharmony_ci- Node() : Node("") {} 1946be168c0dSopenharmony_ci- explicit Node(Node *node) : Node(*node) {} 1947be168c0dSopenharmony_ci- EXPR *create(std::string name) { 1948be168c0dSopenharmony_ci- name_ = name; 1949be168c0dSopenharmony_ci- return &expr_[0]; 1950be168c0dSopenharmony_ci- } 1951be168c0dSopenharmony_ci- virtual std::vector<EXPR *> operator()(const std::vector<EXPR *> &inputs) { 1952be168c0dSopenharmony_ci- auto x = construct(inputs); 1953be168c0dSopenharmony_ci- return x; 1954be168c0dSopenharmony_ci- } 1955be168c0dSopenharmony_ci- virtual std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &&inputs) { 1956be168c0dSopenharmony_ci- std::vector<EXPR *> vec = inputs; 1957be168c0dSopenharmony_ci- auto x = construct(vec); 1958be168c0dSopenharmony_ci- return x; 1959be168c0dSopenharmony_ci- } 1960be168c0dSopenharmony_ci- virtual std::vector<EXPR *> operator()(const std::initializer_list<EXPR *> &inputs) { 1961be168c0dSopenharmony_ci- std::vector<EXPR *> vec = inputs; 1962be168c0dSopenharmony_ci- auto x = construct(vec); 1963be168c0dSopenharmony_ci- return x; 1964be168c0dSopenharmony_ci- } 1965be168c0dSopenharmony_ci- void set_primitive(schema::PrimitiveType primitive) { 1966be168c0dSopenharmony_ci- primitive_ = primitive; 1967be168c0dSopenharmony_ci- if (OpParam() != nullptr) opParam_->type_ = primitive_; 1968be168c0dSopenharmony_ci- } 1969be168c0dSopenharmony_ci- schema::PrimitiveType primitive() { return primitive_; } 1970be168c0dSopenharmony_ci- virtual std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs); 1971be168c0dSopenharmony_ci- std::string name() { return name_; } 1972be168c0dSopenharmony_ci- void set_name(std::string name) { name_ = name; } 1973be168c0dSopenharmony_ci- virtual void update_name(std::string name) { set_name(name + "/" + name_); } 1974be168c0dSopenharmony_ci- size_t Load(std::string file_name, size_t offset = 0) { return offset; } 1975be168c0dSopenharmony_ci- OpParameter *OpParam() const { return opParam_.get(); } 1976be168c0dSopenharmony_ci- virtual void Add(Node *node) {} 1977be168c0dSopenharmony_ci- virtual std::vector<EXPR *> Clone(EXPR *grad, EXPR *weight) { return {}; } 1978be168c0dSopenharmony_ci- void SetOpParam(std::shared_ptr<OpParameter> opParam) { opParam_ = opParam; } 1979be168c0dSopenharmony_ci- void SetOpParam(void *opParam) { opParam_.reset(reinterpret_cast<OpParameter *>(opParam), free); } 1980be168c0dSopenharmony_ci- static std::string UniqueName(const std::string &name) { return name + "-" + std::to_string(name_id++); } 1981be168c0dSopenharmony_ci- static std::string UniqueName(std::string &&name) { return name + "-" + std::to_string(name_id++); } 1982be168c0dSopenharmony_ci- template <typename T> 1983be168c0dSopenharmony_ci- int CloneOpParam(std::shared_ptr<OpParameter> opParam) { 1984be168c0dSopenharmony_ci- auto t = reinterpret_cast<T *>(opParam.get()); 1985be168c0dSopenharmony_ci- auto obj = new (std::nothrow) T(*t); // copy content 1986be168c0dSopenharmony_ci- if (obj == nullptr) { 1987be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate obj"; 1988be168c0dSopenharmony_ci- return RET_ERROR; 1989be168c0dSopenharmony_ci- } 1990be168c0dSopenharmony_ci- opParam_.reset(reinterpret_cast<OpParameter *>(obj)); 1991be168c0dSopenharmony_ci- return RET_OK; 1992be168c0dSopenharmony_ci- } 1993be168c0dSopenharmony_ci- template <typename T> 1994be168c0dSopenharmony_ci- int CloneOpParam(OpParameter *opParam) { 1995be168c0dSopenharmony_ci- auto t = reinterpret_cast<T *>(opParam); 1996be168c0dSopenharmony_ci- auto obj = new (std::nothrow) T(*t); // copy content 1997be168c0dSopenharmony_ci- if (obj == nullptr) { 1998be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate obj"; 1999be168c0dSopenharmony_ci- return RET_ERROR; 2000be168c0dSopenharmony_ci- } 2001be168c0dSopenharmony_ci- opParam_.reset(reinterpret_cast<OpParameter *>(obj)); 2002be168c0dSopenharmony_ci- return RET_OK; 2003be168c0dSopenharmony_ci- } 2004be168c0dSopenharmony_ci- virtual Param *weight() { return nullptr; } 2005be168c0dSopenharmony_ci- EXPR *expr(int i) { return &expr_[i]; } 2006be168c0dSopenharmony_ci- EXPR *expr() { return expr(0); } 2007be168c0dSopenharmony_ci- std::vector<EXPR *> inputs() { return expr()[0].params(); } 2008be168c0dSopenharmony_ci- size_t InputsNum() { return expr()[0].params().size(); } 2009be168c0dSopenharmony_ci- size_t OutputsNum() { return expr_.size(); } 2010be168c0dSopenharmony_ci- EXPR *input(int idx) { return expr()[0].params().at(idx); } 2011be168c0dSopenharmony_ci- EXPR *output(int idx) { return expr(idx); } 2012be168c0dSopenharmony_ci- EXPR *CreateWeights(std::vector<int> dims, TypeId data_type, int format, Param::Mode mode, std::string name); 2013be168c0dSopenharmony_ci- Node *CreateConstTensor(int index, std::vector<int> dims, TypeId data_type, int format, std::string name, 2014be168c0dSopenharmony_ci- const void *data); 2015be168c0dSopenharmony_ci- virtual std::vector<EXPR *> Grad(EXPR *expr); 2016be168c0dSopenharmony_ci- virtual Param *data() { return nullptr; } 2017be168c0dSopenharmony_ci- bool IsLearn(Node *node) { return learnable_.find(node) != learnable_.end(); } 2018be168c0dSopenharmony_ci- virtual void SetLearn() {} 2019be168c0dSopenharmony_ci- virtual std::set<Node *> trainable_params() { return learnable_; } 2020be168c0dSopenharmony_ci- std::vector<int> &dims() { return expr()->dims(); } 2021be168c0dSopenharmony_ci- std::vector<int> &dims(int i) { return expr(i)->dims(); } 2022be168c0dSopenharmony_ci- // export 2023be168c0dSopenharmony_ci- int MakeEntry(ExportSession *session); 2024be168c0dSopenharmony_ci- void PushOp(Node *n) { ops_.push_back(n); } 2025be168c0dSopenharmony_ci- virtual void AddNetOutput(std::vector<EXPR *> *output) {} 2026be168c0dSopenharmony_ci- int SetOutputs(int num); 2027be168c0dSopenharmony_ci- std::shared_ptr<OpParameter> opParam_; 2028be168c0dSopenharmony_ci- void set_impl(std::shared_ptr<NodeImpl> impl) { impl_ = impl; } 2029be168c0dSopenharmony_ci- 2030be168c0dSopenharmony_ci- protected: 2031be168c0dSopenharmony_ci- std::vector<EXPR> expr_; // hold outputs 2032be168c0dSopenharmony_ci- std::vector<Node *> ops_; // all nodes or subnets 2033be168c0dSopenharmony_ci- int InferShape(); 2034be168c0dSopenharmony_ci- void AddLearn(Node *node) { learnable_.insert(node); } 2035be168c0dSopenharmony_ci- void AssignLearn(std::set<Node *> &&learn) { learnable_ = learn; } 2036be168c0dSopenharmony_ci- 2037be168c0dSopenharmony_ci- std::unique_ptr<schema::CNodeT> CreateCNode(std::vector<uint32_t> inputIndex, std::vector<uint32_t> outputIndex); 2038be168c0dSopenharmony_ci- virtual int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode); 2039be168c0dSopenharmony_ci- std::unique_ptr<schema::TensorT> CreateTensor(std::string name, int type, int data_type, 2040be168c0dSopenharmony_ci- const std::vector<int32_t> dims, int format, 2041be168c0dSopenharmony_ci- const std::vector<uint8_t> &data); 2042be168c0dSopenharmony_ci- 2043be168c0dSopenharmony_ci- private: 2044be168c0dSopenharmony_ci- int CreateTensorFromExpr(const std::vector<EXPR *> &expr, std::vector<Tensor *> *tensors, bool is_input = false); 2045be168c0dSopenharmony_ci- void FreeAllTensors(std::vector<Tensor *> *tensors); 2046be168c0dSopenharmony_ci- static int name_id; 2047be168c0dSopenharmony_ci- std::set<Node *> learnable_; // set of nodes with learnable parameters 2048be168c0dSopenharmony_ci- std::string name_; 2049be168c0dSopenharmony_ci- schema::PrimitiveType primitive_; 2050be168c0dSopenharmony_ci- std::shared_ptr<NodeImpl> impl_; 2051be168c0dSopenharmony_ci-}; 2052be168c0dSopenharmony_ci-} // namespace lite 2053be168c0dSopenharmony_ci-} // namespace mindspore 2054be168c0dSopenharmony_ci- 2055be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_NODE_H_ 2056be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops.cc b/mindspore/lite/src/expression/ops.cc 2057be168c0dSopenharmony_cideleted file mode 100644 2058be168c0dSopenharmony_ciindex 629fa0ae..00000000 2059be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops.cc 2060be168c0dSopenharmony_ci+++ /dev/null 2061be168c0dSopenharmony_ci@@ -1,66 +0,0 @@ 2062be168c0dSopenharmony_ci- 2063be168c0dSopenharmony_ci-/** 2064be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2065be168c0dSopenharmony_ci- * 2066be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2067be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2068be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2069be168c0dSopenharmony_ci- * 2070be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2071be168c0dSopenharmony_ci- * 2072be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2073be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2074be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2075be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2076be168c0dSopenharmony_ci- * limitations under the License. 2077be168c0dSopenharmony_ci- */ 2078be168c0dSopenharmony_ci- 2079be168c0dSopenharmony_ci-#include <numeric> 2080be168c0dSopenharmony_ci-#include <algorithm> 2081be168c0dSopenharmony_ci-#include "src/expression/ops.h" 2082be168c0dSopenharmony_ci-#include "src/expression/ops_utils.h" 2083be168c0dSopenharmony_ci-#include "src/expression/param.h" 2084be168c0dSopenharmony_ci-#include "include/api/cfg.h" 2085be168c0dSopenharmony_ci-#include "src/expression/sequential.h" 2086be168c0dSopenharmony_ci- 2087be168c0dSopenharmony_ci-namespace mindspore { 2088be168c0dSopenharmony_ci-namespace lite { 2089be168c0dSopenharmony_ci-void InputM::SetUp(const std::vector<int> &dims, TypeId data_type, int fmt) { 2090be168c0dSopenharmony_ci- expr()->SetSize(C0NUM); 2091be168c0dSopenharmony_ci- expr()->SetDims(dims); 2092be168c0dSopenharmony_ci- expr()->set_data_type(data_type); 2093be168c0dSopenharmony_ci- expr()->set_format(fmt); 2094be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_NONE); 2095be168c0dSopenharmony_ci-} 2096be168c0dSopenharmony_ci- 2097be168c0dSopenharmony_ci-InputM::InputM(const std::vector<int> &dims, TypeId data_type, int fmt) : Node() { SetUp(dims, data_type, fmt); } 2098be168c0dSopenharmony_ci- 2099be168c0dSopenharmony_ci-InputM::InputM(const schema::Tensor *tensor) : Node() { 2100be168c0dSopenharmony_ci- std::vector<int> dims(tensor->dims()->size()); 2101be168c0dSopenharmony_ci- (void)std::transform(tensor->dims()->begin(), tensor->dims()->end(), dims.begin(), [](int32_t x) { return x; }); 2102be168c0dSopenharmony_ci- SetUp(dims, static_cast<TypeId>(tensor->dataType()), tensor->format()); 2103be168c0dSopenharmony_ci- if (tensor->name()) set_name(tensor->name()->str()); 2104be168c0dSopenharmony_ci- if (tensor->data() != nullptr) data_.Copy(tensor->data()->data(), tensor->data()->size()); 2105be168c0dSopenharmony_ci-} 2106be168c0dSopenharmony_ci- 2107be168c0dSopenharmony_ci-namespace NN { 2108be168c0dSopenharmony_ci-Node *Input(const std::vector<int> &dims, TypeId data_type, int fmt) { 2109be168c0dSopenharmony_ci- auto i = new (std::nothrow) InputM(dims, data_type, fmt); 2110be168c0dSopenharmony_ci- if (i == nullptr) { 2111be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate input expression "; 2112be168c0dSopenharmony_ci- return nullptr; 2113be168c0dSopenharmony_ci- } 2114be168c0dSopenharmony_ci- return i; 2115be168c0dSopenharmony_ci-} 2116be168c0dSopenharmony_ci- 2117be168c0dSopenharmony_ci-Net *Sequential() { 2118be168c0dSopenharmony_ci- auto s = new (std::nothrow) mindspore::lite::Sequential(); 2119be168c0dSopenharmony_ci- if (s == nullptr) { 2120be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate sequential expression"; 2121be168c0dSopenharmony_ci- return nullptr; 2122be168c0dSopenharmony_ci- } 2123be168c0dSopenharmony_ci- return s; 2124be168c0dSopenharmony_ci-} 2125be168c0dSopenharmony_ci-}; // namespace NN 2126be168c0dSopenharmony_ci-} // namespace lite 2127be168c0dSopenharmony_ci-} // namespace mindspore 2128be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops.h b/mindspore/lite/src/expression/ops.h 2129be168c0dSopenharmony_cideleted file mode 100644 2130be168c0dSopenharmony_ciindex 96a006f4..00000000 2131be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops.h 2132be168c0dSopenharmony_ci+++ /dev/null 2133be168c0dSopenharmony_ci@@ -1,69 +0,0 @@ 2134be168c0dSopenharmony_ci-/** 2135be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2136be168c0dSopenharmony_ci- * 2137be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2138be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2139be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2140be168c0dSopenharmony_ci- * 2141be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2142be168c0dSopenharmony_ci- * 2143be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2144be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2145be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2146be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2147be168c0dSopenharmony_ci- * limitations under the License. 2148be168c0dSopenharmony_ci- */ 2149be168c0dSopenharmony_ci- 2150be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_ 2151be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_ 2152be168c0dSopenharmony_ci- 2153be168c0dSopenharmony_ci-#include <vector> 2154be168c0dSopenharmony_ci-#include <string> 2155be168c0dSopenharmony_ci-#include <set> 2156be168c0dSopenharmony_ci-#include "include/api/net.h" 2157be168c0dSopenharmony_ci-#include "src/expression/cfg.h" 2158be168c0dSopenharmony_ci-#include "src/expression/net.h" 2159be168c0dSopenharmony_ci-#include "inner/model_generated.h" 2160be168c0dSopenharmony_ci- 2161be168c0dSopenharmony_ci-namespace mindspore { 2162be168c0dSopenharmony_ci-namespace lite { 2163be168c0dSopenharmony_ci-class InputM : public Node { 2164be168c0dSopenharmony_ci- public: 2165be168c0dSopenharmony_ci- explicit InputM(const schema::Tensor *tensor); 2166be168c0dSopenharmony_ci- explicit InputM(const std::vector<int> &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC); 2167be168c0dSopenharmony_ci- Param *data() override { return &data_; } 2168be168c0dSopenharmony_ci- 2169be168c0dSopenharmony_ci- private: 2170be168c0dSopenharmony_ci- void SetUp(const std::vector<int> &dims, TypeId data_type, int fmt); 2171be168c0dSopenharmony_ci- Param data_; 2172be168c0dSopenharmony_ci-}; 2173be168c0dSopenharmony_ci-namespace NN { 2174be168c0dSopenharmony_ci-Node *Conv2D(const ConvConfig &cfg); 2175be168c0dSopenharmony_ci-Node *Relu(); 2176be168c0dSopenharmony_ci-Node *Dense(const DenseConfig &cfg); 2177be168c0dSopenharmony_ci-Node *Flatten(); 2178be168c0dSopenharmony_ci-Node *Input(const std::vector<int> &dims, TypeId data_type = kNumberTypeFloat32, int fmt = NHWC); 2179be168c0dSopenharmony_ci-Node *Add(); 2180be168c0dSopenharmony_ci-Node *Sub(); 2181be168c0dSopenharmony_ci-Node *Div(); 2182be168c0dSopenharmony_ci-Node *Mul(); 2183be168c0dSopenharmony_ci-Node *Neg(); 2184be168c0dSopenharmony_ci-Node *SoftmaxCrossEntropy(); 2185be168c0dSopenharmony_ci-Net *Sequential(); 2186be168c0dSopenharmony_ci-Node *Adam(std::set<Node *> &&learn, const AdamConfig &cfg); 2187be168c0dSopenharmony_ci- 2188be168c0dSopenharmony_ci-Node *Softmax(int axis = -1); 2189be168c0dSopenharmony_ci-Node *BatchNorm2D(int outp, float momentum = 0.1, float epsilon = 1e-5f); 2190be168c0dSopenharmony_ci-Node *Sigmoid(); 2191be168c0dSopenharmony_ci-Node *DropOut(float ration = 0.5); 2192be168c0dSopenharmony_ci-Node *ReLU6(); 2193be168c0dSopenharmony_ci-Node *Reshape(const std::vector<int> &shape); 2194be168c0dSopenharmony_ci-Node *ReduceMean(bool keep_dims, const std::vector<int> &dims); 2195be168c0dSopenharmony_ci-Node *ReduceSum(bool keep_dims, const std::vector<int> &dims); 2196be168c0dSopenharmony_ci-Node *Tile(const std::vector<int> &multiples); 2197be168c0dSopenharmony_ci-Node *MaxPool2D(const PoolingConfig &cfg); 2198be168c0dSopenharmony_ci-Node *AvgPool2D(const PoolingConfig &cfg); 2199be168c0dSopenharmony_ci-} // namespace NN 2200be168c0dSopenharmony_ci-} // namespace lite 2201be168c0dSopenharmony_ci-} // namespace mindspore 2202be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_H_ 2203be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/activation.cc b/mindspore/lite/src/expression/ops/activation.cc 2204be168c0dSopenharmony_cideleted file mode 100644 2205be168c0dSopenharmony_ciindex 3429a003..00000000 2206be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/activation.cc 2207be168c0dSopenharmony_ci+++ /dev/null 2208be168c0dSopenharmony_ci@@ -1,133 +0,0 @@ 2209be168c0dSopenharmony_ci-/** 2210be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2211be168c0dSopenharmony_ci- * 2212be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2213be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2214be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2215be168c0dSopenharmony_ci- * 2216be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2217be168c0dSopenharmony_ci- * 2218be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2219be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2220be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2221be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2222be168c0dSopenharmony_ci- * limitations under the License. 2223be168c0dSopenharmony_ci- */ 2224be168c0dSopenharmony_ci-#include "src/expression/ops/activation.h" 2225be168c0dSopenharmony_ci-#include "nnacl/fp32/activation_fp32.h" 2226be168c0dSopenharmony_ci-#include "src/expression/import.h" 2227be168c0dSopenharmony_ci-#include "src/expression/ops.h" 2228be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 2229be168c0dSopenharmony_ci- 2230be168c0dSopenharmony_ci-namespace mindspore { 2231be168c0dSopenharmony_ci-namespace lite { 2232be168c0dSopenharmony_ci-ActM::ActM(schema::ActivationType type) : Node() { 2233be168c0dSopenharmony_ci- auto op_param = malloc(sizeof(ActivationParameter)); 2234be168c0dSopenharmony_ci- if (op_param == nullptr) { 2235be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ActivationParameter"; 2236be168c0dSopenharmony_ci- return; 2237be168c0dSopenharmony_ci- } 2238be168c0dSopenharmony_ci- SetOpParam(op_param); 2239be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Activation); 2240be168c0dSopenharmony_ci- ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(opParam_.get()); 2241be168c0dSopenharmony_ci- act_param->type_ = type; 2242be168c0dSopenharmony_ci- act_param->alpha_ = 0.f; 2243be168c0dSopenharmony_ci- act_param->min_val_ = 0.f; 2244be168c0dSopenharmony_ci- act_param->max_val_ = 0.f; 2245be168c0dSopenharmony_ci-} 2246be168c0dSopenharmony_ci- 2247be168c0dSopenharmony_ci-std::vector<EXPR *> ActM::Grad(EXPR *yt) { 2248be168c0dSopenharmony_ci- auto actGrad = new (std::nothrow) ActGradM(this); 2249be168c0dSopenharmony_ci- if (actGrad == nullptr) { 2250be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate activation grad node"; 2251be168c0dSopenharmony_ci- return {}; 2252be168c0dSopenharmony_ci- } 2253be168c0dSopenharmony_ci- PushOp(actGrad); 2254be168c0dSopenharmony_ci- auto param = reinterpret_cast<ActivationParameter *>(actGrad->OpParam()); 2255be168c0dSopenharmony_ci- EXPR *ag = nullptr; 2256be168c0dSopenharmony_ci- actGrad->expr()->SetSize(C2NUM); 2257be168c0dSopenharmony_ci- if ((param->type_ == schema::ActivationType_SIGMOID) || (param->type_ == schema::ActivationType_TANH)) { 2258be168c0dSopenharmony_ci- ag = (*actGrad)({output(0), yt}).front(); 2259be168c0dSopenharmony_ci- } else if ((param->type_ == schema::ActivationType_HSWISH) || (param->type_ == schema::ActivationType_HSIGMOID) || 2260be168c0dSopenharmony_ci- (param->type_ == schema::ActivationType_RELU6)) { 2261be168c0dSopenharmony_ci- ag = (*actGrad)({yt, input(0)}).front(); 2262be168c0dSopenharmony_ci- } else if (param->type_ == schema::ActivationType_GELU) { 2263be168c0dSopenharmony_ci- actGrad->expr()->SetSize(C3NUM); 2264be168c0dSopenharmony_ci- ag = (*actGrad)({yt, input(0), output(0)}).front(); 2265be168c0dSopenharmony_ci- } else { 2266be168c0dSopenharmony_ci- ag = (*actGrad)({yt, output(0)}).front(); 2267be168c0dSopenharmony_ci- } 2268be168c0dSopenharmony_ci- std::vector<EXPR *> res = {ag}; 2269be168c0dSopenharmony_ci- return res; 2270be168c0dSopenharmony_ci-} 2271be168c0dSopenharmony_ci- 2272be168c0dSopenharmony_ci-int ActM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2273be168c0dSopenharmony_ci- auto act_param = reinterpret_cast<const ActivationParameter *>(OpParam()); 2274be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::ActivationT; 2275be168c0dSopenharmony_ci- if (prim == nullptr) { 2276be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate activation primitive"; 2277be168c0dSopenharmony_ci- return RET_ERROR; 2278be168c0dSopenharmony_ci- } 2279be168c0dSopenharmony_ci- prim->activation_type = static_cast<decltype(prim->activation_type)>(act_param->type_); 2280be168c0dSopenharmony_ci- prim->alpha = act_param->alpha_; 2281be168c0dSopenharmony_ci- prim->min_val = act_param->min_val_; 2282be168c0dSopenharmony_ci- prim->max_val = act_param->max_val_; 2283be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2284be168c0dSopenharmony_ci- return RET_OK; 2285be168c0dSopenharmony_ci-} 2286be168c0dSopenharmony_ci- 2287be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Activation, ReturnNode<ActM>); 2288be168c0dSopenharmony_ci- 2289be168c0dSopenharmony_ci-ActGradM::ActGradM(Node *node) { 2290be168c0dSopenharmony_ci- CloneOpParam<ActivationParameter>(node->OpParam()); 2291be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_ActivationGrad); 2292be168c0dSopenharmony_ci- set_name(node->name() + "/" + kGradName + "/actGrad"); 2293be168c0dSopenharmony_ci-} 2294be168c0dSopenharmony_ci- 2295be168c0dSopenharmony_ci-std::vector<EXPR *> ActGradM::Grad(EXPR *yt) { return {}; } 2296be168c0dSopenharmony_ci- 2297be168c0dSopenharmony_ci-int ActGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2298be168c0dSopenharmony_ci- auto act_param = reinterpret_cast<const ActivationParameter *>(OpParam()); 2299be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::ActivationGradT; 2300be168c0dSopenharmony_ci- if (prim == nullptr) { 2301be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate activation grad primitive"; 2302be168c0dSopenharmony_ci- return RET_ERROR; 2303be168c0dSopenharmony_ci- } 2304be168c0dSopenharmony_ci- prim->activation_type = static_cast<decltype(prim->activation_type)>(act_param->type_); 2305be168c0dSopenharmony_ci- prim->alpha = act_param->alpha_; 2306be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2307be168c0dSopenharmony_ci- return RET_OK; 2308be168c0dSopenharmony_ci-} 2309be168c0dSopenharmony_ci- 2310be168c0dSopenharmony_ci-static ImportReg regGrad(schema::PrimitiveType_ActivationGrad, ReturnNode<ActGradM>); 2311be168c0dSopenharmony_ci-namespace NN { 2312be168c0dSopenharmony_ci-Node *ReLU6() { 2313be168c0dSopenharmony_ci- auto r = new (std::nothrow) ActM(schema::ActivationType_RELU6); 2314be168c0dSopenharmony_ci- if (r == nullptr) { 2315be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate relu6"; 2316be168c0dSopenharmony_ci- return nullptr; 2317be168c0dSopenharmony_ci- } 2318be168c0dSopenharmony_ci- r->set_name(Node::UniqueName("ReLU6")); 2319be168c0dSopenharmony_ci- return r; 2320be168c0dSopenharmony_ci-} 2321be168c0dSopenharmony_ci-Node *Sigmoid() { 2322be168c0dSopenharmony_ci- auto s = new (std::nothrow) ActM(schema::ActivationType_SIGMOID); 2323be168c0dSopenharmony_ci- if (s == nullptr) { 2324be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate sigmoid"; 2325be168c0dSopenharmony_ci- return nullptr; 2326be168c0dSopenharmony_ci- } 2327be168c0dSopenharmony_ci- s->set_name(Node::UniqueName("Sigmoid")); 2328be168c0dSopenharmony_ci- return s; 2329be168c0dSopenharmony_ci-} 2330be168c0dSopenharmony_ci-Node *Relu() { 2331be168c0dSopenharmony_ci- auto r = new (std::nothrow) ActM(schema::ActivationType_RELU); 2332be168c0dSopenharmony_ci- if (r == nullptr) { 2333be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate relu"; 2334be168c0dSopenharmony_ci- return nullptr; 2335be168c0dSopenharmony_ci- } 2336be168c0dSopenharmony_ci- r->set_name(r->UniqueName("Relu")); 2337be168c0dSopenharmony_ci- return r; 2338be168c0dSopenharmony_ci-} 2339be168c0dSopenharmony_ci-} // namespace NN 2340be168c0dSopenharmony_ci-} // namespace lite 2341be168c0dSopenharmony_ci-} // namespace mindspore 2342be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/activation.h b/mindspore/lite/src/expression/ops/activation.h 2343be168c0dSopenharmony_cideleted file mode 100644 2344be168c0dSopenharmony_ciindex 14271c09..00000000 2345be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/activation.h 2346be168c0dSopenharmony_ci+++ /dev/null 2347be168c0dSopenharmony_ci@@ -1,44 +0,0 @@ 2348be168c0dSopenharmony_ci-/** 2349be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2350be168c0dSopenharmony_ci- * 2351be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2352be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2353be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2354be168c0dSopenharmony_ci- * 2355be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2356be168c0dSopenharmony_ci- * 2357be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2358be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2359be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2360be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2361be168c0dSopenharmony_ci- * limitations under the License. 2362be168c0dSopenharmony_ci- */ 2363be168c0dSopenharmony_ci- 2364be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_ 2365be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_ 2366be168c0dSopenharmony_ci- 2367be168c0dSopenharmony_ci-#include <vector> 2368be168c0dSopenharmony_ci-#include <memory> 2369be168c0dSopenharmony_ci-#include "src/expression/net.h" 2370be168c0dSopenharmony_ci-#include "inner/model_generated.h" 2371be168c0dSopenharmony_ci- 2372be168c0dSopenharmony_ci-namespace mindspore { 2373be168c0dSopenharmony_ci-namespace lite { 2374be168c0dSopenharmony_ci-class ActM : public Node { 2375be168c0dSopenharmony_ci- public: 2376be168c0dSopenharmony_ci- ActM() = default; 2377be168c0dSopenharmony_ci- explicit ActM(schema::ActivationType type); 2378be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 2379be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2380be168c0dSopenharmony_ci-}; 2381be168c0dSopenharmony_ci- 2382be168c0dSopenharmony_ci-class ActGradM : public Node { 2383be168c0dSopenharmony_ci- public: 2384be168c0dSopenharmony_ci- ActGradM() : Node() {} // for Import 2385be168c0dSopenharmony_ci- explicit ActGradM(Node *act); // for Grad 2386be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 2387be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2388be168c0dSopenharmony_ci-}; 2389be168c0dSopenharmony_ci-} // namespace lite 2390be168c0dSopenharmony_ci-} // namespace mindspore 2391be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ACTIVATION_H_ 2392be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/adam.cc b/mindspore/lite/src/expression/ops/adam.cc 2393be168c0dSopenharmony_cideleted file mode 100644 2394be168c0dSopenharmony_ciindex 7fcfdc74..00000000 2395be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/adam.cc 2396be168c0dSopenharmony_ci+++ /dev/null 2397be168c0dSopenharmony_ci@@ -1,142 +0,0 @@ 2398be168c0dSopenharmony_ci-/** 2399be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2400be168c0dSopenharmony_ci- * 2401be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2402be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2403be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2404be168c0dSopenharmony_ci- * 2405be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2406be168c0dSopenharmony_ci- * 2407be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2408be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2409be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2410be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2411be168c0dSopenharmony_ci- * limitations under the License. 2412be168c0dSopenharmony_ci- */ 2413be168c0dSopenharmony_ci- 2414be168c0dSopenharmony_ci-#include "src/expression/ops/adam.h" 2415be168c0dSopenharmony_ci-#include <memory> 2416be168c0dSopenharmony_ci-#include <set> 2417be168c0dSopenharmony_ci-#include <utility> 2418be168c0dSopenharmony_ci-#include "src/expression/ops.h" 2419be168c0dSopenharmony_ci-#include "src/expression/ops/assign.h" 2420be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic.h" 2421be168c0dSopenharmony_ci-#include "nnacl/fp32_grad/optimizer.h" 2422be168c0dSopenharmony_ci-#include "include/api/net.h" 2423be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 2424be168c0dSopenharmony_ci- 2425be168c0dSopenharmony_ci-namespace mindspore { 2426be168c0dSopenharmony_ci-namespace NN { 2427be168c0dSopenharmony_ci-Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg) { 2428be168c0dSopenharmony_ci- auto lite_node = lite::NN::Adam(std::move(learn->set_), cfg); 2429be168c0dSopenharmony_ci- return NodeImpl::Connect(lite_node); 2430be168c0dSopenharmony_ci-} 2431be168c0dSopenharmony_ci-} // namespace NN 2432be168c0dSopenharmony_ci- 2433be168c0dSopenharmony_ci-namespace lite { 2434be168c0dSopenharmony_ci-std::vector<EXPR *> AdamM::Clone(EXPR *grad, EXPR *weight) { 2435be168c0dSopenharmony_ci- auto adam = new (std::nothrow) AdamM(); 2436be168c0dSopenharmony_ci- if (adam == nullptr) { 2437be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate adam"; 2438be168c0dSopenharmony_ci- return {}; 2439be168c0dSopenharmony_ci- } 2440be168c0dSopenharmony_ci- adam->set_name("optimizer-Adam"); 2441be168c0dSopenharmony_ci- adam->CloneOpParam<AdamParameter>(OpParam()); 2442be168c0dSopenharmony_ci- adam->update_name(weight->node()->name()); 2443be168c0dSopenharmony_ci- adam->set_primitive(primitive()); 2444be168c0dSopenharmony_ci- adam->expr()->SetSize(C10NUM); 2445be168c0dSopenharmony_ci- // setup weight and momentum 2446be168c0dSopenharmony_ci- adam->expr()->set_params(C0NUM, weight); 2447be168c0dSopenharmony_ci- auto dims = grad->dims(); 2448be168c0dSopenharmony_ci- auto m = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "m"); 2449be168c0dSopenharmony_ci- adam->expr()->set_params(C1NUM, m); 2450be168c0dSopenharmony_ci- auto v = adam->CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::ZEROS, "v"); 2451be168c0dSopenharmony_ci- adam->expr()->set_params(C2NUM, v); 2452be168c0dSopenharmony_ci- // copy parameters 2453be168c0dSopenharmony_ci- for (int i = C3NUM; i < C9NUM; i++) { 2454be168c0dSopenharmony_ci- adam->expr()->set_params(i, this->input(i)); 2455be168c0dSopenharmony_ci- } 2456be168c0dSopenharmony_ci- adam->expr()->set_params(C9NUM, grad); 2457be168c0dSopenharmony_ci- return (*adam)(adam->inputs()); 2458be168c0dSopenharmony_ci-} 2459be168c0dSopenharmony_ci- 2460be168c0dSopenharmony_ci-AdamM::AdamM(std::set<Node *> &&learn, const AdamConfig &cfg) { 2461be168c0dSopenharmony_ci- auto op_param = reinterpret_cast<AdamParameter *>(malloc(sizeof(AdamParameter))); 2462be168c0dSopenharmony_ci- if (op_param == nullptr) { 2463be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ActivationParameter"; 2464be168c0dSopenharmony_ci- return; 2465be168c0dSopenharmony_ci- } 2466be168c0dSopenharmony_ci- AssignLearn(std::move(learn)); 2467be168c0dSopenharmony_ci- memset(op_param, 0, sizeof(AdamParameter)); 2468be168c0dSopenharmony_ci- op_param->use_nesterov_ = cfg.use_nesterov_; 2469be168c0dSopenharmony_ci- SetOpParam(op_param); 2470be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Adam); 2471be168c0dSopenharmony_ci- set_name("optimizer-Adam"); 2472be168c0dSopenharmony_ci- // Adam Network 2473be168c0dSopenharmony_ci- expr()->SetSize(C10NUM); 2474be168c0dSopenharmony_ci- auto assign1 = new (std::nothrow) AssignM(0); 2475be168c0dSopenharmony_ci- if (assign1 == nullptr) { 2476be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate assign"; 2477be168c0dSopenharmony_ci- return; 2478be168c0dSopenharmony_ci- } 2479be168c0dSopenharmony_ci- PushOp(assign1); 2480be168c0dSopenharmony_ci- auto assign2 = new (std::nothrow) AssignM(0); 2481be168c0dSopenharmony_ci- if (assign2 == nullptr) { 2482be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate assign"; 2483be168c0dSopenharmony_ci- return; 2484be168c0dSopenharmony_ci- } 2485be168c0dSopenharmony_ci- PushOp(assign2); 2486be168c0dSopenharmony_ci- auto mul1 = NN::Mul(); 2487be168c0dSopenharmony_ci- if (mul1 == nullptr) { 2488be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate mul"; 2489be168c0dSopenharmony_ci- return; 2490be168c0dSopenharmony_ci- } 2491be168c0dSopenharmony_ci- PushOp(mul1); 2492be168c0dSopenharmony_ci- auto mul2 = NN::Mul(); 2493be168c0dSopenharmony_ci- if (mul2 == nullptr) { 2494be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate mul"; 2495be168c0dSopenharmony_ci- return; 2496be168c0dSopenharmony_ci- } 2497be168c0dSopenharmony_ci- PushOp(mul2); 2498be168c0dSopenharmony_ci- auto tmp = 1.0f; 2499be168c0dSopenharmony_ci- mul1->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-power", &tmp); 2500be168c0dSopenharmony_ci- mul1->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta1-data", &cfg.beta1_); 2501be168c0dSopenharmony_ci- auto o1 = (*mul1)({}); 2502be168c0dSopenharmony_ci- assign1_ = (*assign1)({mul1->input(0), o1.front()}).front(); 2503be168c0dSopenharmony_ci- mul2->CreateConstTensor(C0NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-power", &tmp); 2504be168c0dSopenharmony_ci- mul2->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "beta2-data", &cfg.beta2_); 2505be168c0dSopenharmony_ci- auto o2 = (*mul2)({}); 2506be168c0dSopenharmony_ci- assign2_ = (*assign2)({mul2->input(0), o2.front()}).front(); 2507be168c0dSopenharmony_ci- expr()->set_params(C3NUM, o1.front()); 2508be168c0dSopenharmony_ci- expr()->set_params(C4NUM, o2.front()); 2509be168c0dSopenharmony_ci- CreateConstTensor(C5NUM, {1}, kNumberTypeFloat32, KHWC, "learning-rate", &cfg.learning_rate_); 2510be168c0dSopenharmony_ci- CreateConstTensor(C6NUM, {1}, kNumberTypeFloat32, KHWC, "beta1", &cfg.beta1_); 2511be168c0dSopenharmony_ci- CreateConstTensor(C7NUM, {1}, kNumberTypeFloat32, KHWC, "beta2", &cfg.beta2_); 2512be168c0dSopenharmony_ci- CreateConstTensor(C8NUM, {1}, kNumberTypeFloat32, KHWC, "epsilon", &cfg.eps_); 2513be168c0dSopenharmony_ci-} 2514be168c0dSopenharmony_ci- 2515be168c0dSopenharmony_ci-int AdamM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2516be168c0dSopenharmony_ci- auto param = reinterpret_cast<const AdamParameter *>(OpParam()); 2517be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::AdamT; 2518be168c0dSopenharmony_ci- if (prim == nullptr) { 2519be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate " << cnode->name; 2520be168c0dSopenharmony_ci- return RET_ERROR; 2521be168c0dSopenharmony_ci- } 2522be168c0dSopenharmony_ci- prim->use_nesterov = param->use_nesterov_; 2523be168c0dSopenharmony_ci- prim->use_locking = false; 2524be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2525be168c0dSopenharmony_ci- return RET_OK; 2526be168c0dSopenharmony_ci-} 2527be168c0dSopenharmony_ci- 2528be168c0dSopenharmony_ci-namespace NN { 2529be168c0dSopenharmony_ci-Node *Adam(std::set<Node *> &&learn, const AdamConfig &cfg) { 2530be168c0dSopenharmony_ci- auto a = new (std::nothrow) AdamM(std::move(learn), cfg); 2531be168c0dSopenharmony_ci- if (a == nullptr) { 2532be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate adam"; 2533be168c0dSopenharmony_ci- return nullptr; 2534be168c0dSopenharmony_ci- } 2535be168c0dSopenharmony_ci- return a; 2536be168c0dSopenharmony_ci-} 2537be168c0dSopenharmony_ci-} // namespace NN 2538be168c0dSopenharmony_ci-} // namespace lite 2539be168c0dSopenharmony_ci-} // namespace mindspore 2540be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/adam.h b/mindspore/lite/src/expression/ops/adam.h 2541be168c0dSopenharmony_cideleted file mode 100644 2542be168c0dSopenharmony_ciindex 58e44def..00000000 2543be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/adam.h 2544be168c0dSopenharmony_ci+++ /dev/null 2545be168c0dSopenharmony_ci@@ -1,46 +0,0 @@ 2546be168c0dSopenharmony_ci-/** 2547be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2548be168c0dSopenharmony_ci- * 2549be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2550be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2551be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2552be168c0dSopenharmony_ci- * 2553be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2554be168c0dSopenharmony_ci- * 2555be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2556be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2557be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2558be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2559be168c0dSopenharmony_ci- * limitations under the License. 2560be168c0dSopenharmony_ci- */ 2561be168c0dSopenharmony_ci- 2562be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_ 2563be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_ 2564be168c0dSopenharmony_ci- 2565be168c0dSopenharmony_ci-#include <vector> 2566be168c0dSopenharmony_ci-#include <set> 2567be168c0dSopenharmony_ci-#include <memory> 2568be168c0dSopenharmony_ci-#include "include/api/net.h" 2569be168c0dSopenharmony_ci-#include "src/expression/net.h" 2570be168c0dSopenharmony_ci-#include "inner/model_generated.h" 2571be168c0dSopenharmony_ci- 2572be168c0dSopenharmony_ci-namespace mindspore { 2573be168c0dSopenharmony_ci-namespace lite { 2574be168c0dSopenharmony_ci-class AdamM : public Node { 2575be168c0dSopenharmony_ci- public: 2576be168c0dSopenharmony_ci- AdamM() = default; 2577be168c0dSopenharmony_ci- AdamM(std::set<Node *> &&learn, const AdamConfig &cfg); 2578be168c0dSopenharmony_ci- std::vector<EXPR *> Clone(EXPR *grad, EXPR *weight) override; 2579be168c0dSopenharmony_ci- void AddNetOutput(std::vector<EXPR *> *output) override { 2580be168c0dSopenharmony_ci- output->push_back(assign1_); 2581be168c0dSopenharmony_ci- output->push_back(assign2_); 2582be168c0dSopenharmony_ci- } 2583be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2584be168c0dSopenharmony_ci- 2585be168c0dSopenharmony_ci- private: 2586be168c0dSopenharmony_ci- EXPR *assign1_{nullptr}; 2587be168c0dSopenharmony_ci- EXPR *assign2_{nullptr}; 2588be168c0dSopenharmony_ci-}; 2589be168c0dSopenharmony_ci-} // namespace lite 2590be168c0dSopenharmony_ci-} // namespace mindspore 2591be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADAM_H_ 2592be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/addn.cc b/mindspore/lite/src/expression/ops/addn.cc 2593be168c0dSopenharmony_cideleted file mode 100644 2594be168c0dSopenharmony_ciindex bd614f1d..00000000 2595be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/addn.cc 2596be168c0dSopenharmony_ci+++ /dev/null 2597be168c0dSopenharmony_ci@@ -1,42 +0,0 @@ 2598be168c0dSopenharmony_ci-/** 2599be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2600be168c0dSopenharmony_ci- * 2601be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2602be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2603be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2604be168c0dSopenharmony_ci- * 2605be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2606be168c0dSopenharmony_ci- * 2607be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2608be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2609be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2610be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2611be168c0dSopenharmony_ci- * limitations under the License. 2612be168c0dSopenharmony_ci- */ 2613be168c0dSopenharmony_ci- 2614be168c0dSopenharmony_ci-#include "src/expression/ops/addn.h" 2615be168c0dSopenharmony_ci- 2616be168c0dSopenharmony_ci-namespace mindspore { 2617be168c0dSopenharmony_ci-namespace lite { 2618be168c0dSopenharmony_ci-AddN::AddN(int dummy) { 2619be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(OpParameter)); 2620be168c0dSopenharmony_ci- if (op_param == nullptr) { 2621be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter "; 2622be168c0dSopenharmony_ci- return; 2623be168c0dSopenharmony_ci- } 2624be168c0dSopenharmony_ci- set_name(UniqueName("addN")); 2625be168c0dSopenharmony_ci- SetOpParam(op_param); 2626be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_AddN); 2627be168c0dSopenharmony_ci-} 2628be168c0dSopenharmony_ci- 2629be168c0dSopenharmony_ci-int AddN::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2630be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::AddNT; 2631be168c0dSopenharmony_ci- if (prim == nullptr) { 2632be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 2633be168c0dSopenharmony_ci- return RET_ERROR; 2634be168c0dSopenharmony_ci- } 2635be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2636be168c0dSopenharmony_ci- return RET_OK; 2637be168c0dSopenharmony_ci-} 2638be168c0dSopenharmony_ci-} // namespace lite 2639be168c0dSopenharmony_ci-} // namespace mindspore 2640be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/addn.h b/mindspore/lite/src/expression/ops/addn.h 2641be168c0dSopenharmony_cideleted file mode 100644 2642be168c0dSopenharmony_ciindex 3ed96319..00000000 2643be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/addn.h 2644be168c0dSopenharmony_ci+++ /dev/null 2645be168c0dSopenharmony_ci@@ -1,34 +0,0 @@ 2646be168c0dSopenharmony_ci-/** 2647be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2648be168c0dSopenharmony_ci- * 2649be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2650be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2651be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2652be168c0dSopenharmony_ci- * 2653be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2654be168c0dSopenharmony_ci- * 2655be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2656be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2657be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2658be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2659be168c0dSopenharmony_ci- * limitations under the License. 2660be168c0dSopenharmony_ci- */ 2661be168c0dSopenharmony_ci- 2662be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_ 2663be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_ 2664be168c0dSopenharmony_ci- 2665be168c0dSopenharmony_ci-#include <memory> 2666be168c0dSopenharmony_ci-#include "src/expression/node.h" 2667be168c0dSopenharmony_ci-#include "inner/model_generated.h" 2668be168c0dSopenharmony_ci- 2669be168c0dSopenharmony_ci-namespace mindspore { 2670be168c0dSopenharmony_ci-namespace lite { 2671be168c0dSopenharmony_ci-class AddN : public Node { 2672be168c0dSopenharmony_ci- public: 2673be168c0dSopenharmony_ci- explicit AddN(int dummy); 2674be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2675be168c0dSopenharmony_ci-}; 2676be168c0dSopenharmony_ci-} // namespace lite 2677be168c0dSopenharmony_ci-} // namespace mindspore 2678be168c0dSopenharmony_ci- 2679be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ADDN_H_ 2680be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/arithmetic.cc b/mindspore/lite/src/expression/ops/arithmetic.cc 2681be168c0dSopenharmony_cideleted file mode 100644 2682be168c0dSopenharmony_ciindex 89af36d6..00000000 2683be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/arithmetic.cc 2684be168c0dSopenharmony_ci+++ /dev/null 2685be168c0dSopenharmony_ci@@ -1,223 +0,0 @@ 2686be168c0dSopenharmony_ci-/** 2687be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2688be168c0dSopenharmony_ci- * 2689be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2690be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2691be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2692be168c0dSopenharmony_ci- * 2693be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2694be168c0dSopenharmony_ci- * 2695be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2696be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2697be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2698be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2699be168c0dSopenharmony_ci- * limitations under the License. 2700be168c0dSopenharmony_ci- */ 2701be168c0dSopenharmony_ci- 2702be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic.h" 2703be168c0dSopenharmony_ci-#include <memory> 2704be168c0dSopenharmony_ci-#include "src/expression/ops/reduce.h" 2705be168c0dSopenharmony_ci-#include "src/expression/ops/reshape.h" 2706be168c0dSopenharmony_ci-#include "src/expression/ops_utils.h" 2707be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic_self.h" 2708be168c0dSopenharmony_ci-#include "src/expression/ops.h" 2709be168c0dSopenharmony_ci-#include "nnacl/arithmetic_parameter.h" 2710be168c0dSopenharmony_ci-#include "src/expression/import.h" 2711be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 2712be168c0dSopenharmony_ci- 2713be168c0dSopenharmony_ci-namespace mindspore { 2714be168c0dSopenharmony_ci-namespace lite { 2715be168c0dSopenharmony_ci-// Common Arithmetic Functionality 2716be168c0dSopenharmony_ci-ArithmeticM::ArithmeticM(schema::PrimitiveType type) : Node() { 2717be168c0dSopenharmony_ci- auto op_param = malloc(sizeof(ArithmeticParameter)); 2718be168c0dSopenharmony_ci- if (op_param == nullptr) { 2719be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ActivationParameter"; 2720be168c0dSopenharmony_ci- return; 2721be168c0dSopenharmony_ci- } 2722be168c0dSopenharmony_ci- SetOpParam(op_param); 2723be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 2724be168c0dSopenharmony_ci- set_primitive(type); 2725be168c0dSopenharmony_ci-} 2726be168c0dSopenharmony_ci- 2727be168c0dSopenharmony_ci-std::vector<EXPR *> ArithmeticM::binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy) { 2728be168c0dSopenharmony_ci- auto shape_of_x = x->dims(); 2729be168c0dSopenharmony_ci- auto shape_of_y = y->dims(); 2730be168c0dSopenharmony_ci- auto reduce_dx = dx; 2731be168c0dSopenharmony_ci- auto reduce_dy = dy; 2732be168c0dSopenharmony_ci- auto rx = (BroadcastGradientArgs(shape_of_x, shape_of_y))(); 2733be168c0dSopenharmony_ci- if (rx[0].size()) { 2734be168c0dSopenharmony_ci- auto reduce_sum = NN::ReduceSum(false, rx[0]); 2735be168c0dSopenharmony_ci- PushOp(reduce_sum); 2736be168c0dSopenharmony_ci- reduce_dx = (*reduce_sum)({reduce_dx}).front(); 2737be168c0dSopenharmony_ci- auto reshape = NN::Reshape(shape_of_x); 2738be168c0dSopenharmony_ci- PushOp(reshape); 2739be168c0dSopenharmony_ci- reduce_dx = (*reshape)({reduce_dx}).front(); 2740be168c0dSopenharmony_ci- } 2741be168c0dSopenharmony_ci- if (rx[1].size()) { 2742be168c0dSopenharmony_ci- auto reduce_sum = NN::ReduceSum(false, rx[1]); 2743be168c0dSopenharmony_ci- PushOp(reduce_sum); 2744be168c0dSopenharmony_ci- reduce_dy = (*reduce_sum)({reduce_dy}).front(); 2745be168c0dSopenharmony_ci- auto reshape = NN::Reshape(shape_of_y); 2746be168c0dSopenharmony_ci- PushOp(reshape); 2747be168c0dSopenharmony_ci- reduce_dy = (*reshape)({reduce_dy}).front(); 2748be168c0dSopenharmony_ci- } 2749be168c0dSopenharmony_ci- std::vector<EXPR *> out = {reduce_dx, reduce_dy}; 2750be168c0dSopenharmony_ci- return out; 2751be168c0dSopenharmony_ci-} 2752be168c0dSopenharmony_ci- 2753be168c0dSopenharmony_ci-// Add Op 2754be168c0dSopenharmony_ci-AddM::AddM(int dummy) : ArithmeticM(schema::PrimitiveType_AddFusion) { set_name(UniqueName("Add")); } 2755be168c0dSopenharmony_ci- 2756be168c0dSopenharmony_ci-std::vector<EXPR *> AddM::Grad(EXPR *yt) { return binop_grad_common(input(0), input(1), yt, yt); } 2757be168c0dSopenharmony_ci- 2758be168c0dSopenharmony_ci-int AddM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2759be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::AddFusionT; 2760be168c0dSopenharmony_ci- if (prim == nullptr) { 2761be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate prim"; 2762be168c0dSopenharmony_ci- return RET_ERROR; 2763be168c0dSopenharmony_ci- } 2764be168c0dSopenharmony_ci- prim->activation_type = schema::ActivationType_NO_ACTIVATION; 2765be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2766be168c0dSopenharmony_ci- return RET_OK; 2767be168c0dSopenharmony_ci-} 2768be168c0dSopenharmony_ci- 2769be168c0dSopenharmony_ci-static ImportReg AddReg(schema::PrimitiveType_AddFusion, ReturnNode<AddM>); 2770be168c0dSopenharmony_ci- 2771be168c0dSopenharmony_ci-// Div op 2772be168c0dSopenharmony_ci-DivM::DivM(int dummy) : ArithmeticM(schema::PrimitiveType_RealDiv) { set_name(UniqueName("RealDiv")); } 2773be168c0dSopenharmony_ci-std::vector<EXPR *> DivM::Grad(EXPR *yt) { 2774be168c0dSopenharmony_ci- auto x = input(0); 2775be168c0dSopenharmony_ci- auto y = input(1); 2776be168c0dSopenharmony_ci- auto o = output(0); 2777be168c0dSopenharmony_ci- auto div_op = NN::Div(); 2778be168c0dSopenharmony_ci- if (div_op == nullptr) { 2779be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate div_op"; 2780be168c0dSopenharmony_ci- return {}; 2781be168c0dSopenharmony_ci- } 2782be168c0dSopenharmony_ci- PushOp(div_op); 2783be168c0dSopenharmony_ci- auto neg_op = NN::Neg(); 2784be168c0dSopenharmony_ci- if (neg_op == nullptr) { 2785be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate neg_op"; 2786be168c0dSopenharmony_ci- return {}; 2787be168c0dSopenharmony_ci- } 2788be168c0dSopenharmony_ci- PushOp(neg_op); 2789be168c0dSopenharmony_ci- auto mul_op = NN::Mul(); 2790be168c0dSopenharmony_ci- if (mul_op == nullptr) { 2791be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate mul_op"; 2792be168c0dSopenharmony_ci- return {}; 2793be168c0dSopenharmony_ci- } 2794be168c0dSopenharmony_ci- PushOp(mul_op); 2795be168c0dSopenharmony_ci- auto bc_x = (*div_op)({yt, y}).front(); 2796be168c0dSopenharmony_ci- auto bc_y = (*neg_op)((*mul_op)({bc_x, o})).front(); 2797be168c0dSopenharmony_ci- return binop_grad_common(x, y, bc_x, bc_y); 2798be168c0dSopenharmony_ci-} 2799be168c0dSopenharmony_ci-int DivM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2800be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::RealDivT; 2801be168c0dSopenharmony_ci- if (prim == nullptr) { 2802be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 2803be168c0dSopenharmony_ci- return RET_ERROR; 2804be168c0dSopenharmony_ci- } 2805be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2806be168c0dSopenharmony_ci- return RET_OK; 2807be168c0dSopenharmony_ci-} 2808be168c0dSopenharmony_ci-static ImportReg DivReg(schema::PrimitiveType_DivFusion, ReturnNode<DivM>); 2809be168c0dSopenharmony_ci- 2810be168c0dSopenharmony_ci-// Mul op 2811be168c0dSopenharmony_ci-MulM::MulM(int dummy) : ArithmeticM(schema::PrimitiveType_MulFusion) { set_name(UniqueName("Mul")); } 2812be168c0dSopenharmony_ci- 2813be168c0dSopenharmony_ci-std::vector<EXPR *> MulM::Grad(EXPR *yt) { 2814be168c0dSopenharmony_ci- auto mul_dx = NN::Mul(); 2815be168c0dSopenharmony_ci- if (mul_dx == nullptr) { 2816be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate mul dx"; 2817be168c0dSopenharmony_ci- return {}; 2818be168c0dSopenharmony_ci- } 2819be168c0dSopenharmony_ci- PushOp(mul_dx); 2820be168c0dSopenharmony_ci- auto mul_dy = NN::Mul(); 2821be168c0dSopenharmony_ci- if (mul_dy == nullptr) { 2822be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate mul_dy"; 2823be168c0dSopenharmony_ci- return {}; 2824be168c0dSopenharmony_ci- } 2825be168c0dSopenharmony_ci- PushOp(mul_dy); 2826be168c0dSopenharmony_ci- auto x = input(0); 2827be168c0dSopenharmony_ci- auto y = input(1); 2828be168c0dSopenharmony_ci- auto bc_dx = (*mul_dx)({y, yt}).front(); 2829be168c0dSopenharmony_ci- auto bc_dy = (*mul_dy)({x, yt}).front(); 2830be168c0dSopenharmony_ci- return binop_grad_common(x, y, bc_dx, bc_dy); 2831be168c0dSopenharmony_ci-} 2832be168c0dSopenharmony_ci- 2833be168c0dSopenharmony_ci-int MulM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2834be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::MulFusionT; 2835be168c0dSopenharmony_ci- if (prim == nullptr) { 2836be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate prim"; 2837be168c0dSopenharmony_ci- return RET_ERROR; 2838be168c0dSopenharmony_ci- } 2839be168c0dSopenharmony_ci- prim->activation_type = schema::ActivationType_NO_ACTIVATION; 2840be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2841be168c0dSopenharmony_ci- return RET_OK; 2842be168c0dSopenharmony_ci-} 2843be168c0dSopenharmony_ci-static ImportReg MulReg(schema::PrimitiveType_MulFusion, ReturnNode<MulM>); 2844be168c0dSopenharmony_ci- 2845be168c0dSopenharmony_ci-// Sub op 2846be168c0dSopenharmony_ci-SubM::SubM(int dummy) : ArithmeticM(schema::PrimitiveType_SubFusion) { set_name(UniqueName("Sub")); } 2847be168c0dSopenharmony_ci- 2848be168c0dSopenharmony_ci-std::vector<EXPR *> SubM::Grad(EXPR *yt) { 2849be168c0dSopenharmony_ci- auto x = input(0); 2850be168c0dSopenharmony_ci- auto y = input(1); 2851be168c0dSopenharmony_ci- auto neg = NN::Neg(); 2852be168c0dSopenharmony_ci- if (neg == nullptr) { 2853be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate neg"; 2854be168c0dSopenharmony_ci- return {}; 2855be168c0dSopenharmony_ci- } 2856be168c0dSopenharmony_ci- PushOp(neg); 2857be168c0dSopenharmony_ci- auto neg_grad = (*neg)({yt}).front(); 2858be168c0dSopenharmony_ci- return binop_grad_common(x, y, yt, neg_grad); 2859be168c0dSopenharmony_ci-} 2860be168c0dSopenharmony_ci-int SubM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 2861be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::SubFusionT; 2862be168c0dSopenharmony_ci- if (prim == nullptr) { 2863be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate prim"; 2864be168c0dSopenharmony_ci- return RET_ERROR; 2865be168c0dSopenharmony_ci- } 2866be168c0dSopenharmony_ci- prim->activation_type = schema::ActivationType_NO_ACTIVATION; 2867be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 2868be168c0dSopenharmony_ci- return RET_OK; 2869be168c0dSopenharmony_ci-} 2870be168c0dSopenharmony_ci-static ImportReg SubReg(schema::PrimitiveType_SubFusion, ReturnNode<SubM>); 2871be168c0dSopenharmony_ci- 2872be168c0dSopenharmony_ci-namespace NN { 2873be168c0dSopenharmony_ci-Node *Add() { 2874be168c0dSopenharmony_ci- auto a = new (std::nothrow) AddM(0); 2875be168c0dSopenharmony_ci- if (a == nullptr) { 2876be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate a"; 2877be168c0dSopenharmony_ci- return nullptr; 2878be168c0dSopenharmony_ci- } 2879be168c0dSopenharmony_ci- return a; 2880be168c0dSopenharmony_ci-} 2881be168c0dSopenharmony_ci-Node *Sub() { 2882be168c0dSopenharmony_ci- auto a = new (std::nothrow) SubM(0); 2883be168c0dSopenharmony_ci- if (a == nullptr) { 2884be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate a"; 2885be168c0dSopenharmony_ci- return nullptr; 2886be168c0dSopenharmony_ci- } 2887be168c0dSopenharmony_ci- return a; 2888be168c0dSopenharmony_ci-} 2889be168c0dSopenharmony_ci- 2890be168c0dSopenharmony_ci-Node *Mul() { 2891be168c0dSopenharmony_ci- auto a = new (std::nothrow) MulM(0); 2892be168c0dSopenharmony_ci- if (a == nullptr) { 2893be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate a"; 2894be168c0dSopenharmony_ci- return nullptr; 2895be168c0dSopenharmony_ci- } 2896be168c0dSopenharmony_ci- return a; 2897be168c0dSopenharmony_ci-} 2898be168c0dSopenharmony_ci-Node *Div() { 2899be168c0dSopenharmony_ci- auto a = new (std::nothrow) DivM(0); 2900be168c0dSopenharmony_ci- if (a == nullptr) { 2901be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate a"; 2902be168c0dSopenharmony_ci- return nullptr; 2903be168c0dSopenharmony_ci- } 2904be168c0dSopenharmony_ci- return a; 2905be168c0dSopenharmony_ci-} 2906be168c0dSopenharmony_ci-} // namespace NN 2907be168c0dSopenharmony_ci-} // namespace lite 2908be168c0dSopenharmony_ci-} // namespace mindspore 2909be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/arithmetic.h b/mindspore/lite/src/expression/ops/arithmetic.h 2910be168c0dSopenharmony_cideleted file mode 100644 2911be168c0dSopenharmony_ciindex b1509245..00000000 2912be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/arithmetic.h 2913be168c0dSopenharmony_ci+++ /dev/null 2914be168c0dSopenharmony_ci@@ -1,75 +0,0 @@ 2915be168c0dSopenharmony_ci-/** 2916be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2917be168c0dSopenharmony_ci- * 2918be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 2919be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 2920be168c0dSopenharmony_ci- * You may obtain a copy of the License at 2921be168c0dSopenharmony_ci- * 2922be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 2923be168c0dSopenharmony_ci- * 2924be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 2925be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 2926be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 2927be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 2928be168c0dSopenharmony_ci- * limitations under the License. 2929be168c0dSopenharmony_ci- */ 2930be168c0dSopenharmony_ci- 2931be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_ 2932be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_ 2933be168c0dSopenharmony_ci- 2934be168c0dSopenharmony_ci-#include <vector> 2935be168c0dSopenharmony_ci-#include <memory> 2936be168c0dSopenharmony_ci-#include "src/expression/net.h" 2937be168c0dSopenharmony_ci-#include "inner/model_generated.h" 2938be168c0dSopenharmony_ci- 2939be168c0dSopenharmony_ci-namespace mindspore { 2940be168c0dSopenharmony_ci-namespace lite { 2941be168c0dSopenharmony_ci-class ArithmeticM : public Node { 2942be168c0dSopenharmony_ci- public: 2943be168c0dSopenharmony_ci- ArithmeticM() = default; 2944be168c0dSopenharmony_ci- explicit ArithmeticM(schema::PrimitiveType type); 2945be168c0dSopenharmony_ci- 2946be168c0dSopenharmony_ci- protected: 2947be168c0dSopenharmony_ci- std::vector<EXPR *> binop_grad_common(EXPR *x, EXPR *y, EXPR *dx, EXPR *dy); 2948be168c0dSopenharmony_ci-}; 2949be168c0dSopenharmony_ci- 2950be168c0dSopenharmony_ci-class AddM : public ArithmeticM { 2951be168c0dSopenharmony_ci- public: 2952be168c0dSopenharmony_ci- AddM() = default; 2953be168c0dSopenharmony_ci- explicit AddM(int dummy); 2954be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 2955be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2956be168c0dSopenharmony_ci-}; 2957be168c0dSopenharmony_ci- 2958be168c0dSopenharmony_ci-class DivM : public ArithmeticM { 2959be168c0dSopenharmony_ci- public: 2960be168c0dSopenharmony_ci- DivM() = default; 2961be168c0dSopenharmony_ci- explicit DivM(int dummy); 2962be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 2963be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2964be168c0dSopenharmony_ci-}; 2965be168c0dSopenharmony_ci- 2966be168c0dSopenharmony_ci-class MulM : public ArithmeticM { 2967be168c0dSopenharmony_ci- public: 2968be168c0dSopenharmony_ci- MulM() = default; 2969be168c0dSopenharmony_ci- explicit MulM(int dummy); 2970be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 2971be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2972be168c0dSopenharmony_ci-}; 2973be168c0dSopenharmony_ci- 2974be168c0dSopenharmony_ci-class SubM : public ArithmeticM { 2975be168c0dSopenharmony_ci- public: 2976be168c0dSopenharmony_ci- SubM() = default; 2977be168c0dSopenharmony_ci- explicit SubM(int dummy); 2978be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 2979be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 2980be168c0dSopenharmony_ci-}; 2981be168c0dSopenharmony_ci-namespace NN { 2982be168c0dSopenharmony_ci-Node *Add(); 2983be168c0dSopenharmony_ci-Node *Sub(); 2984be168c0dSopenharmony_ci-Node *Mul(); 2985be168c0dSopenharmony_ci-Node *Div(); 2986be168c0dSopenharmony_ci-} // namespace NN 2987be168c0dSopenharmony_ci-} // namespace lite 2988be168c0dSopenharmony_ci-} // namespace mindspore 2989be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_H_ 2990be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/arithmetic_self.cc b/mindspore/lite/src/expression/ops/arithmetic_self.cc 2991be168c0dSopenharmony_cideleted file mode 100644 2992be168c0dSopenharmony_ciindex e5a84f75..00000000 2993be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/arithmetic_self.cc 2994be168c0dSopenharmony_ci+++ /dev/null 2995be168c0dSopenharmony_ci@@ -1,72 +0,0 @@ 2996be168c0dSopenharmony_ci-/** 2997be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 2998be168c0dSopenharmony_ci- * 2999be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3000be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3001be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3002be168c0dSopenharmony_ci- * 3003be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3004be168c0dSopenharmony_ci- * 3005be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3006be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3007be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3008be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3009be168c0dSopenharmony_ci- * limitations under the License. 3010be168c0dSopenharmony_ci- */ 3011be168c0dSopenharmony_ci- 3012be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic_self.h" 3013be168c0dSopenharmony_ci-#include <memory> 3014be168c0dSopenharmony_ci-#include "src/expression/ops_utils.h" 3015be168c0dSopenharmony_ci-#include "src/expression/ops.h" 3016be168c0dSopenharmony_ci-#include "nnacl/arithmetic_self_parameter.h" 3017be168c0dSopenharmony_ci-#include "src/expression/import.h" 3018be168c0dSopenharmony_ci- 3019be168c0dSopenharmony_ci-namespace mindspore { 3020be168c0dSopenharmony_ci-namespace lite { 3021be168c0dSopenharmony_ci-// Common Arithmetic Self Functionality 3022be168c0dSopenharmony_ci-ArithmeticSelfM::ArithmeticSelfM(schema::PrimitiveType type) : Node() { 3023be168c0dSopenharmony_ci- auto op_param = malloc(sizeof(ArithmeticSelfParameter)); 3024be168c0dSopenharmony_ci- if (op_param == nullptr) { 3025be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ArithmeticSelfParameter"; 3026be168c0dSopenharmony_ci- return; 3027be168c0dSopenharmony_ci- } 3028be168c0dSopenharmony_ci- memset(op_param, 0, sizeof(ArithmeticSelfParameter)); 3029be168c0dSopenharmony_ci- SetOpParam(op_param); 3030be168c0dSopenharmony_ci- expr()->SetSize(C1NUM); 3031be168c0dSopenharmony_ci- set_primitive(type); 3032be168c0dSopenharmony_ci-} 3033be168c0dSopenharmony_ci- 3034be168c0dSopenharmony_ci-// NEG OP 3035be168c0dSopenharmony_ci-NegM::NegM(int dummy) : ArithmeticSelfM(schema::PrimitiveType_NegGrad) { set_name(UniqueName("Neg")); } 3036be168c0dSopenharmony_ci- 3037be168c0dSopenharmony_ci-std::vector<EXPR *> NegM::Grad(EXPR *yt) { 3038be168c0dSopenharmony_ci- auto grad_neg = new (std::nothrow) NegM(0); 3039be168c0dSopenharmony_ci- if (grad_neg == nullptr) { 3040be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate neg gradient"; 3041be168c0dSopenharmony_ci- return {}; 3042be168c0dSopenharmony_ci- } 3043be168c0dSopenharmony_ci- return (*grad_neg)({yt}); 3044be168c0dSopenharmony_ci-} 3045be168c0dSopenharmony_ci- 3046be168c0dSopenharmony_ci-int NegM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3047be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::NegT; 3048be168c0dSopenharmony_ci- if (prim == nullptr) { 3049be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3050be168c0dSopenharmony_ci- return RET_ERROR; 3051be168c0dSopenharmony_ci- } 3052be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3053be168c0dSopenharmony_ci- return RET_OK; 3054be168c0dSopenharmony_ci-} 3055be168c0dSopenharmony_ci- 3056be168c0dSopenharmony_ci-namespace NN { 3057be168c0dSopenharmony_ci-Node *Neg() { 3058be168c0dSopenharmony_ci- auto a = new (std::nothrow) NegM(0); 3059be168c0dSopenharmony_ci- if (a == nullptr) { 3060be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate neg node"; 3061be168c0dSopenharmony_ci- return nullptr; 3062be168c0dSopenharmony_ci- } 3063be168c0dSopenharmony_ci- return a; 3064be168c0dSopenharmony_ci-} 3065be168c0dSopenharmony_ci-} // namespace NN 3066be168c0dSopenharmony_ci-} // namespace lite 3067be168c0dSopenharmony_ci-} // namespace mindspore 3068be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/arithmetic_self.h b/mindspore/lite/src/expression/ops/arithmetic_self.h 3069be168c0dSopenharmony_cideleted file mode 100644 3070be168c0dSopenharmony_ciindex e64ba024..00000000 3071be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/arithmetic_self.h 3072be168c0dSopenharmony_ci+++ /dev/null 3073be168c0dSopenharmony_ci@@ -1,46 +0,0 @@ 3074be168c0dSopenharmony_ci-/** 3075be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3076be168c0dSopenharmony_ci- * 3077be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3078be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3079be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3080be168c0dSopenharmony_ci- * 3081be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3082be168c0dSopenharmony_ci- * 3083be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3084be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3085be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3086be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3087be168c0dSopenharmony_ci- * limitations under the License. 3088be168c0dSopenharmony_ci- */ 3089be168c0dSopenharmony_ci- 3090be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_ 3091be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_ 3092be168c0dSopenharmony_ci- 3093be168c0dSopenharmony_ci-#include <vector> 3094be168c0dSopenharmony_ci-#include <memory> 3095be168c0dSopenharmony_ci-#include "src/expression/net.h" 3096be168c0dSopenharmony_ci-#include "inner/model_generated.h" 3097be168c0dSopenharmony_ci- 3098be168c0dSopenharmony_ci-namespace mindspore { 3099be168c0dSopenharmony_ci-namespace lite { 3100be168c0dSopenharmony_ci-class ArithmeticSelfM : public Node { 3101be168c0dSopenharmony_ci- public: 3102be168c0dSopenharmony_ci- explicit ArithmeticSelfM(schema::PrimitiveType type); 3103be168c0dSopenharmony_ci- 3104be168c0dSopenharmony_ci- protected: 3105be168c0dSopenharmony_ci-}; 3106be168c0dSopenharmony_ci- 3107be168c0dSopenharmony_ci-class NegM : public ArithmeticSelfM { 3108be168c0dSopenharmony_ci- public: 3109be168c0dSopenharmony_ci- explicit NegM(int dummy); 3110be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 3111be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3112be168c0dSopenharmony_ci-}; 3113be168c0dSopenharmony_ci- 3114be168c0dSopenharmony_ci-namespace NN { 3115be168c0dSopenharmony_ci-Node *Neg(); 3116be168c0dSopenharmony_ci-} 3117be168c0dSopenharmony_ci-} // namespace lite 3118be168c0dSopenharmony_ci-} // namespace mindspore 3119be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ARITHMETIC_SELF_H_ 3120be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/assign.cc b/mindspore/lite/src/expression/ops/assign.cc 3121be168c0dSopenharmony_cideleted file mode 100644 3122be168c0dSopenharmony_ciindex acf10950..00000000 3123be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/assign.cc 3124be168c0dSopenharmony_ci+++ /dev/null 3125be168c0dSopenharmony_ci@@ -1,60 +0,0 @@ 3126be168c0dSopenharmony_ci-/** 3127be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3128be168c0dSopenharmony_ci- * 3129be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3130be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3131be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3132be168c0dSopenharmony_ci- * 3133be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3134be168c0dSopenharmony_ci- * 3135be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3136be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3137be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3138be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3139be168c0dSopenharmony_ci- * limitations under the License. 3140be168c0dSopenharmony_ci- */ 3141be168c0dSopenharmony_ci- 3142be168c0dSopenharmony_ci-#include "src/expression/ops/assign.h" 3143be168c0dSopenharmony_ci-#include <memory> 3144be168c0dSopenharmony_ci-#include "nnacl/reshape_parameter.h" 3145be168c0dSopenharmony_ci-#include "src/expression/import.h" 3146be168c0dSopenharmony_ci- 3147be168c0dSopenharmony_ci-namespace mindspore { 3148be168c0dSopenharmony_ci-namespace lite { 3149be168c0dSopenharmony_ci-AssignM::AssignM(int dummy) { 3150be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(OpParameter)); 3151be168c0dSopenharmony_ci- if (op_param == nullptr) { 3152be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ReshapeParameter"; 3153be168c0dSopenharmony_ci- return; 3154be168c0dSopenharmony_ci- } 3155be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 3156be168c0dSopenharmony_ci- SetOpParam(op_param); 3157be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Assign); 3158be168c0dSopenharmony_ci- set_name(UniqueName("Assign")); 3159be168c0dSopenharmony_ci-} 3160be168c0dSopenharmony_ci- 3161be168c0dSopenharmony_ci-int AssignM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3162be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::AssignT; 3163be168c0dSopenharmony_ci- if (prim == nullptr) { 3164be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3165be168c0dSopenharmony_ci- return RET_ERROR; 3166be168c0dSopenharmony_ci- } 3167be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3168be168c0dSopenharmony_ci- return RET_OK; 3169be168c0dSopenharmony_ci-} 3170be168c0dSopenharmony_ci- 3171be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode<AssignM>); 3172be168c0dSopenharmony_ci- 3173be168c0dSopenharmony_ci-namespace NN { 3174be168c0dSopenharmony_ci-Node *Assign() { 3175be168c0dSopenharmony_ci- auto node = new (std::nothrow) AssignM(0); 3176be168c0dSopenharmony_ci- if (node == nullptr) { 3177be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate node"; 3178be168c0dSopenharmony_ci- return nullptr; 3179be168c0dSopenharmony_ci- } 3180be168c0dSopenharmony_ci- node->set_name(Node::UniqueName("Assign")); 3181be168c0dSopenharmony_ci- return node; 3182be168c0dSopenharmony_ci-} 3183be168c0dSopenharmony_ci-} // namespace NN 3184be168c0dSopenharmony_ci-} // namespace lite 3185be168c0dSopenharmony_ci-} // namespace mindspore 3186be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/assign.h b/mindspore/lite/src/expression/ops/assign.h 3187be168c0dSopenharmony_cideleted file mode 100644 3188be168c0dSopenharmony_ciindex 0dfd2c67..00000000 3189be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/assign.h 3190be168c0dSopenharmony_ci+++ /dev/null 3191be168c0dSopenharmony_ci@@ -1,35 +0,0 @@ 3192be168c0dSopenharmony_ci-/** 3193be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3194be168c0dSopenharmony_ci- * 3195be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3196be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3197be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3198be168c0dSopenharmony_ci- * 3199be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3200be168c0dSopenharmony_ci- * 3201be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3202be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3203be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3204be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3205be168c0dSopenharmony_ci- * limitations under the License. 3206be168c0dSopenharmony_ci- */ 3207be168c0dSopenharmony_ci- 3208be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_ 3209be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_ 3210be168c0dSopenharmony_ci- 3211be168c0dSopenharmony_ci-#include <vector> 3212be168c0dSopenharmony_ci-#include <memory> 3213be168c0dSopenharmony_ci-#include "src/expression/net.h" 3214be168c0dSopenharmony_ci-#include "inner/model_generated.h" 3215be168c0dSopenharmony_ci- 3216be168c0dSopenharmony_ci-namespace mindspore { 3217be168c0dSopenharmony_ci-namespace lite { 3218be168c0dSopenharmony_ci-class AssignM : public Node { 3219be168c0dSopenharmony_ci- public: 3220be168c0dSopenharmony_ci- AssignM() = default; 3221be168c0dSopenharmony_ci- explicit AssignM(int dummy); 3222be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3223be168c0dSopenharmony_ci-}; 3224be168c0dSopenharmony_ci-} // namespace lite 3225be168c0dSopenharmony_ci-} // namespace mindspore 3226be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_ASSIGN_H_ 3227be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/batchnorm.cc b/mindspore/lite/src/expression/ops/batchnorm.cc 3228be168c0dSopenharmony_cideleted file mode 100644 3229be168c0dSopenharmony_ciindex ec2bc7b5..00000000 3230be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/batchnorm.cc 3231be168c0dSopenharmony_ci+++ /dev/null 3232be168c0dSopenharmony_ci@@ -1,135 +0,0 @@ 3233be168c0dSopenharmony_ci-/** 3234be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3235be168c0dSopenharmony_ci- * 3236be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3237be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3238be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3239be168c0dSopenharmony_ci- * 3240be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3241be168c0dSopenharmony_ci- * 3242be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3243be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3244be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3245be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3246be168c0dSopenharmony_ci- * limitations under the License. 3247be168c0dSopenharmony_ci- */ 3248be168c0dSopenharmony_ci- 3249be168c0dSopenharmony_ci-#include "src/expression/ops/batchnorm.h" 3250be168c0dSopenharmony_ci-#include <memory> 3251be168c0dSopenharmony_ci-#include "nnacl/batchnorm_parameter.h" 3252be168c0dSopenharmony_ci-#include "nnacl/fp32_grad/batch_norm_grad.h" 3253be168c0dSopenharmony_ci-#include "src/expression/import.h" 3254be168c0dSopenharmony_ci-#include "src/expression/ops.h" 3255be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 3256be168c0dSopenharmony_ci- 3257be168c0dSopenharmony_ci-namespace mindspore { 3258be168c0dSopenharmony_ci-namespace lite { 3259be168c0dSopenharmony_ci-BatchNorm2dM::BatchNorm2dM(int outp, float momentum, float epsilon) { 3260be168c0dSopenharmony_ci- constexpr int bn_inputs = 5; 3261be168c0dSopenharmony_ci- constexpr int bn_outputs = 5; 3262be168c0dSopenharmony_ci- 3263be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(BatchNormParameter)); 3264be168c0dSopenharmony_ci- if (op_param == nullptr) { 3265be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate BatchNormParameter"; 3266be168c0dSopenharmony_ci- return; 3267be168c0dSopenharmony_ci- } 3268be168c0dSopenharmony_ci- expr()->SetSize(bn_inputs); 3269be168c0dSopenharmony_ci- set_name(UniqueName("BatchNorm2D")); 3270be168c0dSopenharmony_ci- auto bn_param = reinterpret_cast<BatchNormParameter *>(op_param); 3271be168c0dSopenharmony_ci- bn_param->channel_ = outp; 3272be168c0dSopenharmony_ci- bn_param->momentum_ = momentum; 3273be168c0dSopenharmony_ci- bn_param->epsilon_ = epsilon; 3274be168c0dSopenharmony_ci- SetOpParam(op_param); 3275be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_FusedBatchNorm); 3276be168c0dSopenharmony_ci- std::vector<int> dims = {outp}; 3277be168c0dSopenharmony_ci- auto scale = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "scale"); 3278be168c0dSopenharmony_ci- expr()->set_params(C1NUM, scale); 3279be168c0dSopenharmony_ci- auto offset = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "offset"); 3280be168c0dSopenharmony_ci- expr()->set_params(C2NUM, offset); 3281be168c0dSopenharmony_ci- auto mean = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "mean"); 3282be168c0dSopenharmony_ci- expr()->set_params(C3NUM, mean); 3283be168c0dSopenharmony_ci- auto var = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::ONES, "var"); 3284be168c0dSopenharmony_ci- expr()->set_params(C4NUM, var); 3285be168c0dSopenharmony_ci- SetOutputs(bn_outputs); 3286be168c0dSopenharmony_ci- SetLearn(); 3287be168c0dSopenharmony_ci-} 3288be168c0dSopenharmony_ci- 3289be168c0dSopenharmony_ci-BatchNorm2dGradM::BatchNorm2dGradM(BatchNorm2dM *bn_node) : Node() { 3290be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(BNGradParameter)); 3291be168c0dSopenharmony_ci- if (op_param == nullptr) { 3292be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate BNGradParameter"; 3293be168c0dSopenharmony_ci- return; 3294be168c0dSopenharmony_ci- } 3295be168c0dSopenharmony_ci- expr()->SetSize(C6NUM); 3296be168c0dSopenharmony_ci- set_name(bn_node->name() + "/" + kGradName + "/bnGrad"); 3297be168c0dSopenharmony_ci- auto bn_grad_param = reinterpret_cast<BNGradParameter *>(op_param); 3298be168c0dSopenharmony_ci- auto bn_param = reinterpret_cast<BatchNormParameter *>(bn_node->OpParam()); 3299be168c0dSopenharmony_ci- bn_param->is_training_ = true; 3300be168c0dSopenharmony_ci- bn_grad_param->epsilon_ = bn_param->epsilon_; 3301be168c0dSopenharmony_ci- bn_grad_param->is_training_ = true; 3302be168c0dSopenharmony_ci- SetOpParam(op_param); 3303be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_BatchNormGrad); 3304be168c0dSopenharmony_ci- EXPR e(this); 3305be168c0dSopenharmony_ci- e.SetSize(0); 3306be168c0dSopenharmony_ci- // Dgamma 3307be168c0dSopenharmony_ci- expr_.emplace_back(e); 3308be168c0dSopenharmony_ci- // Doffset 3309be168c0dSopenharmony_ci- expr_.emplace_back(e); 3310be168c0dSopenharmony_ci-} 3311be168c0dSopenharmony_ci- 3312be168c0dSopenharmony_ci-std::vector<EXPR *> BatchNorm2dM::Grad(EXPR *yt) { 3313be168c0dSopenharmony_ci- auto bn_grad_node = new (std::nothrow) BatchNorm2dGradM(this); 3314be168c0dSopenharmony_ci- if (bn_grad_node == nullptr) { 3315be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate batchnorm grad"; 3316be168c0dSopenharmony_ci- return {}; 3317be168c0dSopenharmony_ci- } 3318be168c0dSopenharmony_ci- PushOp(bn_grad_node); 3319be168c0dSopenharmony_ci- auto bn_grad = (*bn_grad_node)({yt, input(0), output(1), output(3), output(4), output(2)}); 3320be168c0dSopenharmony_ci- return bn_grad; 3321be168c0dSopenharmony_ci-} 3322be168c0dSopenharmony_ci- 3323be168c0dSopenharmony_ci-int BatchNorm2dM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3324be168c0dSopenharmony_ci- auto bn_param = reinterpret_cast<const BatchNormParameter *>(OpParam()); 3325be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::FusedBatchNormT; 3326be168c0dSopenharmony_ci- if (prim == nullptr) { 3327be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3328be168c0dSopenharmony_ci- return RET_ERROR; 3329be168c0dSopenharmony_ci- } 3330be168c0dSopenharmony_ci- prim->epsilon = bn_param->epsilon_; 3331be168c0dSopenharmony_ci- prim->momentum = bn_param->momentum_; 3332be168c0dSopenharmony_ci- prim->mode = (bn_param->is_training_ == false) ? 0 : 1; 3333be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3334be168c0dSopenharmony_ci- return RET_OK; 3335be168c0dSopenharmony_ci-} 3336be168c0dSopenharmony_ci- 3337be168c0dSopenharmony_ci-void BatchNorm2dM::SetLearn() { 3338be168c0dSopenharmony_ci- AddLearn(input(C1NUM)->node()); 3339be168c0dSopenharmony_ci- AddLearn(input(C2NUM)->node()); 3340be168c0dSopenharmony_ci-} 3341be168c0dSopenharmony_ci- 3342be168c0dSopenharmony_ci-int BatchNorm2dGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3343be168c0dSopenharmony_ci- auto param = reinterpret_cast<const BNGradParameter *>(OpParam()); 3344be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::BatchNormGradT; 3345be168c0dSopenharmony_ci- if (prim == nullptr) { 3346be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3347be168c0dSopenharmony_ci- return RET_ERROR; 3348be168c0dSopenharmony_ci- } 3349be168c0dSopenharmony_ci- prim->epsilon = param->epsilon_; 3350be168c0dSopenharmony_ci- prim->is_training = param->is_training_; 3351be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3352be168c0dSopenharmony_ci- return RET_OK; 3353be168c0dSopenharmony_ci-} 3354be168c0dSopenharmony_ci- 3355be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_FusedBatchNorm, ReturnNode<BatchNorm2dM>); 3356be168c0dSopenharmony_ci-namespace NN { 3357be168c0dSopenharmony_ci-Node *BatchNorm2D(int outp, float momentum, float epsilon) { 3358be168c0dSopenharmony_ci- auto node = new (std::nothrow) BatchNorm2dM(outp, momentum, epsilon); 3359be168c0dSopenharmony_ci- if (node == nullptr) { 3360be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate node"; 3361be168c0dSopenharmony_ci- return nullptr; 3362be168c0dSopenharmony_ci- } 3363be168c0dSopenharmony_ci- return node; 3364be168c0dSopenharmony_ci-} 3365be168c0dSopenharmony_ci-} // namespace NN 3366be168c0dSopenharmony_ci-} // namespace lite 3367be168c0dSopenharmony_ci-} // namespace mindspore 3368be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/batchnorm.h b/mindspore/lite/src/expression/ops/batchnorm.h 3369be168c0dSopenharmony_cideleted file mode 100644 3370be168c0dSopenharmony_ciindex 2891f35a..00000000 3371be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/batchnorm.h 3372be168c0dSopenharmony_ci+++ /dev/null 3373be168c0dSopenharmony_ci@@ -1,43 +0,0 @@ 3374be168c0dSopenharmony_ci-/** 3375be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3376be168c0dSopenharmony_ci- * 3377be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3378be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3379be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3380be168c0dSopenharmony_ci- * 3381be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3382be168c0dSopenharmony_ci- * 3383be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3384be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3385be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3386be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3387be168c0dSopenharmony_ci- * limitations under the License. 3388be168c0dSopenharmony_ci- */ 3389be168c0dSopenharmony_ci- 3390be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_ 3391be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_ 3392be168c0dSopenharmony_ci- 3393be168c0dSopenharmony_ci-#include <vector> 3394be168c0dSopenharmony_ci-#include <memory> 3395be168c0dSopenharmony_ci-#include "src/expression/net.h" 3396be168c0dSopenharmony_ci-#include "inner/model_generated.h" 3397be168c0dSopenharmony_ci- 3398be168c0dSopenharmony_ci-namespace mindspore { 3399be168c0dSopenharmony_ci-namespace lite { 3400be168c0dSopenharmony_ci-class BatchNorm2dM : public Node { 3401be168c0dSopenharmony_ci- public: 3402be168c0dSopenharmony_ci- BatchNorm2dM() = default; 3403be168c0dSopenharmony_ci- BatchNorm2dM(int outp, float momentum, float epsilon); 3404be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 3405be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3406be168c0dSopenharmony_ci- void SetLearn() override; 3407be168c0dSopenharmony_ci-}; 3408be168c0dSopenharmony_ci- 3409be168c0dSopenharmony_ci-class BatchNorm2dGradM : public Node { 3410be168c0dSopenharmony_ci- public: 3411be168c0dSopenharmony_ci- explicit BatchNorm2dGradM(BatchNorm2dM *bn_node); 3412be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3413be168c0dSopenharmony_ci-}; 3414be168c0dSopenharmony_ci-} // namespace lite 3415be168c0dSopenharmony_ci-} // namespace mindspore 3416be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BATCHNORM_H_ 3417be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/biasadd.cc b/mindspore/lite/src/expression/ops/biasadd.cc 3418be168c0dSopenharmony_cideleted file mode 100644 3419be168c0dSopenharmony_ciindex c6088584..00000000 3420be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/biasadd.cc 3421be168c0dSopenharmony_ci+++ /dev/null 3422be168c0dSopenharmony_ci@@ -1,93 +0,0 @@ 3423be168c0dSopenharmony_ci-/** 3424be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3425be168c0dSopenharmony_ci- * 3426be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3427be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3428be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3429be168c0dSopenharmony_ci- * 3430be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3431be168c0dSopenharmony_ci- * 3432be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3433be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3434be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3435be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3436be168c0dSopenharmony_ci- * limitations under the License. 3437be168c0dSopenharmony_ci- */ 3438be168c0dSopenharmony_ci- 3439be168c0dSopenharmony_ci-#include "src/expression/ops/biasadd.h" 3440be168c0dSopenharmony_ci-#include "src/expression/ops/transpose.h" 3441be168c0dSopenharmony_ci-#include "nnacl/arithmetic_parameter.h" 3442be168c0dSopenharmony_ci-#include "src/expression/import.h" 3443be168c0dSopenharmony_ci- 3444be168c0dSopenharmony_ci-namespace mindspore { 3445be168c0dSopenharmony_ci-namespace lite { 3446be168c0dSopenharmony_ci-BiasAddM::BiasAddM(Format data_format) { 3447be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(ArithmeticParameter)); 3448be168c0dSopenharmony_ci- if (op_param == nullptr) { 3449be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ConvParameter"; 3450be168c0dSopenharmony_ci- return; 3451be168c0dSopenharmony_ci- } 3452be168c0dSopenharmony_ci- auto bias_param = reinterpret_cast<ArithmeticParameter *>(op_param); 3453be168c0dSopenharmony_ci- SetOpParam(bias_param); 3454be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_BiasAdd); 3455be168c0dSopenharmony_ci-} 3456be168c0dSopenharmony_ci- 3457be168c0dSopenharmony_ci-std::vector<EXPR *> BiasAddM::construct(const std::vector<EXPR *> &inputs) { 3458be168c0dSopenharmony_ci- auto x = Node::construct(inputs); 3459be168c0dSopenharmony_ci- AddLearn(inputs.at(C1NUM)->node()); 3460be168c0dSopenharmony_ci- return x; 3461be168c0dSopenharmony_ci-} 3462be168c0dSopenharmony_ci- 3463be168c0dSopenharmony_ci-void BiasAddM::SetLearn() { AddLearn(input(C1NUM)->node()); } 3464be168c0dSopenharmony_ci- 3465be168c0dSopenharmony_ci-std::vector<EXPR *> BiasAddM::Grad(EXPR *yt) { 3466be168c0dSopenharmony_ci- auto in = yt; 3467be168c0dSopenharmony_ci- if (yt->format() != NHWC && yt->dims().size() == C4NUM) { 3468be168c0dSopenharmony_ci- in = TransposeM::TransposeCHW2HWC(yt); 3469be168c0dSopenharmony_ci- in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); 3470be168c0dSopenharmony_ci- PushOp(in->node()); 3471be168c0dSopenharmony_ci- } 3472be168c0dSopenharmony_ci- auto grad_node = new (std::nothrow) BiasAddGradM(*this); 3473be168c0dSopenharmony_ci- if (grad_node == nullptr) { 3474be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannon allocate Bias Grad"; 3475be168c0dSopenharmony_ci- return {}; 3476be168c0dSopenharmony_ci- } 3477be168c0dSopenharmony_ci- PushOp(grad_node); 3478be168c0dSopenharmony_ci- auto bias_grad = (*grad_node)({in}); 3479be168c0dSopenharmony_ci- return {in, bias_grad.front()}; 3480be168c0dSopenharmony_ci-} 3481be168c0dSopenharmony_ci-int BiasAddM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3482be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::BiasAddT; 3483be168c0dSopenharmony_ci- if (prim == nullptr) { 3484be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot allocate prim"; 3485be168c0dSopenharmony_ci- return RET_ERROR; 3486be168c0dSopenharmony_ci- } 3487be168c0dSopenharmony_ci- prim->format = static_cast<schema::Format>(KHWC); 3488be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3489be168c0dSopenharmony_ci- return RET_OK; 3490be168c0dSopenharmony_ci-} 3491be168c0dSopenharmony_ci- 3492be168c0dSopenharmony_ci-BiasAddGradM::BiasAddGradM(const BiasAddM &bias) { 3493be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(OpParameter)); 3494be168c0dSopenharmony_ci- if (op_param == nullptr) { 3495be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate op_param"; 3496be168c0dSopenharmony_ci- } 3497be168c0dSopenharmony_ci- SetOpParam(op_param); 3498be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_BiasAddGrad); 3499be168c0dSopenharmony_ci-} 3500be168c0dSopenharmony_ci- 3501be168c0dSopenharmony_ci-int BiasAddGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3502be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::BiasAddGradT; 3503be168c0dSopenharmony_ci- if (prim == nullptr) { 3504be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3505be168c0dSopenharmony_ci- return RET_ERROR; 3506be168c0dSopenharmony_ci- } 3507be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3508be168c0dSopenharmony_ci- return RET_OK; 3509be168c0dSopenharmony_ci-} 3510be168c0dSopenharmony_ci- 3511be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_BiasAdd, ReturnNode<BiasAddM>); 3512be168c0dSopenharmony_ci- 3513be168c0dSopenharmony_ci-namespace NN {} 3514be168c0dSopenharmony_ci-} // namespace lite 3515be168c0dSopenharmony_ci-} // namespace mindspore 3516be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/biasadd.h b/mindspore/lite/src/expression/ops/biasadd.h 3517be168c0dSopenharmony_cideleted file mode 100644 3518be168c0dSopenharmony_ciindex bb23e0ff..00000000 3519be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/biasadd.h 3520be168c0dSopenharmony_ci+++ /dev/null 3521be168c0dSopenharmony_ci@@ -1,44 +0,0 @@ 3522be168c0dSopenharmony_ci-/** 3523be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3524be168c0dSopenharmony_ci- * 3525be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3526be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3527be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3528be168c0dSopenharmony_ci- * 3529be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3530be168c0dSopenharmony_ci- * 3531be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3532be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3533be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3534be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3535be168c0dSopenharmony_ci- * limitations under the License. 3536be168c0dSopenharmony_ci- */ 3537be168c0dSopenharmony_ci- 3538be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_ 3539be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_ 3540be168c0dSopenharmony_ci- 3541be168c0dSopenharmony_ci-#include <vector> 3542be168c0dSopenharmony_ci-#include <memory> 3543be168c0dSopenharmony_ci-#include "src/expression/net.h" 3544be168c0dSopenharmony_ci-#include "inner/model_generated.h" 3545be168c0dSopenharmony_ci- 3546be168c0dSopenharmony_ci-namespace mindspore { 3547be168c0dSopenharmony_ci-namespace lite { 3548be168c0dSopenharmony_ci-class BiasAddM : public Node { 3549be168c0dSopenharmony_ci- public: 3550be168c0dSopenharmony_ci- BiasAddM() = default; 3551be168c0dSopenharmony_ci- explicit BiasAddM(Format data_format); 3552be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 3553be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *yt) override; 3554be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3555be168c0dSopenharmony_ci- void SetLearn() override; 3556be168c0dSopenharmony_ci-}; 3557be168c0dSopenharmony_ci- 3558be168c0dSopenharmony_ci-class BiasAddGradM : public Node { 3559be168c0dSopenharmony_ci- public: 3560be168c0dSopenharmony_ci- explicit BiasAddGradM(const BiasAddM &bias); 3561be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3562be168c0dSopenharmony_ci-}; 3563be168c0dSopenharmony_ci-} // namespace lite 3564be168c0dSopenharmony_ci-} // namespace mindspore 3565be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_BIASADD_H_ 3566be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/conv.cc b/mindspore/lite/src/expression/ops/conv.cc 3567be168c0dSopenharmony_cideleted file mode 100644 3568be168c0dSopenharmony_ciindex 669fd2b9..00000000 3569be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/conv.cc 3570be168c0dSopenharmony_ci+++ /dev/null 3571be168c0dSopenharmony_ci@@ -1,241 +0,0 @@ 3572be168c0dSopenharmony_ci-/** 3573be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3574be168c0dSopenharmony_ci- * 3575be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3576be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3577be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3578be168c0dSopenharmony_ci- * 3579be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3580be168c0dSopenharmony_ci- * 3581be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3582be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3583be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3584be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3585be168c0dSopenharmony_ci- * limitations under the License. 3586be168c0dSopenharmony_ci- */ 3587be168c0dSopenharmony_ci- 3588be168c0dSopenharmony_ci-#include "src/expression/ops/conv.h" 3589be168c0dSopenharmony_ci-#include <memory> 3590be168c0dSopenharmony_ci-#include "src/expression/ops/biasadd.h" 3591be168c0dSopenharmony_ci-#include "src/expression/ops/depend.h" 3592be168c0dSopenharmony_ci-#include "src/expression/ops/transpose.h" 3593be168c0dSopenharmony_ci-#include "nnacl/conv_parameter.h" 3594be168c0dSopenharmony_ci-#include "inner/model_generated.h" 3595be168c0dSopenharmony_ci-#include "src/expression/import.h" 3596be168c0dSopenharmony_ci-#include "src/expression/ops.h" 3597be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 3598be168c0dSopenharmony_ci- 3599be168c0dSopenharmony_ci-namespace mindspore { 3600be168c0dSopenharmony_ci-namespace lite { 3601be168c0dSopenharmony_ci-ConvM::ConvM(const ConvConfig &cfg) : Node() { 3602be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(ConvParameter)); 3603be168c0dSopenharmony_ci- if (op_param == nullptr) { 3604be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ConvParameter"; 3605be168c0dSopenharmony_ci- return; 3606be168c0dSopenharmony_ci- } 3607be168c0dSopenharmony_ci- SetOpParam(op_param); 3608be168c0dSopenharmony_ci- ConvParameter *conv_param = reinterpret_cast<ConvParameter *>(OpParam()); 3609be168c0dSopenharmony_ci- conv_param->input_channel_ = cfg.in_channel_; 3610be168c0dSopenharmony_ci- conv_param->output_channel_ = cfg.out_channel_; 3611be168c0dSopenharmony_ci- conv_param->kernel_h_ = cfg.kernel_size_[0]; 3612be168c0dSopenharmony_ci- conv_param->kernel_w_ = cfg.kernel_size_[1]; 3613be168c0dSopenharmony_ci- conv_param->stride_h_ = cfg.stride_[0]; 3614be168c0dSopenharmony_ci- conv_param->stride_w_ = cfg.stride_[1]; 3615be168c0dSopenharmony_ci- auto pad_mode = GetMode(cfg.pad_mode_); 3616be168c0dSopenharmony_ci- if (pad_mode == -1) { 3617be168c0dSopenharmony_ci- MS_LOG(ERROR) << "bad pad mode"; 3618be168c0dSopenharmony_ci- return; 3619be168c0dSopenharmony_ci- } 3620be168c0dSopenharmony_ci- conv_param->pad_mode_ = static_cast<PadType>(pad_mode); 3621be168c0dSopenharmony_ci- conv_param->pad_u_ = cfg.padding_[C0NUM]; 3622be168c0dSopenharmony_ci- conv_param->pad_d_ = cfg.padding_[C1NUM]; 3623be168c0dSopenharmony_ci- conv_param->pad_l_ = cfg.padding_[C2NUM]; 3624be168c0dSopenharmony_ci- conv_param->pad_r_ = cfg.padding_[C3NUM]; 3625be168c0dSopenharmony_ci- conv_param->dilation_h_ = cfg.dilation_[C0NUM]; 3626be168c0dSopenharmony_ci- conv_param->dilation_w_ = cfg.dilation_[C1NUM]; 3627be168c0dSopenharmony_ci- conv_param->group_ = cfg.group_; 3628be168c0dSopenharmony_ci- conv_param->out_format_ = NHWC; 3629be168c0dSopenharmony_ci- conv_param->act_type_ = ActType_No; 3630be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 3631be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Conv2DFusion); 3632be168c0dSopenharmony_ci- set_name(UniqueName("Conv")); 3633be168c0dSopenharmony_ci- Param::Mode mode = Param::String2Enum(cfg.weight_init_); 3634be168c0dSopenharmony_ci- std::vector<int> dims = {conv_param->output_channel_, conv_param->kernel_h_, conv_param->kernel_w_, 3635be168c0dSopenharmony_ci- conv_param->input_channel_ / conv_param->group_}; 3636be168c0dSopenharmony_ci- auto w = CreateWeights(dims, kNumberTypeFloat32, KHWC, mode, "weights"); 3637be168c0dSopenharmony_ci- expr()->set_params(C1NUM, w); 3638be168c0dSopenharmony_ci- if (cfg.has_bias) { 3639be168c0dSopenharmony_ci- bias_ = new (std::nothrow) BiasAddM(KHWC); 3640be168c0dSopenharmony_ci- if (bias_ == nullptr) { 3641be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate bias"; 3642be168c0dSopenharmony_ci- return; 3643be168c0dSopenharmony_ci- } 3644be168c0dSopenharmony_ci- bias_->update_name(name()); 3645be168c0dSopenharmony_ci- std::vector<int> dim_bias = {conv_param->output_channel_}; 3646be168c0dSopenharmony_ci- wbias_ = CreateWeights(dim_bias, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "weights"); 3647be168c0dSopenharmony_ci- AddLearn(wbias_->node()); 3648be168c0dSopenharmony_ci- PushOp(bias_); 3649be168c0dSopenharmony_ci- } 3650be168c0dSopenharmony_ci- SetLearn(); 3651be168c0dSopenharmony_ci-} 3652be168c0dSopenharmony_ci- 3653be168c0dSopenharmony_ci-std::vector<EXPR *> ConvM::construct(const std::vector<EXPR *> &inputs) { 3654be168c0dSopenharmony_ci- auto in = inputs; 3655be168c0dSopenharmony_ci- auto x = in.front(); 3656be168c0dSopenharmony_ci- if (x->format() != NHWC && x->dims().size() == C4NUM) { 3657be168c0dSopenharmony_ci- x = TransposeM::TransposeCHW2HWC(x); 3658be168c0dSopenharmony_ci- x->node()->set_name(name() + "/" + x->node()->name()); 3659be168c0dSopenharmony_ci- PushOp(x->node()); 3660be168c0dSopenharmony_ci- in.at(0) = x; 3661be168c0dSopenharmony_ci- } 3662be168c0dSopenharmony_ci- auto y = Node::construct(in); 3663be168c0dSopenharmony_ci- if (bias_ != nullptr) { 3664be168c0dSopenharmony_ci- y = (*bias_)({y.front(), wbias_}); 3665be168c0dSopenharmony_ci- } 3666be168c0dSopenharmony_ci- return y; 3667be168c0dSopenharmony_ci-} 3668be168c0dSopenharmony_ci- 3669be168c0dSopenharmony_ci-void ConvM::SetLearn() { AddLearn(input(C1NUM)->node()); } 3670be168c0dSopenharmony_ci- 3671be168c0dSopenharmony_ci-int ConvM::GetMode(std::string mode) { 3672be168c0dSopenharmony_ci- const std::vector<std::string> list = {"pad", "same", "valid"}; 3673be168c0dSopenharmony_ci- auto itr = std::find(list.begin(), list.end(), mode); 3674be168c0dSopenharmony_ci- if (itr == list.end()) { 3675be168c0dSopenharmony_ci- MS_LOG(ERROR) << "illegal mode" << mode; 3676be168c0dSopenharmony_ci- return -1; 3677be168c0dSopenharmony_ci- } 3678be168c0dSopenharmony_ci- return std::distance(list.begin(), itr); 3679be168c0dSopenharmony_ci-} 3680be168c0dSopenharmony_ci- 3681be168c0dSopenharmony_ci-std::vector<EXPR *> ConvM::Grad(EXPR *yt) { 3682be168c0dSopenharmony_ci- // Generate Input Grad 3683be168c0dSopenharmony_ci- EXPR *in = yt; 3684be168c0dSopenharmony_ci- if (yt->format() != NHWC && yt->dims().size() == C4NUM) { 3685be168c0dSopenharmony_ci- in = TransposeM::TransposeCHW2HWC(yt); 3686be168c0dSopenharmony_ci- in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); 3687be168c0dSopenharmony_ci- PushOp(in->node()); 3688be168c0dSopenharmony_ci- } 3689be168c0dSopenharmony_ci- auto inGrad = new (std::nothrow) ConvInputGradM(this); 3690be168c0dSopenharmony_ci- if (inGrad == nullptr) { 3691be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate convolution input grad"; 3692be168c0dSopenharmony_ci- return {}; 3693be168c0dSopenharmony_ci- } 3694be168c0dSopenharmony_ci- PushOp(inGrad); 3695be168c0dSopenharmony_ci- auto ig = (*inGrad)({in, input(1), inGrad->input(2)}); 3696be168c0dSopenharmony_ci- // Execution Control Flow ! 3697be168c0dSopenharmony_ci- auto depend = NN::Depend(); 3698be168c0dSopenharmony_ci- if (depend == nullptr) { 3699be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate depend"; 3700be168c0dSopenharmony_ci- return {}; 3701be168c0dSopenharmony_ci- } 3702be168c0dSopenharmony_ci- PushOp(depend); 3703be168c0dSopenharmony_ci- depend->update_name(name()); 3704be168c0dSopenharmony_ci- auto de = (*depend)({inGrad->expr()}); 3705be168c0dSopenharmony_ci- // Generate Filter Grad 3706be168c0dSopenharmony_ci- auto filterGrad = new (std::nothrow) ConvFilterGradM(this); 3707be168c0dSopenharmony_ci- if (filterGrad == nullptr) { 3708be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate convolution filter grad"; 3709be168c0dSopenharmony_ci- return {}; 3710be168c0dSopenharmony_ci- } 3711be168c0dSopenharmony_ci- PushOp(filterGrad); 3712be168c0dSopenharmony_ci- filterGrad->update_name(name()); 3713be168c0dSopenharmony_ci- auto fg = (*filterGrad)({in, input(0), filterGrad->input(2), de[0]}); 3714be168c0dSopenharmony_ci- std::vector<EXPR *> res = {ig[0], fg[0]}; 3715be168c0dSopenharmony_ci- return res; 3716be168c0dSopenharmony_ci-} 3717be168c0dSopenharmony_ci- 3718be168c0dSopenharmony_ci-int ConvM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3719be168c0dSopenharmony_ci- auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam()); 3720be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::Conv2DFusionT; 3721be168c0dSopenharmony_ci- if (prim == nullptr) { 3722be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3723be168c0dSopenharmony_ci- return RET_ERROR; 3724be168c0dSopenharmony_ci- } 3725be168c0dSopenharmony_ci- prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_); 3726be168c0dSopenharmony_ci- prim->format = static_cast<schema::Format>(conv_param->out_format_); 3727be168c0dSopenharmony_ci- prim->stride = {conv_param->stride_h_, conv_param->stride_w_}; 3728be168c0dSopenharmony_ci- prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_}; 3729be168c0dSopenharmony_ci- prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_}; 3730be168c0dSopenharmony_ci- prim->out_channel = conv_param->output_channel_; 3731be168c0dSopenharmony_ci- prim->in_channel = conv_param->input_channel_; 3732be168c0dSopenharmony_ci- prim->group = conv_param->group_; 3733be168c0dSopenharmony_ci- prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_); 3734be168c0dSopenharmony_ci- prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_}; 3735be168c0dSopenharmony_ci- prim->mode = 1; 3736be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3737be168c0dSopenharmony_ci- return RET_OK; 3738be168c0dSopenharmony_ci-} 3739be168c0dSopenharmony_ci- 3740be168c0dSopenharmony_ci-ConvInputGradM::ConvInputGradM(ConvM *conv_node) : Node() { 3741be168c0dSopenharmony_ci- CloneOpParam<ConvParameter>(conv_node->OpParam()); 3742be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Conv2DBackpropInputFusion); 3743be168c0dSopenharmony_ci- set_name(kGradName + "/conv2DBackpropInput"); 3744be168c0dSopenharmony_ci- expr()->SetSize(C3NUM); 3745be168c0dSopenharmony_ci- auto const x = conv_node->input(0); 3746be168c0dSopenharmony_ci- CreateConstTensor(C2NUM, {static_cast<int32_t>(x->dims().size())}, kNumberTypeInt32, KHWC, "shape", x->dims().data()); 3747be168c0dSopenharmony_ci-} 3748be168c0dSopenharmony_ci- 3749be168c0dSopenharmony_ci-int ConvInputGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3750be168c0dSopenharmony_ci- auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam()); 3751be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::Conv2DBackpropInputFusionT; 3752be168c0dSopenharmony_ci- if (prim == nullptr) { 3753be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3754be168c0dSopenharmony_ci- return RET_ERROR; 3755be168c0dSopenharmony_ci- } 3756be168c0dSopenharmony_ci- prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_); 3757be168c0dSopenharmony_ci- prim->format = static_cast<schema::Format>(conv_param->out_format_); 3758be168c0dSopenharmony_ci- prim->stride = {conv_param->stride_h_, conv_param->stride_w_}; 3759be168c0dSopenharmony_ci- prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_}; 3760be168c0dSopenharmony_ci- prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_}; 3761be168c0dSopenharmony_ci- prim->out_channel = conv_param->output_channel_; 3762be168c0dSopenharmony_ci- prim->in_channel = conv_param->input_channel_; 3763be168c0dSopenharmony_ci- prim->group = conv_param->group_; 3764be168c0dSopenharmony_ci- prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_); 3765be168c0dSopenharmony_ci- prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_}; 3766be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3767be168c0dSopenharmony_ci- return RET_OK; 3768be168c0dSopenharmony_ci-} 3769be168c0dSopenharmony_ci- 3770be168c0dSopenharmony_ci-ConvFilterGradM::ConvFilterGradM(ConvM *conv_node) : Node() { 3771be168c0dSopenharmony_ci- CloneOpParam<ConvParameter>(conv_node->OpParam()); 3772be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Conv2DBackpropFilterFusion); 3773be168c0dSopenharmony_ci- set_name(kGradName + "/conv2DBackpropFilter"); 3774be168c0dSopenharmony_ci- expr()->SetSize(C4NUM); 3775be168c0dSopenharmony_ci- auto w = conv_node->input(1); 3776be168c0dSopenharmony_ci- CreateConstTensor(C2NUM, {static_cast<int32_t>(w->dims().size())}, kNumberTypeInt32, KHWC, "shape", w->dims().data()); 3777be168c0dSopenharmony_ci-} 3778be168c0dSopenharmony_ci-int ConvFilterGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 3779be168c0dSopenharmony_ci- auto conv_param = reinterpret_cast<const ConvParameter *>(OpParam()); 3780be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::Conv2DBackpropFilterFusionT; 3781be168c0dSopenharmony_ci- if (prim == nullptr) { 3782be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 3783be168c0dSopenharmony_ci- return RET_ERROR; 3784be168c0dSopenharmony_ci- } 3785be168c0dSopenharmony_ci- prim->activation_type = static_cast<schema::ActivationType>(conv_param->act_type_); 3786be168c0dSopenharmony_ci- prim->format = static_cast<schema::Format>(conv_param->out_format_); 3787be168c0dSopenharmony_ci- prim->stride = {conv_param->stride_h_, conv_param->stride_w_}; 3788be168c0dSopenharmony_ci- prim->kernel_size = {conv_param->kernel_h_, conv_param->kernel_w_}; 3789be168c0dSopenharmony_ci- prim->dilation = {conv_param->dilation_h_, conv_param->dilation_w_}; 3790be168c0dSopenharmony_ci- prim->out_channel = conv_param->output_channel_; 3791be168c0dSopenharmony_ci- prim->in_channel = conv_param->input_channel_; 3792be168c0dSopenharmony_ci- prim->group = conv_param->group_; 3793be168c0dSopenharmony_ci- prim->pad_mode = static_cast<schema::PadMode>(conv_param->pad_mode_); 3794be168c0dSopenharmony_ci- prim->pad_list = {conv_param->pad_u_, conv_param->pad_d_, conv_param->pad_l_, conv_param->pad_r_}; 3795be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 3796be168c0dSopenharmony_ci- return RET_OK; 3797be168c0dSopenharmony_ci-} 3798be168c0dSopenharmony_ci- 3799be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Conv2DFusion, ReturnNode<ConvM>); 3800be168c0dSopenharmony_ci- 3801be168c0dSopenharmony_ci-namespace NN { 3802be168c0dSopenharmony_ci-Node *Conv2D(const ConvConfig &cfg) { 3803be168c0dSopenharmony_ci- auto c = new (std::nothrow) ConvM(cfg); 3804be168c0dSopenharmony_ci- if (c == nullptr) { 3805be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate Convolution object"; 3806be168c0dSopenharmony_ci- return nullptr; 3807be168c0dSopenharmony_ci- } 3808be168c0dSopenharmony_ci- return c; 3809be168c0dSopenharmony_ci-} 3810be168c0dSopenharmony_ci-} // namespace NN 3811be168c0dSopenharmony_ci-} // namespace lite 3812be168c0dSopenharmony_ci-} // namespace mindspore 3813be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/conv.h b/mindspore/lite/src/expression/ops/conv.h 3814be168c0dSopenharmony_cideleted file mode 100644 3815be168c0dSopenharmony_ciindex 32fc6632..00000000 3816be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/conv.h 3817be168c0dSopenharmony_ci+++ /dev/null 3818be168c0dSopenharmony_ci@@ -1,58 +0,0 @@ 3819be168c0dSopenharmony_ci-/** 3820be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3821be168c0dSopenharmony_ci- * 3822be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3823be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3824be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3825be168c0dSopenharmony_ci- * 3826be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3827be168c0dSopenharmony_ci- * 3828be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3829be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3830be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3831be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3832be168c0dSopenharmony_ci- * limitations under the License. 3833be168c0dSopenharmony_ci- */ 3834be168c0dSopenharmony_ci- 3835be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_ 3836be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_ 3837be168c0dSopenharmony_ci- 3838be168c0dSopenharmony_ci-#include <vector> 3839be168c0dSopenharmony_ci-#include <memory> 3840be168c0dSopenharmony_ci-#include <string> 3841be168c0dSopenharmony_ci-#include "src/expression/cfg.h" 3842be168c0dSopenharmony_ci-#include "src/expression/node.h" 3843be168c0dSopenharmony_ci- 3844be168c0dSopenharmony_ci-namespace mindspore { 3845be168c0dSopenharmony_ci-namespace lite { 3846be168c0dSopenharmony_ci-class ConvM : public Node { 3847be168c0dSopenharmony_ci- public: 3848be168c0dSopenharmony_ci- ConvM() = default; 3849be168c0dSopenharmony_ci- explicit ConvM(const ConvConfig &cfg); 3850be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 3851be168c0dSopenharmony_ci- Param *weight() override { return input(1)->node()->data(); } 3852be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 3853be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3854be168c0dSopenharmony_ci- void SetLearn() override; 3855be168c0dSopenharmony_ci- 3856be168c0dSopenharmony_ci- private: 3857be168c0dSopenharmony_ci- int GetMode(std::string mode); 3858be168c0dSopenharmony_ci- Node *bias_ = nullptr; 3859be168c0dSopenharmony_ci- EXPR *wbias_ = nullptr; 3860be168c0dSopenharmony_ci-}; 3861be168c0dSopenharmony_ci- 3862be168c0dSopenharmony_ci-class ConvInputGradM : public Node { 3863be168c0dSopenharmony_ci- public: 3864be168c0dSopenharmony_ci- explicit ConvInputGradM(ConvM *conv_node); 3865be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3866be168c0dSopenharmony_ci-}; 3867be168c0dSopenharmony_ci- 3868be168c0dSopenharmony_ci-class ConvFilterGradM : public Node { 3869be168c0dSopenharmony_ci- public: 3870be168c0dSopenharmony_ci- explicit ConvFilterGradM(ConvM *conv_node); 3871be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 3872be168c0dSopenharmony_ci-}; 3873be168c0dSopenharmony_ci-} // namespace lite 3874be168c0dSopenharmony_ci-} // namespace mindspore 3875be168c0dSopenharmony_ci- 3876be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_CONV_H_ 3877be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/dense.cc b/mindspore/lite/src/expression/ops/dense.cc 3878be168c0dSopenharmony_cideleted file mode 100644 3879be168c0dSopenharmony_ciindex 9d1df46f..00000000 3880be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/dense.cc 3881be168c0dSopenharmony_ci+++ /dev/null 3882be168c0dSopenharmony_ci@@ -1,151 +0,0 @@ 3883be168c0dSopenharmony_ci-/** 3884be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 3885be168c0dSopenharmony_ci- * 3886be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 3887be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 3888be168c0dSopenharmony_ci- * You may obtain a copy of the License at 3889be168c0dSopenharmony_ci- * 3890be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 3891be168c0dSopenharmony_ci- * 3892be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 3893be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 3894be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 3895be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 3896be168c0dSopenharmony_ci- * limitations under the License. 3897be168c0dSopenharmony_ci- */ 3898be168c0dSopenharmony_ci- 3899be168c0dSopenharmony_ci-#include "src/expression/ops/dense.h" 3900be168c0dSopenharmony_ci-#include <memory> 3901be168c0dSopenharmony_ci-#include "include/api/cfg.h" 3902be168c0dSopenharmony_ci-#include "src/expression/ops/biasadd.h" 3903be168c0dSopenharmony_ci-#include "src/expression/ops/depend.h" 3904be168c0dSopenharmony_ci-#include "src/expression/ops.h" 3905be168c0dSopenharmony_ci-#include "nnacl/matmul_parameter.h" 3906be168c0dSopenharmony_ci-#include "src/expression/import.h" 3907be168c0dSopenharmony_ci-#include "inner/model_generated.h" 3908be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 3909be168c0dSopenharmony_ci- 3910be168c0dSopenharmony_ci-namespace mindspore { 3911be168c0dSopenharmony_ci-namespace lite { 3912be168c0dSopenharmony_ci-DenseM::DenseM(const DenseConfig &cfg) : Node() { 3913be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(MatMulParameter)); 3914be168c0dSopenharmony_ci- if (op_param == nullptr) { 3915be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate MatMulParameter"; 3916be168c0dSopenharmony_ci- return; 3917be168c0dSopenharmony_ci- } 3918be168c0dSopenharmony_ci- set_name(UniqueName("Dense")); 3919be168c0dSopenharmony_ci- SetOpParam(op_param); 3920be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 3921be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_MatMulFusion); 3922be168c0dSopenharmony_ci- auto param = reinterpret_cast<MatMulParameter *>(opParam_.get()); 3923be168c0dSopenharmony_ci- param->row_ = cfg.out_channels_; 3924be168c0dSopenharmony_ci- param->col_ = cfg.in_channels_; 3925be168c0dSopenharmony_ci- param->a_transpose_ = false; 3926be168c0dSopenharmony_ci- param->b_transpose_ = true; 3927be168c0dSopenharmony_ci- std::vector<int> dims = {param->row_, param->col_}; 3928be168c0dSopenharmony_ci- auto w = Node::CreateWeights(dims, kNumberTypeFloat32, KHWC, Param::Mode::NORMAL, "weights"); 3929be168c0dSopenharmony_ci- expr()->set_params(C1NUM, w); 3930be168c0dSopenharmony_ci- if (cfg.has_bias_) { 3931be168c0dSopenharmony_ci- wbias_ = CreateWeights({cfg.out_channels_}, kNumberTypeFloat32, KHWC, Param::Mode::ZEROS, "bias_weights"); 3932be168c0dSopenharmony_ci- bias_ = new (std::nothrow) BiasAddM(KHWC); 3933be168c0dSopenharmony_ci- if (bias_ == nullptr) { 3934be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate bias"; 3935be168c0dSopenharmony_ci- return; 3936be168c0dSopenharmony_ci- } 3937be168c0dSopenharmony_ci- bias_->update_name(name()); 3938be168c0dSopenharmony_ci- AddLearn(wbias_->node()); 3939be168c0dSopenharmony_ci- PushOp(bias_); 3940be168c0dSopenharmony_ci- } 3941be168c0dSopenharmony_ci- SetLearn(); 3942be168c0dSopenharmony_ci-} 3943be168c0dSopenharmony_ci- 3944be168c0dSopenharmony_ci-std::vector<EXPR *> DenseM::construct(const std::vector<EXPR *> &inputs) { 3945be168c0dSopenharmony_ci- auto x = Node::construct(inputs); 3946be168c0dSopenharmony_ci- if (bias_ != nullptr) { 3947be168c0dSopenharmony_ci- x = (*bias_)({x.front(), wbias_}); 3948be168c0dSopenharmony_ci- } 3949be168c0dSopenharmony_ci- return x; 3950be168c0dSopenharmony_ci-} 3951be168c0dSopenharmony_ci- 3952be168c0dSopenharmony_ci-std::vector<EXPR *> DenseM::Grad(EXPR *yt) { 3953be168c0dSopenharmony_ci- auto src_param = reinterpret_cast<MatMulParameter *>(opParam_.get()); 3954be168c0dSopenharmony_ci- bool ta = src_param->a_transpose_; 3955be168c0dSopenharmony_ci- bool tb = src_param->b_transpose_; 3956be168c0dSopenharmony_ci- 3957be168c0dSopenharmony_ci- // dx grad op 3958be168c0dSopenharmony_ci- auto dxGrad = new (std::nothrow) DenseM(); 3959be168c0dSopenharmony_ci- if (dxGrad == nullptr) { 3960be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate dxGrad "; 3961be168c0dSopenharmony_ci- return {}; 3962be168c0dSopenharmony_ci- } 3963be168c0dSopenharmony_ci- PushOp(dxGrad); 3964be168c0dSopenharmony_ci- dxGrad->CloneOpParam<MatMulParameter>(opParam_); 3965be168c0dSopenharmony_ci- dxGrad->set_primitive(schema::PrimitiveType_MatMulFusion); 3966be168c0dSopenharmony_ci- auto dxGradParam = reinterpret_cast<MatMulParameter *>(dxGrad->OpParam()); 3967be168c0dSopenharmony_ci- dxGradParam->a_transpose_ = (ta && tb); 3968be168c0dSopenharmony_ci- dxGradParam->b_transpose_ = (ta || !tb); 3969be168c0dSopenharmony_ci- dxGrad->set_name(name() + kGradName + "/dxGrad"); 3970be168c0dSopenharmony_ci- EXPR *dx = nullptr; 3971be168c0dSopenharmony_ci- if (ta) { 3972be168c0dSopenharmony_ci- dx = (*dxGrad)({input(1), yt}).front(); 3973be168c0dSopenharmony_ci- } else { 3974be168c0dSopenharmony_ci- dx = (*dxGrad)({yt, input(1)}).front(); 3975be168c0dSopenharmony_ci- } 3976be168c0dSopenharmony_ci- // Control execution flow 3977be168c0dSopenharmony_ci- auto depend = NN::Depend(); 3978be168c0dSopenharmony_ci- if (depend == nullptr) { 3979be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate depend "; 3980be168c0dSopenharmony_ci- return {}; 3981be168c0dSopenharmony_ci- } 3982be168c0dSopenharmony_ci- PushOp(depend); 3983be168c0dSopenharmony_ci- auto de = (*depend)({dxGrad->expr()}).front(); 3984be168c0dSopenharmony_ci- 3985be168c0dSopenharmony_ci- // dw grad op 3986be168c0dSopenharmony_ci- auto dwGrad = new (std::nothrow) DenseM(); 3987be168c0dSopenharmony_ci- if (dwGrad == nullptr) { 3988be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate dwGrad "; 3989be168c0dSopenharmony_ci- return {}; 3990be168c0dSopenharmony_ci- } 3991be168c0dSopenharmony_ci- PushOp(dwGrad); 3992be168c0dSopenharmony_ci- dwGrad->CloneOpParam<MatMulParameter>(opParam_); 3993be168c0dSopenharmony_ci- dwGrad->set_primitive(schema::PrimitiveType_MatMulFusion); 3994be168c0dSopenharmony_ci- auto dwGradParam = reinterpret_cast<MatMulParameter *>(dwGrad->OpParam()); 3995be168c0dSopenharmony_ci- dwGradParam->a_transpose_ = (!ta || tb); 3996be168c0dSopenharmony_ci- dwGradParam->b_transpose_ = ta && tb; 3997be168c0dSopenharmony_ci- dwGrad->set_name(name() + kGradName + "/dwGrad"); 3998be168c0dSopenharmony_ci- EXPR *dw = nullptr; 3999be168c0dSopenharmony_ci- if (tb) { 4000be168c0dSopenharmony_ci- dw = (*dwGrad)({yt, input(0), de}).front(); 4001be168c0dSopenharmony_ci- } else { 4002be168c0dSopenharmony_ci- dw = (*dwGrad)({input(0), yt, de}).front(); 4003be168c0dSopenharmony_ci- } 4004be168c0dSopenharmony_ci- return {dx, dw}; 4005be168c0dSopenharmony_ci-} 4006be168c0dSopenharmony_ci-int DenseM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4007be168c0dSopenharmony_ci- auto dense_param = reinterpret_cast<const MatMulParameter *>(OpParam()); 4008be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::MatMulFusionT; 4009be168c0dSopenharmony_ci- if (prim == nullptr) { 4010be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4011be168c0dSopenharmony_ci- return RET_ERROR; 4012be168c0dSopenharmony_ci- } 4013be168c0dSopenharmony_ci- prim->transpose_a = dense_param->a_transpose_; 4014be168c0dSopenharmony_ci- prim->transpose_b = dense_param->b_transpose_; 4015be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4016be168c0dSopenharmony_ci- return RET_OK; 4017be168c0dSopenharmony_ci-} 4018be168c0dSopenharmony_ci- 4019be168c0dSopenharmony_ci-void DenseM::SetLearn() { AddLearn(input(C1NUM)->node()); } 4020be168c0dSopenharmony_ci- 4021be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_MatMulFusion, ReturnNode<DenseM>); 4022be168c0dSopenharmony_ci- 4023be168c0dSopenharmony_ci-namespace NN { 4024be168c0dSopenharmony_ci-Node *Dense(const DenseConfig &cfg) { 4025be168c0dSopenharmony_ci- auto l = new (std::nothrow) DenseM(cfg); 4026be168c0dSopenharmony_ci- if (l == nullptr) { 4027be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate Dense object"; 4028be168c0dSopenharmony_ci- } 4029be168c0dSopenharmony_ci- return l; 4030be168c0dSopenharmony_ci-} 4031be168c0dSopenharmony_ci-} // namespace NN 4032be168c0dSopenharmony_ci-} // namespace lite 4033be168c0dSopenharmony_ci-} // namespace mindspore 4034be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/dense.h b/mindspore/lite/src/expression/ops/dense.h 4035be168c0dSopenharmony_cideleted file mode 100644 4036be168c0dSopenharmony_ciindex 10734336..00000000 4037be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/dense.h 4038be168c0dSopenharmony_ci+++ /dev/null 4039be168c0dSopenharmony_ci@@ -1,44 +0,0 @@ 4040be168c0dSopenharmony_ci-/** 4041be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4042be168c0dSopenharmony_ci- * 4043be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4044be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4045be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4046be168c0dSopenharmony_ci- * 4047be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4048be168c0dSopenharmony_ci- * 4049be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4050be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4051be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4052be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4053be168c0dSopenharmony_ci- * limitations under the License. 4054be168c0dSopenharmony_ci- */ 4055be168c0dSopenharmony_ci- 4056be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_ 4057be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_ 4058be168c0dSopenharmony_ci- 4059be168c0dSopenharmony_ci-#include <vector> 4060be168c0dSopenharmony_ci-#include <memory> 4061be168c0dSopenharmony_ci-#include "src/expression/node.h" 4062be168c0dSopenharmony_ci-#include "src/expression/cfg.h" 4063be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4064be168c0dSopenharmony_ci- 4065be168c0dSopenharmony_ci-namespace mindspore { 4066be168c0dSopenharmony_ci-namespace lite { 4067be168c0dSopenharmony_ci-class DenseM : public Node { 4068be168c0dSopenharmony_ci- public: 4069be168c0dSopenharmony_ci- DenseM() = default; 4070be168c0dSopenharmony_ci- explicit DenseM(const DenseConfig &cfg); 4071be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 4072be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4073be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4074be168c0dSopenharmony_ci- void SetLearn() override; 4075be168c0dSopenharmony_ci- 4076be168c0dSopenharmony_ci- private: 4077be168c0dSopenharmony_ci- Param *weight() override { return input(1)->node()->data(); } 4078be168c0dSopenharmony_ci- Node *bias_{nullptr}; 4079be168c0dSopenharmony_ci- EXPR *wbias_{nullptr}; 4080be168c0dSopenharmony_ci-}; 4081be168c0dSopenharmony_ci-} // namespace lite 4082be168c0dSopenharmony_ci-} // namespace mindspore 4083be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DENSE_H_ 4084be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/depend.cc b/mindspore/lite/src/expression/ops/depend.cc 4085be168c0dSopenharmony_cideleted file mode 100644 4086be168c0dSopenharmony_ciindex c6aee153..00000000 4087be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/depend.cc 4088be168c0dSopenharmony_ci+++ /dev/null 4089be168c0dSopenharmony_ci@@ -1,43 +0,0 @@ 4090be168c0dSopenharmony_ci-/** 4091be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4092be168c0dSopenharmony_ci- * 4093be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4094be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4095be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4096be168c0dSopenharmony_ci- * 4097be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4098be168c0dSopenharmony_ci- * 4099be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4100be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4101be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4102be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4103be168c0dSopenharmony_ci- * limitations under the License. 4104be168c0dSopenharmony_ci- */ 4105be168c0dSopenharmony_ci- 4106be168c0dSopenharmony_ci-#include "src/expression/ops/depend.h" 4107be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4108be168c0dSopenharmony_ci- 4109be168c0dSopenharmony_ci-namespace mindspore { 4110be168c0dSopenharmony_ci-namespace lite { 4111be168c0dSopenharmony_ci-DependM::DependM() : Node() { 4112be168c0dSopenharmony_ci- auto param = calloc(1, sizeof(OpParameter)); 4113be168c0dSopenharmony_ci- if (param == nullptr) { 4114be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter"; 4115be168c0dSopenharmony_ci- return; 4116be168c0dSopenharmony_ci- } 4117be168c0dSopenharmony_ci- SetOpParam(param); 4118be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Depend); 4119be168c0dSopenharmony_ci- set_name(UniqueName("Depend")); 4120be168c0dSopenharmony_ci-} 4121be168c0dSopenharmony_ci-namespace NN { 4122be168c0dSopenharmony_ci-Node *Depend() { 4123be168c0dSopenharmony_ci- auto d = new (std::nothrow) DependM(); 4124be168c0dSopenharmony_ci- if (d == nullptr) { 4125be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate depend object"; 4126be168c0dSopenharmony_ci- return nullptr; 4127be168c0dSopenharmony_ci- } 4128be168c0dSopenharmony_ci- return d; 4129be168c0dSopenharmony_ci-} 4130be168c0dSopenharmony_ci-} // namespace NN 4131be168c0dSopenharmony_ci-} // namespace lite 4132be168c0dSopenharmony_ci-} // namespace mindspore 4133be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/dropout.cc b/mindspore/lite/src/expression/ops/dropout.cc 4134be168c0dSopenharmony_cideleted file mode 100644 4135be168c0dSopenharmony_ciindex e49bf07a..00000000 4136be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/dropout.cc 4137be168c0dSopenharmony_ci+++ /dev/null 4138be168c0dSopenharmony_ci@@ -1,91 +0,0 @@ 4139be168c0dSopenharmony_ci-/** 4140be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4141be168c0dSopenharmony_ci- * 4142be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4143be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4144be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4145be168c0dSopenharmony_ci- * 4146be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4147be168c0dSopenharmony_ci- * 4148be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4149be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4150be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4151be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4152be168c0dSopenharmony_ci- * limitations under the License. 4153be168c0dSopenharmony_ci- */ 4154be168c0dSopenharmony_ci- 4155be168c0dSopenharmony_ci-#include "src/expression/ops/dropout.h" 4156be168c0dSopenharmony_ci-#include <vector> 4157be168c0dSopenharmony_ci-#include "nnacl/fp32_grad/dropout_parameter.h" 4158be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4159be168c0dSopenharmony_ci-#include "src/expression/import.h" 4160be168c0dSopenharmony_ci-#include "src/expression/ops.h" 4161be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 4162be168c0dSopenharmony_ci- 4163be168c0dSopenharmony_ci-namespace mindspore { 4164be168c0dSopenharmony_ci-namespace lite { 4165be168c0dSopenharmony_ci-DropOutM::DropOutM(float ratio) { 4166be168c0dSopenharmony_ci- auto param = reinterpret_cast<DropoutParameter *>(calloc(1, sizeof(DropoutParameter))); 4167be168c0dSopenharmony_ci- if (param == nullptr) { 4168be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter"; 4169be168c0dSopenharmony_ci- return; 4170be168c0dSopenharmony_ci- } 4171be168c0dSopenharmony_ci- param->ratio_ = ratio; 4172be168c0dSopenharmony_ci- SetOpParam(param); 4173be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Dropout); 4174be168c0dSopenharmony_ci- set_name(UniqueName("DropOut")); 4175be168c0dSopenharmony_ci-} 4176be168c0dSopenharmony_ci- 4177be168c0dSopenharmony_ci-std::vector<EXPR *> DropOutM::Grad(EXPR *yt) { 4178be168c0dSopenharmony_ci- auto inGrad = new (std::nothrow) DropOutGradM(this); 4179be168c0dSopenharmony_ci- if (inGrad == nullptr) { 4180be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate drop grad"; 4181be168c0dSopenharmony_ci- return {}; 4182be168c0dSopenharmony_ci- } 4183be168c0dSopenharmony_ci- return (*inGrad)({yt, expr()}); 4184be168c0dSopenharmony_ci-} 4185be168c0dSopenharmony_ci- 4186be168c0dSopenharmony_ci-int DropOutM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4187be168c0dSopenharmony_ci- auto param = reinterpret_cast<const DropoutParameter *>(OpParam()); 4188be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::DropoutT; 4189be168c0dSopenharmony_ci- if (prim == nullptr) { 4190be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4191be168c0dSopenharmony_ci- return RET_ERROR; 4192be168c0dSopenharmony_ci- } 4193be168c0dSopenharmony_ci- prim->keep_prob = param->ratio_; 4194be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4195be168c0dSopenharmony_ci- return RET_OK; 4196be168c0dSopenharmony_ci-} 4197be168c0dSopenharmony_ci- 4198be168c0dSopenharmony_ci-DropOutGradM::DropOutGradM(DropOutM *node) { 4199be168c0dSopenharmony_ci- CloneOpParam<DropoutParameter>(node->OpParam()); 4200be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_DropoutGrad); 4201be168c0dSopenharmony_ci- set_name(kGradName + "/DropOutGrad"); 4202be168c0dSopenharmony_ci-} 4203be168c0dSopenharmony_ci- 4204be168c0dSopenharmony_ci-int DropOutGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4205be168c0dSopenharmony_ci- auto param = reinterpret_cast<const DropoutParameter *>(OpParam()); 4206be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::DropoutGradT; 4207be168c0dSopenharmony_ci- if (prim == nullptr) { 4208be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4209be168c0dSopenharmony_ci- return RET_ERROR; 4210be168c0dSopenharmony_ci- } 4211be168c0dSopenharmony_ci- prim->keep_prob = param->ratio_; 4212be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4213be168c0dSopenharmony_ci- return RET_OK; 4214be168c0dSopenharmony_ci-} 4215be168c0dSopenharmony_ci- 4216be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Dropout, ReturnNode<DropOutM>); 4217be168c0dSopenharmony_ci- 4218be168c0dSopenharmony_ci-namespace NN { 4219be168c0dSopenharmony_ci-Node *DropOut(float ratio) { 4220be168c0dSopenharmony_ci- auto node = new (std::nothrow) DropOutM(ratio); 4221be168c0dSopenharmony_ci- if (node == nullptr) { 4222be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate dropout node"; 4223be168c0dSopenharmony_ci- return nullptr; 4224be168c0dSopenharmony_ci- } 4225be168c0dSopenharmony_ci- return node; 4226be168c0dSopenharmony_ci-} 4227be168c0dSopenharmony_ci-} // namespace NN 4228be168c0dSopenharmony_ci-} // namespace lite 4229be168c0dSopenharmony_ci-} // namespace mindspore 4230be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/dropout.h b/mindspore/lite/src/expression/ops/dropout.h 4231be168c0dSopenharmony_cideleted file mode 100644 4232be168c0dSopenharmony_ciindex dce87f18..00000000 4233be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/dropout.h 4234be168c0dSopenharmony_ci+++ /dev/null 4235be168c0dSopenharmony_ci@@ -1,42 +0,0 @@ 4236be168c0dSopenharmony_ci-/** 4237be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4238be168c0dSopenharmony_ci- * 4239be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4240be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4241be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4242be168c0dSopenharmony_ci- * 4243be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4244be168c0dSopenharmony_ci- * 4245be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4246be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4247be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4248be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4249be168c0dSopenharmony_ci- * limitations under the License. 4250be168c0dSopenharmony_ci- */ 4251be168c0dSopenharmony_ci- 4252be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_ 4253be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_ 4254be168c0dSopenharmony_ci- 4255be168c0dSopenharmony_ci-#include <vector> 4256be168c0dSopenharmony_ci-#include <memory> 4257be168c0dSopenharmony_ci-#include "src/expression/node.h" 4258be168c0dSopenharmony_ci- 4259be168c0dSopenharmony_ci-namespace mindspore { 4260be168c0dSopenharmony_ci-namespace lite { 4261be168c0dSopenharmony_ci-class DropOutM : public Node { 4262be168c0dSopenharmony_ci- public: 4263be168c0dSopenharmony_ci- DropOutM() = default; 4264be168c0dSopenharmony_ci- explicit DropOutM(float ratio); 4265be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4266be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4267be168c0dSopenharmony_ci-}; 4268be168c0dSopenharmony_ci- 4269be168c0dSopenharmony_ci-class DropOutGradM : public Node { 4270be168c0dSopenharmony_ci- public: 4271be168c0dSopenharmony_ci- explicit DropOutGradM(DropOutM *node); 4272be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4273be168c0dSopenharmony_ci-}; 4274be168c0dSopenharmony_ci-} // namespace lite 4275be168c0dSopenharmony_ci-} // namespace mindspore 4276be168c0dSopenharmony_ci- 4277be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DROPOUT_H_ 4278be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/flatten.cc b/mindspore/lite/src/expression/ops/flatten.cc 4279be168c0dSopenharmony_cideleted file mode 100644 4280be168c0dSopenharmony_ciindex 76564186..00000000 4281be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/flatten.cc 4282be168c0dSopenharmony_ci+++ /dev/null 4283be168c0dSopenharmony_ci@@ -1,71 +0,0 @@ 4284be168c0dSopenharmony_ci-/** 4285be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4286be168c0dSopenharmony_ci- * 4287be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4288be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4289be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4290be168c0dSopenharmony_ci- * 4291be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4292be168c0dSopenharmony_ci- * 4293be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4294be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4295be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4296be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4297be168c0dSopenharmony_ci- * limitations under the License. 4298be168c0dSopenharmony_ci- */ 4299be168c0dSopenharmony_ci- 4300be168c0dSopenharmony_ci-#include "src/expression/ops/flatten.h" 4301be168c0dSopenharmony_ci-#include <vector> 4302be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4303be168c0dSopenharmony_ci-#include "src/expression/import.h" 4304be168c0dSopenharmony_ci-#include "src/expression/ops.h" 4305be168c0dSopenharmony_ci- 4306be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 4307be168c0dSopenharmony_ci-#include "nnacl/op_base.h" 4308be168c0dSopenharmony_ci- 4309be168c0dSopenharmony_ci-namespace mindspore { 4310be168c0dSopenharmony_ci-namespace lite { 4311be168c0dSopenharmony_ci-FlattenM::FlattenM(int dummy) { 4312be168c0dSopenharmony_ci- auto param = reinterpret_cast<OpParameter *>(calloc(C1NUM, sizeof(OpParameter))); 4313be168c0dSopenharmony_ci- if (param == nullptr) { 4314be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter"; 4315be168c0dSopenharmony_ci- return; 4316be168c0dSopenharmony_ci- } 4317be168c0dSopenharmony_ci- SetOpParam(param); 4318be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Flatten); 4319be168c0dSopenharmony_ci- set_name(UniqueName("Flatten")); 4320be168c0dSopenharmony_ci-} 4321be168c0dSopenharmony_ci- 4322be168c0dSopenharmony_ci-std::vector<EXPR *> FlattenM::construct(const std::vector<EXPR *> &inputs) { 4323be168c0dSopenharmony_ci- auto in = inputs; 4324be168c0dSopenharmony_ci- auto y = Node::construct(in); 4325be168c0dSopenharmony_ci- return y; 4326be168c0dSopenharmony_ci-} 4327be168c0dSopenharmony_ci- 4328be168c0dSopenharmony_ci-std::vector<EXPR *> FlattenM::Grad(EXPR *yt) { 4329be168c0dSopenharmony_ci- auto shape_of_x = input(0)->dims(); 4330be168c0dSopenharmony_ci- auto reshape = NN::Reshape(shape_of_x); 4331be168c0dSopenharmony_ci- PushOp(reshape); 4332be168c0dSopenharmony_ci- return (*reshape)({yt}); 4333be168c0dSopenharmony_ci-} 4334be168c0dSopenharmony_ci- 4335be168c0dSopenharmony_ci-int FlattenM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4336be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::DropoutT; 4337be168c0dSopenharmony_ci- if (prim == nullptr) { 4338be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4339be168c0dSopenharmony_ci- return RET_ERROR; 4340be168c0dSopenharmony_ci- } 4341be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4342be168c0dSopenharmony_ci- return RET_OK; 4343be168c0dSopenharmony_ci-} 4344be168c0dSopenharmony_ci- 4345be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Flatten, ReturnNode<FlattenM>); 4346be168c0dSopenharmony_ci- 4347be168c0dSopenharmony_ci-namespace NN { 4348be168c0dSopenharmony_ci-Node *Flatten() { 4349be168c0dSopenharmony_ci- auto node = new (std::nothrow) FlattenM(0); 4350be168c0dSopenharmony_ci- return node; 4351be168c0dSopenharmony_ci-} 4352be168c0dSopenharmony_ci-} // namespace NN 4353be168c0dSopenharmony_ci-} // namespace lite 4354be168c0dSopenharmony_ci-} // namespace mindspore 4355be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/flatten.h b/mindspore/lite/src/expression/ops/flatten.h 4356be168c0dSopenharmony_cideleted file mode 100644 4357be168c0dSopenharmony_ciindex 0be7d5bb..00000000 4358be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/flatten.h 4359be168c0dSopenharmony_ci+++ /dev/null 4360be168c0dSopenharmony_ci@@ -1,36 +0,0 @@ 4361be168c0dSopenharmony_ci-/** 4362be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4363be168c0dSopenharmony_ci- * 4364be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4365be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4366be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4367be168c0dSopenharmony_ci- * 4368be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4369be168c0dSopenharmony_ci- * 4370be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4371be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4372be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4373be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4374be168c0dSopenharmony_ci- * limitations under the License. 4375be168c0dSopenharmony_ci- */ 4376be168c0dSopenharmony_ci- 4377be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_ 4378be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_ 4379be168c0dSopenharmony_ci- 4380be168c0dSopenharmony_ci-#include <memory> 4381be168c0dSopenharmony_ci-#include <vector> 4382be168c0dSopenharmony_ci-#include "src/expression/node.h" 4383be168c0dSopenharmony_ci- 4384be168c0dSopenharmony_ci-namespace mindspore { 4385be168c0dSopenharmony_ci-namespace lite { 4386be168c0dSopenharmony_ci-class FlattenM : public Node { 4387be168c0dSopenharmony_ci- public: 4388be168c0dSopenharmony_ci- FlattenM() = default; 4389be168c0dSopenharmony_ci- explicit FlattenM(int dummy); 4390be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4391be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4392be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 4393be168c0dSopenharmony_ci-}; 4394be168c0dSopenharmony_ci-} // namespace lite 4395be168c0dSopenharmony_ci-} // namespace mindspore 4396be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_FLATTEN_H_ 4397be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/pooling.cc b/mindspore/lite/src/expression/ops/pooling.cc 4398be168c0dSopenharmony_cideleted file mode 100644 4399be168c0dSopenharmony_ciindex 6efbc863..00000000 4400be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/pooling.cc 4401be168c0dSopenharmony_ci+++ /dev/null 4402be168c0dSopenharmony_ci@@ -1,215 +0,0 @@ 4403be168c0dSopenharmony_ci-/** 4404be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4405be168c0dSopenharmony_ci- * 4406be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4407be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4408be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4409be168c0dSopenharmony_ci- * 4410be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4411be168c0dSopenharmony_ci- * 4412be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4413be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4414be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4415be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4416be168c0dSopenharmony_ci- * limitations under the License. 4417be168c0dSopenharmony_ci- */ 4418be168c0dSopenharmony_ci- 4419be168c0dSopenharmony_ci-#include "src/expression/ops/pooling.h" 4420be168c0dSopenharmony_ci-#include "src/expression/ops.h" 4421be168c0dSopenharmony_ci-#include "src/expression/import.h" 4422be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 4423be168c0dSopenharmony_ci-#include "src/expression/ops/transpose.h" 4424be168c0dSopenharmony_ci- 4425be168c0dSopenharmony_ci-namespace mindspore { 4426be168c0dSopenharmony_ci-namespace lite { 4427be168c0dSopenharmony_ci-PoolingM::PoolingM(const PoolingConfig &cfg) : Node() { 4428be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(PoolingParameter)); 4429be168c0dSopenharmony_ci- if (op_param == nullptr) { 4430be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate PoolingParameter"; 4431be168c0dSopenharmony_ci- return; 4432be168c0dSopenharmony_ci- } 4433be168c0dSopenharmony_ci- SetOpParam(op_param); 4434be168c0dSopenharmony_ci- PoolingParameter *pool_param = reinterpret_cast<PoolingParameter *>(OpParam()); 4435be168c0dSopenharmony_ci- 4436be168c0dSopenharmony_ci- pool_param->window_h_ = cfg.kernel_size_[0]; 4437be168c0dSopenharmony_ci- pool_param->window_w_ = cfg.kernel_size_[1]; 4438be168c0dSopenharmony_ci- pool_param->stride_h_ = cfg.stride_[0]; 4439be168c0dSopenharmony_ci- pool_param->stride_w_ = cfg.stride_[1]; 4440be168c0dSopenharmony_ci- auto pad_mode = GetMode(cfg.pad_mode_); 4441be168c0dSopenharmony_ci- if (pad_mode == -1) { 4442be168c0dSopenharmony_ci- MS_LOG(ERROR) << "bad pad mode"; 4443be168c0dSopenharmony_ci- return; 4444be168c0dSopenharmony_ci- } 4445be168c0dSopenharmony_ci- pool_param->pad_mode_ = static_cast<PadType>(pad_mode + Pad_pad); 4446be168c0dSopenharmony_ci- pool_param->round_type_ = RoundType_Floor; 4447be168c0dSopenharmony_ci- pool_param->act_type_ = ActType_No; 4448be168c0dSopenharmony_ci-} 4449be168c0dSopenharmony_ci- 4450be168c0dSopenharmony_ci-std::vector<EXPR *> PoolingM::construct(const std::vector<EXPR *> &inputs) { 4451be168c0dSopenharmony_ci- auto in = inputs; 4452be168c0dSopenharmony_ci- auto x = in.front(); 4453be168c0dSopenharmony_ci- if (x->format() != NHWC && x->dims().size() == C4NUM) { 4454be168c0dSopenharmony_ci- x = TransposeM::TransposeCHW2HWC(x); 4455be168c0dSopenharmony_ci- x->node()->set_name(name() + "/" + x->node()->name()); 4456be168c0dSopenharmony_ci- PushOp(x->node()); 4457be168c0dSopenharmony_ci- in.at(0) = x; 4458be168c0dSopenharmony_ci- } 4459be168c0dSopenharmony_ci- auto y = Node::construct(in); 4460be168c0dSopenharmony_ci- return y; 4461be168c0dSopenharmony_ci-} 4462be168c0dSopenharmony_ci- 4463be168c0dSopenharmony_ci-int PoolingM::GetMode(std::string mode) { 4464be168c0dSopenharmony_ci- const std::vector<std::string> list = {"same", "valid"}; 4465be168c0dSopenharmony_ci- auto itr = std::find(list.begin(), list.end(), mode); 4466be168c0dSopenharmony_ci- if (itr == list.end()) { 4467be168c0dSopenharmony_ci- MS_LOG(ERROR) << "illegal mode" << mode; 4468be168c0dSopenharmony_ci- return -1; 4469be168c0dSopenharmony_ci- } 4470be168c0dSopenharmony_ci- return std::distance(list.begin(), itr); 4471be168c0dSopenharmony_ci-} 4472be168c0dSopenharmony_ci- 4473be168c0dSopenharmony_ci-void PoolingM::UpdateRoundMode(const PoolingParameter *param, schema::RoundMode *round_mode) { 4474be168c0dSopenharmony_ci- switch (param->round_type_) { 4475be168c0dSopenharmony_ci- case RoundType_Floor: 4476be168c0dSopenharmony_ci- *round_mode = schema::RoundMode_FLOOR; 4477be168c0dSopenharmony_ci- break; 4478be168c0dSopenharmony_ci- case RoundType_Ceil: 4479be168c0dSopenharmony_ci- *round_mode = schema::RoundMode_CEIL; 4480be168c0dSopenharmony_ci- break; 4481be168c0dSopenharmony_ci- default: 4482be168c0dSopenharmony_ci- *round_mode = schema::RoundMode_FLOOR; 4483be168c0dSopenharmony_ci- break; 4484be168c0dSopenharmony_ci- } 4485be168c0dSopenharmony_ci-} 4486be168c0dSopenharmony_ci- 4487be168c0dSopenharmony_ci-template <typename T> 4488be168c0dSopenharmony_ci-int PoolingM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4489be168c0dSopenharmony_ci- auto param = reinterpret_cast<const PoolingParameter *>(OpParam()); 4490be168c0dSopenharmony_ci- auto prim = new (std::nothrow) T; 4491be168c0dSopenharmony_ci- if (prim == nullptr) { 4492be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4493be168c0dSopenharmony_ci- return RET_ERROR; 4494be168c0dSopenharmony_ci- } 4495be168c0dSopenharmony_ci- prim->kernel_size = {param->window_h_, param->window_w_}; 4496be168c0dSopenharmony_ci- prim->strides = {param->stride_h_, param->stride_w_}; 4497be168c0dSopenharmony_ci- prim->pad = {param->pad_u_, param->pad_d_, param->pad_l_, param->pad_r_}; 4498be168c0dSopenharmony_ci- prim->pad_mode = static_cast<schema::PadMode>(param->pad_mode_); 4499be168c0dSopenharmony_ci- UpdateRoundMode(param, &prim->round_mode); 4500be168c0dSopenharmony_ci- prim->global = param->global_; 4501be168c0dSopenharmony_ci- prim->activation_type = schema::ActivationType_NO_ACTIVATION; 4502be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4503be168c0dSopenharmony_ci- return RET_OK; 4504be168c0dSopenharmony_ci-} 4505be168c0dSopenharmony_ci- 4506be168c0dSopenharmony_ci-template <typename T> 4507be168c0dSopenharmony_ci-int PoolingM::UnPopulateGrad(const std::unique_ptr<schema::CNodeT> &cnode) { 4508be168c0dSopenharmony_ci- auto param = reinterpret_cast<const PoolingParameter *>(OpParam()); 4509be168c0dSopenharmony_ci- auto prim = new (std::nothrow) T; 4510be168c0dSopenharmony_ci- if (prim == nullptr) { 4511be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4512be168c0dSopenharmony_ci- return RET_ERROR; 4513be168c0dSopenharmony_ci- } 4514be168c0dSopenharmony_ci- prim->kernel_size = {param->window_h_, param->window_w_}; 4515be168c0dSopenharmony_ci- prim->strides = {param->stride_h_, param->stride_w_}; 4516be168c0dSopenharmony_ci- prim->pad_mode = static_cast<schema::PadMode>(param->pad_mode_); 4517be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4518be168c0dSopenharmony_ci- return RET_OK; 4519be168c0dSopenharmony_ci-} 4520be168c0dSopenharmony_ci- 4521be168c0dSopenharmony_ci-// Max pooling Definition 4522be168c0dSopenharmony_ci-MaxPoolM::MaxPoolM(const PoolingConfig &cfg) : PoolingM(cfg) { 4523be168c0dSopenharmony_ci- auto param = reinterpret_cast<PoolingParameter *>(OpParam()); 4524be168c0dSopenharmony_ci- param->pool_mode_ = PoolMode_MaxPool; 4525be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_MaxPoolFusion); 4526be168c0dSopenharmony_ci- set_name(UniqueName("MaxPool")); 4527be168c0dSopenharmony_ci-} 4528be168c0dSopenharmony_ci- 4529be168c0dSopenharmony_ci-int MaxPoolM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4530be168c0dSopenharmony_ci- return PoolingM::UnPopulate<schema::MaxPoolFusionT>(cnode); 4531be168c0dSopenharmony_ci-} 4532be168c0dSopenharmony_ci- 4533be168c0dSopenharmony_ci-std::vector<EXPR *> MaxPoolM::Grad(EXPR *yt) { 4534be168c0dSopenharmony_ci- auto in = yt; 4535be168c0dSopenharmony_ci- if (yt->format() != NHWC && yt->dims().size() == C4NUM) { 4536be168c0dSopenharmony_ci- in = TransposeM::TransposeCHW2HWC(yt); 4537be168c0dSopenharmony_ci- in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); 4538be168c0dSopenharmony_ci- PushOp(in->node()); 4539be168c0dSopenharmony_ci- } 4540be168c0dSopenharmony_ci- auto pool_grad = new (std::nothrow) MaxPoolGradM(this); 4541be168c0dSopenharmony_ci- PushOp(pool_grad); 4542be168c0dSopenharmony_ci- return (*pool_grad)({input(0), output(0), in}); 4543be168c0dSopenharmony_ci-} 4544be168c0dSopenharmony_ci- 4545be168c0dSopenharmony_ci-static ImportReg maxPoolReg(schema::PrimitiveType_MaxPoolFusion, ReturnNode<MaxPoolM>); 4546be168c0dSopenharmony_ci- 4547be168c0dSopenharmony_ci-// Avg pooling Definition 4548be168c0dSopenharmony_ci-AvgPoolM::AvgPoolM(const PoolingConfig &cfg) : PoolingM(cfg) { 4549be168c0dSopenharmony_ci- auto param = reinterpret_cast<PoolingParameter *>(OpParam()); 4550be168c0dSopenharmony_ci- param->pool_mode_ = PoolMode_AvgPool; 4551be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_AvgPoolFusion); 4552be168c0dSopenharmony_ci- set_name(UniqueName("AvgPool")); 4553be168c0dSopenharmony_ci-} 4554be168c0dSopenharmony_ci- 4555be168c0dSopenharmony_ci-int AvgPoolM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4556be168c0dSopenharmony_ci- return PoolingM::UnPopulate<schema::AvgPoolFusionT>(cnode); 4557be168c0dSopenharmony_ci-} 4558be168c0dSopenharmony_ci- 4559be168c0dSopenharmony_ci-std::vector<EXPR *> AvgPoolM::Grad(EXPR *yt) { 4560be168c0dSopenharmony_ci- auto in = yt; 4561be168c0dSopenharmony_ci- if (yt->format() != NHWC && yt->dims().size() == C4NUM) { 4562be168c0dSopenharmony_ci- in = TransposeM::TransposeCHW2HWC(yt); 4563be168c0dSopenharmony_ci- in->node()->set_name(kGradName + "/" + name() + "/" + in->node()->name()); 4564be168c0dSopenharmony_ci- PushOp(in->node()); 4565be168c0dSopenharmony_ci- } 4566be168c0dSopenharmony_ci- auto pool_grad = new (std::nothrow) AvgPoolGradM(this); 4567be168c0dSopenharmony_ci- PushOp(pool_grad); 4568be168c0dSopenharmony_ci- return (*pool_grad)({input(0), output(0), in}); 4569be168c0dSopenharmony_ci-} 4570be168c0dSopenharmony_ci- 4571be168c0dSopenharmony_ci-static ImportReg avgPoolReg(schema::PrimitiveType_AvgPoolFusion, ReturnNode<AvgPoolM>); 4572be168c0dSopenharmony_ci- 4573be168c0dSopenharmony_ci-// Max Pool Grad Definition 4574be168c0dSopenharmony_ci-MaxPoolGradM::MaxPoolGradM(MaxPoolM *node) { 4575be168c0dSopenharmony_ci- Node(); 4576be168c0dSopenharmony_ci- CloneOpParam<PoolingParameter>(node->OpParam()); 4577be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_MaxPoolGrad); 4578be168c0dSopenharmony_ci- set_name(kGradName + "/" + node->name() + "/MaxPoolGrad"); 4579be168c0dSopenharmony_ci-} 4580be168c0dSopenharmony_ci- 4581be168c0dSopenharmony_ci-int MaxPoolGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4582be168c0dSopenharmony_ci- return PoolingM::UnPopulateGrad<schema::MaxPoolGradT>(cnode); 4583be168c0dSopenharmony_ci-} 4584be168c0dSopenharmony_ci- 4585be168c0dSopenharmony_ci-// Avg Pool Grad Definition 4586be168c0dSopenharmony_ci-AvgPoolGradM::AvgPoolGradM(AvgPoolM *node) { 4587be168c0dSopenharmony_ci- Node(); 4588be168c0dSopenharmony_ci- CloneOpParam<PoolingParameter>(node->OpParam()); 4589be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_AvgPoolGrad); 4590be168c0dSopenharmony_ci- set_name(kGradName + "/" + node->name() + "/AvgPoolGrad"); 4591be168c0dSopenharmony_ci-} 4592be168c0dSopenharmony_ci- 4593be168c0dSopenharmony_ci-int AvgPoolGradM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4594be168c0dSopenharmony_ci- return PoolingM::UnPopulateGrad<schema::AvgPoolGradT>(cnode); 4595be168c0dSopenharmony_ci-} 4596be168c0dSopenharmony_ci- 4597be168c0dSopenharmony_ci-namespace NN { 4598be168c0dSopenharmony_ci-Node *MaxPool2D(const PoolingConfig &cfg) { 4599be168c0dSopenharmony_ci- auto c = new (std::nothrow) MaxPoolM(cfg); 4600be168c0dSopenharmony_ci- if (c == nullptr) { 4601be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate max pool object"; 4602be168c0dSopenharmony_ci- return nullptr; 4603be168c0dSopenharmony_ci- } 4604be168c0dSopenharmony_ci- return c; 4605be168c0dSopenharmony_ci-} 4606be168c0dSopenharmony_ci- 4607be168c0dSopenharmony_ci-Node *AvgPool2D(const PoolingConfig &cfg) { 4608be168c0dSopenharmony_ci- auto c = new (std::nothrow) AvgPoolM(cfg); 4609be168c0dSopenharmony_ci- if (c == nullptr) { 4610be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate average pool object"; 4611be168c0dSopenharmony_ci- return nullptr; 4612be168c0dSopenharmony_ci- } 4613be168c0dSopenharmony_ci- return c; 4614be168c0dSopenharmony_ci-} 4615be168c0dSopenharmony_ci-} // namespace NN 4616be168c0dSopenharmony_ci-} // namespace lite 4617be168c0dSopenharmony_ci-} // namespace mindspore 4618be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/pooling.h b/mindspore/lite/src/expression/ops/pooling.h 4619be168c0dSopenharmony_cideleted file mode 100644 4620be168c0dSopenharmony_ciindex 881996a3..00000000 4621be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/pooling.h 4622be168c0dSopenharmony_ci+++ /dev/null 4623be168c0dSopenharmony_ci@@ -1,74 +0,0 @@ 4624be168c0dSopenharmony_ci-/** 4625be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4626be168c0dSopenharmony_ci- * 4627be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4628be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4629be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4630be168c0dSopenharmony_ci- * 4631be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4632be168c0dSopenharmony_ci- * 4633be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4634be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4635be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4636be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4637be168c0dSopenharmony_ci- * limitations under the License. 4638be168c0dSopenharmony_ci- */ 4639be168c0dSopenharmony_ci- 4640be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_ 4641be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_ 4642be168c0dSopenharmony_ci- 4643be168c0dSopenharmony_ci-#include <vector> 4644be168c0dSopenharmony_ci-#include <string> 4645be168c0dSopenharmony_ci-#include <memory> 4646be168c0dSopenharmony_ci-#include "src/expression/node.h" 4647be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4648be168c0dSopenharmony_ci-#include "src/expression/cfg.h" 4649be168c0dSopenharmony_ci-#include "nnacl/pooling_parameter.h" 4650be168c0dSopenharmony_ci- 4651be168c0dSopenharmony_ci-namespace mindspore { 4652be168c0dSopenharmony_ci-namespace lite { 4653be168c0dSopenharmony_ci-class PoolingM : public Node { 4654be168c0dSopenharmony_ci- public: 4655be168c0dSopenharmony_ci- PoolingM() = default; 4656be168c0dSopenharmony_ci- explicit PoolingM(const PoolingConfig &cfg); 4657be168c0dSopenharmony_ci- template <typename T> 4658be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode); 4659be168c0dSopenharmony_ci- template <typename T> 4660be168c0dSopenharmony_ci- int UnPopulateGrad(const std::unique_ptr<schema::CNodeT> &cnode); 4661be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs); 4662be168c0dSopenharmony_ci- 4663be168c0dSopenharmony_ci- private: 4664be168c0dSopenharmony_ci- void UpdateRoundMode(const PoolingParameter *param, enum schema::RoundMode *round_mode); 4665be168c0dSopenharmony_ci- int GetMode(std::string mode); 4666be168c0dSopenharmony_ci-}; 4667be168c0dSopenharmony_ci- 4668be168c0dSopenharmony_ci-class MaxPoolM : public PoolingM { 4669be168c0dSopenharmony_ci- public: 4670be168c0dSopenharmony_ci- MaxPoolM() = default; 4671be168c0dSopenharmony_ci- explicit MaxPoolM(const PoolingConfig &cfg); 4672be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4673be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4674be168c0dSopenharmony_ci-}; 4675be168c0dSopenharmony_ci- 4676be168c0dSopenharmony_ci-class AvgPoolM : public PoolingM { 4677be168c0dSopenharmony_ci- public: 4678be168c0dSopenharmony_ci- AvgPoolM() = default; 4679be168c0dSopenharmony_ci- explicit AvgPoolM(const PoolingConfig &cfg); 4680be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4681be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4682be168c0dSopenharmony_ci-}; 4683be168c0dSopenharmony_ci- 4684be168c0dSopenharmony_ci-class MaxPoolGradM : public PoolingM { 4685be168c0dSopenharmony_ci- public: 4686be168c0dSopenharmony_ci- explicit MaxPoolGradM(MaxPoolM *node); 4687be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4688be168c0dSopenharmony_ci-}; 4689be168c0dSopenharmony_ci- 4690be168c0dSopenharmony_ci-class AvgPoolGradM : public PoolingM { 4691be168c0dSopenharmony_ci- public: 4692be168c0dSopenharmony_ci- explicit AvgPoolGradM(AvgPoolM *node); 4693be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4694be168c0dSopenharmony_ci-}; 4695be168c0dSopenharmony_ci-} // namespace lite 4696be168c0dSopenharmony_ci-} // namespace mindspore 4697be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_POOLING_H_ 4698be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/reduce.cc b/mindspore/lite/src/expression/ops/reduce.cc 4699be168c0dSopenharmony_cideleted file mode 100644 4700be168c0dSopenharmony_ciindex 055d2026..00000000 4701be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/reduce.cc 4702be168c0dSopenharmony_ci+++ /dev/null 4703be168c0dSopenharmony_ci@@ -1,126 +0,0 @@ 4704be168c0dSopenharmony_ci-/** 4705be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4706be168c0dSopenharmony_ci- * 4707be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4708be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4709be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4710be168c0dSopenharmony_ci- * 4711be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4712be168c0dSopenharmony_ci- * 4713be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4714be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4715be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4716be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4717be168c0dSopenharmony_ci- * limitations under the License. 4718be168c0dSopenharmony_ci- */ 4719be168c0dSopenharmony_ci- 4720be168c0dSopenharmony_ci-#include "src/expression/ops/reduce.h" 4721be168c0dSopenharmony_ci-#include <functional> 4722be168c0dSopenharmony_ci-#include "src/expression/ops/tile.h" 4723be168c0dSopenharmony_ci-#include "src/expression/ops/reshape.h" 4724be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic.h" 4725be168c0dSopenharmony_ci-#include "src/expression/ops.h" 4726be168c0dSopenharmony_ci-#include "src/expression/ops_utils.h" 4727be168c0dSopenharmony_ci-#include "src/expression/import.h" 4728be168c0dSopenharmony_ci-#include "nnacl/reduce_parameter.h" 4729be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 4730be168c0dSopenharmony_ci- 4731be168c0dSopenharmony_ci-namespace mindspore { 4732be168c0dSopenharmony_ci-namespace lite { 4733be168c0dSopenharmony_ci-ReduceM::ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector<int> &axis) : Node() { 4734be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 4735be168c0dSopenharmony_ci- ReduceParameter *param = reinterpret_cast<ReduceParameter *>(calloc(1, sizeof(ReduceParameter))); 4736be168c0dSopenharmony_ci- if (param == nullptr) { 4737be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter"; 4738be168c0dSopenharmony_ci- return; 4739be168c0dSopenharmony_ci- } 4740be168c0dSopenharmony_ci- param->mode_ = mode; 4741be168c0dSopenharmony_ci- param->keep_dims_ = keep_dims; 4742be168c0dSopenharmony_ci- param->reduce_to_end_ = false; 4743be168c0dSopenharmony_ci- param->coeff = 1.f; 4744be168c0dSopenharmony_ci- SetOpParam(param); 4745be168c0dSopenharmony_ci- set_name(UniqueName("Reduce")); 4746be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_ReduceFusion); 4747be168c0dSopenharmony_ci- Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(axis.size())}, kNumberTypeInt32, KHWC, "axis", axis.data()); 4748be168c0dSopenharmony_ci-} 4749be168c0dSopenharmony_ci- 4750be168c0dSopenharmony_ci-int ReduceM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4751be168c0dSopenharmony_ci- auto reduce_param = reinterpret_cast<const ReduceParameter *>(OpParam()); 4752be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::ReduceFusionT; 4753be168c0dSopenharmony_ci- if (prim == nullptr) { 4754be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4755be168c0dSopenharmony_ci- return RET_ERROR; 4756be168c0dSopenharmony_ci- } 4757be168c0dSopenharmony_ci- prim->keep_dims = reduce_param->keep_dims_; 4758be168c0dSopenharmony_ci- prim->mode = static_cast<schema::ReduceMode>(reduce_param->mode_); 4759be168c0dSopenharmony_ci- prim->coeff = reduce_param->coeff; 4760be168c0dSopenharmony_ci- prim->reduce_to_end = reduce_param->reduce_to_end_; 4761be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4762be168c0dSopenharmony_ci- return RET_OK; 4763be168c0dSopenharmony_ci-} 4764be168c0dSopenharmony_ci- 4765be168c0dSopenharmony_ci-std::vector<EXPR *> ReduceM::Grad(EXPR *yt) { 4766be168c0dSopenharmony_ci- auto shape_of_x = input(0)->dims(); 4767be168c0dSopenharmony_ci- std::vector<int> shape_of_axis; 4768be168c0dSopenharmony_ci- 4769be168c0dSopenharmony_ci- auto data = input(1)->node()->data()->data().data(); 4770be168c0dSopenharmony_ci- int size = input(1)->dims().at(0); 4771be168c0dSopenharmony_ci- auto int_data = reinterpret_cast<int *>(data); 4772be168c0dSopenharmony_ci- for (int i = 0; i < size; i++) { 4773be168c0dSopenharmony_ci- shape_of_axis.push_back(int_data[i]); 4774be168c0dSopenharmony_ci- } 4775be168c0dSopenharmony_ci- 4776be168c0dSopenharmony_ci- // assume no dynamic shape 4777be168c0dSopenharmony_ci- ShapeReduce reduce_shape; 4778be168c0dSopenharmony_ci- auto output_shape_kept_dims = ShapeReduce()(shape_of_x, shape_of_axis); 4779be168c0dSopenharmony_ci- auto tile_scaling = VectorDiv()(shape_of_x, output_shape_kept_dims); 4780be168c0dSopenharmony_ci- auto reshape = NN::Reshape(output_shape_kept_dims); 4781be168c0dSopenharmony_ci- PushOp(reshape); 4782be168c0dSopenharmony_ci- reshape->set_name(name() + "/reshape"); 4783be168c0dSopenharmony_ci- auto g = (*reshape)({yt}).front(); 4784be168c0dSopenharmony_ci- auto tile = NN::Tile(tile_scaling); 4785be168c0dSopenharmony_ci- PushOp(tile); 4786be168c0dSopenharmony_ci- tile->set_name(name() + "/tile"); 4787be168c0dSopenharmony_ci- auto sum_grad = (*tile)({g}).front(); 4788be168c0dSopenharmony_ci- auto reduce_param = reinterpret_cast<const ReduceParameter *>(OpParam()); 4789be168c0dSopenharmony_ci- if (reduce_param->mode_ == schema::ReduceMode_ReduceSum) { 4790be168c0dSopenharmony_ci- return {sum_grad}; 4791be168c0dSopenharmony_ci- } else if (reduce_param->mode_ == schema::ReduceMode_ReduceMean) { 4792be168c0dSopenharmony_ci- auto shape_of_y = output(0)->dims(); 4793be168c0dSopenharmony_ci- auto shape_x_mul = std::accumulate(shape_of_x.begin(), shape_of_x.end(), 1, std::multiplies<int>()); 4794be168c0dSopenharmony_ci- auto shape_y_mul = std::accumulate(shape_of_y.begin(), shape_of_y.end(), 1, std::multiplies<int>()); 4795be168c0dSopenharmony_ci- auto div_shape = static_cast<float>(shape_x_mul) / static_cast<float>(shape_y_mul); 4796be168c0dSopenharmony_ci- auto div_op = NN::Div(); 4797be168c0dSopenharmony_ci- PushOp(div_op); 4798be168c0dSopenharmony_ci- auto d = div_op->CreateConstTensor(C1NUM, {1}, kNumberTypeFloat32, KHWC, "div_shape", &div_shape); 4799be168c0dSopenharmony_ci- auto dx = (*div_op)({sum_grad, d->expr()}); 4800be168c0dSopenharmony_ci- return dx; 4801be168c0dSopenharmony_ci- } else { 4802be168c0dSopenharmony_ci- return {}; 4803be168c0dSopenharmony_ci- } 4804be168c0dSopenharmony_ci-} 4805be168c0dSopenharmony_ci- 4806be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_ReduceFusion, ReturnNode<ReduceM>); 4807be168c0dSopenharmony_ci- 4808be168c0dSopenharmony_ci-namespace NN { 4809be168c0dSopenharmony_ci-Node *ReduceSum(bool keep_dims, const std::vector<int> &axis) { 4810be168c0dSopenharmony_ci- auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceSum, keep_dims, axis); 4811be168c0dSopenharmony_ci- if (node == nullptr) { 4812be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate reduce sum node"; 4813be168c0dSopenharmony_ci- return nullptr; 4814be168c0dSopenharmony_ci- } 4815be168c0dSopenharmony_ci- node->set_name(Node::UniqueName("ReduceSum")); 4816be168c0dSopenharmony_ci- return node; 4817be168c0dSopenharmony_ci-} 4818be168c0dSopenharmony_ci-Node *ReduceMean(bool keep_dims, const std::vector<int> &axis) { 4819be168c0dSopenharmony_ci- auto node = new (std::nothrow) ReduceM(schema::ReduceMode_ReduceMean, keep_dims, axis); 4820be168c0dSopenharmony_ci- if (node == nullptr) { 4821be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate reduce mean node"; 4822be168c0dSopenharmony_ci- return nullptr; 4823be168c0dSopenharmony_ci- } 4824be168c0dSopenharmony_ci- node->set_name(Node::UniqueName("ReduceMean")); 4825be168c0dSopenharmony_ci- return node; 4826be168c0dSopenharmony_ci-} 4827be168c0dSopenharmony_ci-} // namespace NN 4828be168c0dSopenharmony_ci-} // namespace lite 4829be168c0dSopenharmony_ci-} // namespace mindspore 4830be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/reduce.h b/mindspore/lite/src/expression/ops/reduce.h 4831be168c0dSopenharmony_cideleted file mode 100644 4832be168c0dSopenharmony_ciindex 1ca0e921..00000000 4833be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/reduce.h 4834be168c0dSopenharmony_ci+++ /dev/null 4835be168c0dSopenharmony_ci@@ -1,42 +0,0 @@ 4836be168c0dSopenharmony_ci-/** 4837be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4838be168c0dSopenharmony_ci- * 4839be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4840be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4841be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4842be168c0dSopenharmony_ci- * 4843be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4844be168c0dSopenharmony_ci- * 4845be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4846be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4847be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4848be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4849be168c0dSopenharmony_ci- * limitations under the License. 4850be168c0dSopenharmony_ci- */ 4851be168c0dSopenharmony_ci- 4852be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_ 4853be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_ 4854be168c0dSopenharmony_ci- 4855be168c0dSopenharmony_ci-#include <vector> 4856be168c0dSopenharmony_ci-#include <memory> 4857be168c0dSopenharmony_ci-#include "src/expression/node.h" 4858be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4859be168c0dSopenharmony_ci- 4860be168c0dSopenharmony_ci-namespace mindspore { 4861be168c0dSopenharmony_ci-namespace lite { 4862be168c0dSopenharmony_ci-class ReduceM : public Node { 4863be168c0dSopenharmony_ci- public: 4864be168c0dSopenharmony_ci- ReduceM() = default; 4865be168c0dSopenharmony_ci- ReduceM(schema::ReduceMode mode, bool keep_dims, const std::vector<int> &axis); 4866be168c0dSopenharmony_ci- Param *weight() override { return input(1)->node()->data(); } 4867be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4868be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4869be168c0dSopenharmony_ci-}; 4870be168c0dSopenharmony_ci- 4871be168c0dSopenharmony_ci-namespace NN { 4872be168c0dSopenharmony_ci-Node *ReduceMean(bool keep_dims, const std::vector<int> &axis); 4873be168c0dSopenharmony_ci-Node *ReduceSum(bool keep_dims, const std::vector<int> &axis); 4874be168c0dSopenharmony_ci-} // namespace NN 4875be168c0dSopenharmony_ci-} // namespace lite 4876be168c0dSopenharmony_ci-} // namespace mindspore 4877be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_REDUCE_H_ 4878be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/reshape.cc b/mindspore/lite/src/expression/ops/reshape.cc 4879be168c0dSopenharmony_cideleted file mode 100644 4880be168c0dSopenharmony_ciindex d6cc0433..00000000 4881be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/reshape.cc 4882be168c0dSopenharmony_ci+++ /dev/null 4883be168c0dSopenharmony_ci@@ -1,74 +0,0 @@ 4884be168c0dSopenharmony_ci-/** 4885be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4886be168c0dSopenharmony_ci- * 4887be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4888be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4889be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4890be168c0dSopenharmony_ci- * 4891be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4892be168c0dSopenharmony_ci- * 4893be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4894be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4895be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4896be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4897be168c0dSopenharmony_ci- * limitations under the License. 4898be168c0dSopenharmony_ci- */ 4899be168c0dSopenharmony_ci- 4900be168c0dSopenharmony_ci-#include "src/expression/ops/reshape.h" 4901be168c0dSopenharmony_ci-#include "src/expression/ops.h" 4902be168c0dSopenharmony_ci-#include "nnacl/reshape_parameter.h" 4903be168c0dSopenharmony_ci-#include "src/expression/import.h" 4904be168c0dSopenharmony_ci- 4905be168c0dSopenharmony_ci-namespace mindspore { 4906be168c0dSopenharmony_ci-namespace lite { 4907be168c0dSopenharmony_ci-ReshapeM::ReshapeM(const std::vector<int> &shape) : Node() { 4908be168c0dSopenharmony_ci- auto op_param = calloc(1, sizeof(ReshapeParameter)); 4909be168c0dSopenharmony_ci- if (op_param == nullptr) { 4910be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ReshapeParameter"; 4911be168c0dSopenharmony_ci- return; 4912be168c0dSopenharmony_ci- } 4913be168c0dSopenharmony_ci- set_name(UniqueName("Reshape")); 4914be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 4915be168c0dSopenharmony_ci- SetOpParam(op_param); 4916be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Reshape); 4917be168c0dSopenharmony_ci- 4918be168c0dSopenharmony_ci- ReshapeParameter *reshape_param = reinterpret_cast<ReshapeParameter *>(opParam_.get()); 4919be168c0dSopenharmony_ci- reshape_param->shape_dim_ = shape.size(); 4920be168c0dSopenharmony_ci- for (int i = 0; i < reshape_param->shape_dim_; i++) { 4921be168c0dSopenharmony_ci- reshape_param->shape_[i] = shape.at(i); 4922be168c0dSopenharmony_ci- } 4923be168c0dSopenharmony_ci- Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(shape.size())}, kNumberTypeInt32, KHWC, "shape", shape.data()); 4924be168c0dSopenharmony_ci-} 4925be168c0dSopenharmony_ci- 4926be168c0dSopenharmony_ci-int ReshapeM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 4927be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::ReshapeT; 4928be168c0dSopenharmony_ci- if (prim == nullptr) { 4929be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 4930be168c0dSopenharmony_ci- return RET_ERROR; 4931be168c0dSopenharmony_ci- } 4932be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 4933be168c0dSopenharmony_ci- return RET_OK; 4934be168c0dSopenharmony_ci-} 4935be168c0dSopenharmony_ci- 4936be168c0dSopenharmony_ci-std::vector<EXPR *> ReshapeM::Grad(EXPR *yt) { 4937be168c0dSopenharmony_ci- auto shape_of_x = input(0)->dims(); 4938be168c0dSopenharmony_ci- auto reshape = NN::Reshape(shape_of_x); 4939be168c0dSopenharmony_ci- PushOp(reshape); 4940be168c0dSopenharmony_ci- return (*reshape)({yt}); 4941be168c0dSopenharmony_ci-} 4942be168c0dSopenharmony_ci- 4943be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Reshape, ReturnNode<ReshapeM>); 4944be168c0dSopenharmony_ci- 4945be168c0dSopenharmony_ci-namespace NN { 4946be168c0dSopenharmony_ci-Node *Reshape(const std::vector<int> &shape) { 4947be168c0dSopenharmony_ci- auto node = new (std::nothrow) ReshapeM(shape); 4948be168c0dSopenharmony_ci- if (node == nullptr) { 4949be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate reshape node"; 4950be168c0dSopenharmony_ci- return nullptr; 4951be168c0dSopenharmony_ci- } 4952be168c0dSopenharmony_ci- node->set_name(Node::UniqueName("Reshape")); 4953be168c0dSopenharmony_ci- return node; 4954be168c0dSopenharmony_ci-} 4955be168c0dSopenharmony_ci-} // namespace NN 4956be168c0dSopenharmony_ci-} // namespace lite 4957be168c0dSopenharmony_ci-} // namespace mindspore 4958be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/reshape.h b/mindspore/lite/src/expression/ops/reshape.h 4959be168c0dSopenharmony_cideleted file mode 100644 4960be168c0dSopenharmony_ciindex a7c15377..00000000 4961be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/reshape.h 4962be168c0dSopenharmony_ci+++ /dev/null 4963be168c0dSopenharmony_ci@@ -1,37 +0,0 @@ 4964be168c0dSopenharmony_ci-/** 4965be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 4966be168c0dSopenharmony_ci- * 4967be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 4968be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 4969be168c0dSopenharmony_ci- * You may obtain a copy of the License at 4970be168c0dSopenharmony_ci- * 4971be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 4972be168c0dSopenharmony_ci- * 4973be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 4974be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 4975be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4976be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 4977be168c0dSopenharmony_ci- * limitations under the License. 4978be168c0dSopenharmony_ci- */ 4979be168c0dSopenharmony_ci- 4980be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_ 4981be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_ 4982be168c0dSopenharmony_ci- 4983be168c0dSopenharmony_ci-#include <vector> 4984be168c0dSopenharmony_ci-#include <memory> 4985be168c0dSopenharmony_ci- 4986be168c0dSopenharmony_ci-#include "src/expression/node.h" 4987be168c0dSopenharmony_ci-#include "inner/model_generated.h" 4988be168c0dSopenharmony_ci- 4989be168c0dSopenharmony_ci-namespace mindspore { 4990be168c0dSopenharmony_ci-namespace lite { 4991be168c0dSopenharmony_ci-class ReshapeM : public Node { 4992be168c0dSopenharmony_ci- public: 4993be168c0dSopenharmony_ci- ReshapeM() = default; 4994be168c0dSopenharmony_ci- explicit ReshapeM(const std::vector<int> &shape); 4995be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 4996be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 4997be168c0dSopenharmony_ci-}; 4998be168c0dSopenharmony_ci-} // namespace lite 4999be168c0dSopenharmony_ci-} // namespace mindspore 5000be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_RESHAPE_H_ 5001be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/softmax.cc b/mindspore/lite/src/expression/ops/softmax.cc 5002be168c0dSopenharmony_cideleted file mode 100644 5003be168c0dSopenharmony_ciindex 29ab5ee1..00000000 5004be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/softmax.cc 5005be168c0dSopenharmony_ci+++ /dev/null 5006be168c0dSopenharmony_ci@@ -1,119 +0,0 @@ 5007be168c0dSopenharmony_ci-/** 5008be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5009be168c0dSopenharmony_ci- * 5010be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5011be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5012be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5013be168c0dSopenharmony_ci- * 5014be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5015be168c0dSopenharmony_ci- * 5016be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5017be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5018be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5019be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5020be168c0dSopenharmony_ci- * limitations under the License. 5021be168c0dSopenharmony_ci- */ 5022be168c0dSopenharmony_ci- 5023be168c0dSopenharmony_ci-#include "src/expression/ops/softmax.h" 5024be168c0dSopenharmony_ci-#include "nnacl/softmax_parameter.h" 5025be168c0dSopenharmony_ci-#include "inner/model_generated.h" 5026be168c0dSopenharmony_ci-#include "src/expression/import.h" 5027be168c0dSopenharmony_ci-#include "src/expression/ops/reshape.h" 5028be168c0dSopenharmony_ci-#include "src/expression/ops/reduce.h" 5029be168c0dSopenharmony_ci-#include "src/expression/ops/arithmetic.h" 5030be168c0dSopenharmony_ci-#include "src/expression/ops/transpose.h" 5031be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 5032be168c0dSopenharmony_ci-#include "src/expression/ops.h" 5033be168c0dSopenharmony_ci- 5034be168c0dSopenharmony_ci-namespace mindspore { 5035be168c0dSopenharmony_ci-namespace lite { 5036be168c0dSopenharmony_ci-SoftmaxM::SoftmaxM(int axis) { 5037be168c0dSopenharmony_ci- auto param = reinterpret_cast<SoftmaxParameter *>(calloc(1, sizeof(SoftmaxParameter))); 5038be168c0dSopenharmony_ci- if (param == nullptr) { 5039be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter"; 5040be168c0dSopenharmony_ci- return; 5041be168c0dSopenharmony_ci- } 5042be168c0dSopenharmony_ci- param->axis_ = axis; 5043be168c0dSopenharmony_ci- SetOpParam(param); 5044be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Softmax); 5045be168c0dSopenharmony_ci- set_name(UniqueName("Softmax")); 5046be168c0dSopenharmony_ci-} 5047be168c0dSopenharmony_ci- 5048be168c0dSopenharmony_ci-std::vector<int> SoftmaxM::getTransposeAxis(const std::vector<int> &shape, int axis) { 5049be168c0dSopenharmony_ci- int rank = shape.size(); 5050be168c0dSopenharmony_ci- if (axis < 0) { 5051be168c0dSopenharmony_ci- axis += rank; 5052be168c0dSopenharmony_ci- } 5053be168c0dSopenharmony_ci- std::vector<int> reverse_axis(rank); 5054be168c0dSopenharmony_ci- std::iota(reverse_axis.begin(), reverse_axis.end(), 0); 5055be168c0dSopenharmony_ci- reverse_axis.at(axis) = rank - 1; 5056be168c0dSopenharmony_ci- reverse_axis.at(rank - 1) = axis; 5057be168c0dSopenharmony_ci- return reverse_axis; 5058be168c0dSopenharmony_ci-} 5059be168c0dSopenharmony_ci- 5060be168c0dSopenharmony_ci-std::vector<EXPR *> SoftmaxM::Grad(EXPR *yt) { 5061be168c0dSopenharmony_ci- auto x = input(0); 5062be168c0dSopenharmony_ci- auto out = output(0); 5063be168c0dSopenharmony_ci- auto shape_of_x = x->dims(); 5064be168c0dSopenharmony_ci- auto param = reinterpret_cast<const SoftmaxParameter *>(OpParam()); 5065be168c0dSopenharmony_ci- auto reverse_axis = getTransposeAxis(shape_of_x, param->axis_); 5066be168c0dSopenharmony_ci- 5067be168c0dSopenharmony_ci- auto transpose_out = NN::Transpose(reverse_axis); 5068be168c0dSopenharmony_ci- transpose_out->set_name(kGradName + "/" + name() + "/" + transpose_out->name() + "/out/"); 5069be168c0dSopenharmony_ci- PushOp(transpose_out); 5070be168c0dSopenharmony_ci- auto y_trn = (*transpose_out)({out}).front(); 5071be168c0dSopenharmony_ci- 5072be168c0dSopenharmony_ci- auto transpose_dout = NN::Transpose(reverse_axis); 5073be168c0dSopenharmony_ci- transpose_dout->set_name(kGradName + "/" + name() + "/" + transpose_dout->name() + "/dout/"); 5074be168c0dSopenharmony_ci- PushOp(transpose_dout); 5075be168c0dSopenharmony_ci- auto yt_trn = (*transpose_dout)({yt}).front(); 5076be168c0dSopenharmony_ci- 5077be168c0dSopenharmony_ci- auto mul0 = NN::Mul(); 5078be168c0dSopenharmony_ci- mul0->set_name(kGradName + "/" + name() + "/" + mul0->name() + "0"); 5079be168c0dSopenharmony_ci- PushOp(mul0); 5080be168c0dSopenharmony_ci- auto tmp0 = (*mul0)({y_trn, yt_trn}).front(); 5081be168c0dSopenharmony_ci- 5082be168c0dSopenharmony_ci- auto sum_func = NN::ReduceSum(true, {-1}); 5083be168c0dSopenharmony_ci- sum_func->set_name(kGradName + "/" + name() + "/" + sum_func->name()); 5084be168c0dSopenharmony_ci- PushOp(sum_func); 5085be168c0dSopenharmony_ci- auto tmp1 = (*sum_func)({tmp0}).front(); 5086be168c0dSopenharmony_ci- 5087be168c0dSopenharmony_ci- auto sub = NN::Sub(); 5088be168c0dSopenharmony_ci- sub->set_name(kGradName + "/" + name() + "/" + sub->name()); 5089be168c0dSopenharmony_ci- PushOp(sub); 5090be168c0dSopenharmony_ci- auto tmp2 = (*sub)({yt_trn, tmp1}).front(); 5091be168c0dSopenharmony_ci- 5092be168c0dSopenharmony_ci- auto mul1 = NN::Mul(); 5093be168c0dSopenharmony_ci- mul1->set_name(kGradName + "/" + name() + "/" + mul1->name() + "1"); 5094be168c0dSopenharmony_ci- PushOp(mul1); 5095be168c0dSopenharmony_ci- auto tmp3 = (*mul1)({y_trn, tmp2}); 5096be168c0dSopenharmony_ci- 5097be168c0dSopenharmony_ci- auto transpose_dx = NN::Transpose(reverse_axis); 5098be168c0dSopenharmony_ci- transpose_dx->set_name(kGradName + "/" + name() + "/" + transpose_dx->name() + "/dx"); 5099be168c0dSopenharmony_ci- PushOp(transpose_dx); 5100be168c0dSopenharmony_ci- auto dx = (*transpose_dx)({tmp3}); 5101be168c0dSopenharmony_ci- return dx; 5102be168c0dSopenharmony_ci-} 5103be168c0dSopenharmony_ci- 5104be168c0dSopenharmony_ci-int SoftmaxM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 5105be168c0dSopenharmony_ci- auto param = reinterpret_cast<const SoftmaxParameter *>(OpParam()); 5106be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::SoftmaxT; 5107be168c0dSopenharmony_ci- if (prim == nullptr) { 5108be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 5109be168c0dSopenharmony_ci- return RET_ERROR; 5110be168c0dSopenharmony_ci- } 5111be168c0dSopenharmony_ci- prim->axis.push_back(param->axis_); 5112be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 5113be168c0dSopenharmony_ci- return RET_OK; 5114be168c0dSopenharmony_ci-} 5115be168c0dSopenharmony_ci- 5116be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Softmax, ReturnNode<SoftmaxM>); 5117be168c0dSopenharmony_ci- 5118be168c0dSopenharmony_ci-namespace NN { 5119be168c0dSopenharmony_ci-Node *Softmax(int axis) { 5120be168c0dSopenharmony_ci- auto node = new (std::nothrow) SoftmaxM(axis); 5121be168c0dSopenharmony_ci- return node; 5122be168c0dSopenharmony_ci-} 5123be168c0dSopenharmony_ci-} // namespace NN 5124be168c0dSopenharmony_ci-} // namespace lite 5125be168c0dSopenharmony_ci-} // namespace mindspore 5126be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/softmax.h b/mindspore/lite/src/expression/ops/softmax.h 5127be168c0dSopenharmony_cideleted file mode 100644 5128be168c0dSopenharmony_ciindex 7347984e..00000000 5129be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/softmax.h 5130be168c0dSopenharmony_ci+++ /dev/null 5131be168c0dSopenharmony_ci@@ -1,39 +0,0 @@ 5132be168c0dSopenharmony_ci-/** 5133be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5134be168c0dSopenharmony_ci- * 5135be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5136be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5137be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5138be168c0dSopenharmony_ci- * 5139be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5140be168c0dSopenharmony_ci- * 5141be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5142be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5143be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5144be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5145be168c0dSopenharmony_ci- * limitations under the License. 5146be168c0dSopenharmony_ci- */ 5147be168c0dSopenharmony_ci- 5148be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_ 5149be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_ 5150be168c0dSopenharmony_ci- 5151be168c0dSopenharmony_ci-#include <vector> 5152be168c0dSopenharmony_ci-#include <memory> 5153be168c0dSopenharmony_ci-#include "src/expression/node.h" 5154be168c0dSopenharmony_ci- 5155be168c0dSopenharmony_ci-namespace mindspore { 5156be168c0dSopenharmony_ci-namespace lite { 5157be168c0dSopenharmony_ci-class SoftmaxM : public Node { 5158be168c0dSopenharmony_ci- public: 5159be168c0dSopenharmony_ci- SoftmaxM() = default; 5160be168c0dSopenharmony_ci- explicit SoftmaxM(int axis); 5161be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 5162be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 5163be168c0dSopenharmony_ci- 5164be168c0dSopenharmony_ci- private: 5165be168c0dSopenharmony_ci- std::vector<int> getTransposeAxis(const std::vector<int> &shape, int axis); 5166be168c0dSopenharmony_ci-}; 5167be168c0dSopenharmony_ci-} // namespace lite 5168be168c0dSopenharmony_ci-} // namespace mindspore 5169be168c0dSopenharmony_ci- 5170be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAX_H_ 5171be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/softmaxCE.cc b/mindspore/lite/src/expression/ops/softmaxCE.cc 5172be168c0dSopenharmony_cideleted file mode 100644 5173be168c0dSopenharmony_ciindex f4ffce87..00000000 5174be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/softmaxCE.cc 5175be168c0dSopenharmony_ci+++ /dev/null 5176be168c0dSopenharmony_ci@@ -1,93 +0,0 @@ 5177be168c0dSopenharmony_ci-/** 5178be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5179be168c0dSopenharmony_ci- * 5180be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5181be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5182be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5183be168c0dSopenharmony_ci- * 5184be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5185be168c0dSopenharmony_ci- * 5186be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5187be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5188be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5189be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5190be168c0dSopenharmony_ci- * limitations under the License. 5191be168c0dSopenharmony_ci- */ 5192be168c0dSopenharmony_ci- 5193be168c0dSopenharmony_ci-#include "src/expression/ops/softmaxCE.h" 5194be168c0dSopenharmony_ci-#include "include/api/net.h" 5195be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 5196be168c0dSopenharmony_ci-#include "src/expression/ops/reduce.h" 5197be168c0dSopenharmony_ci-namespace mindspore { 5198be168c0dSopenharmony_ci-namespace NN { 5199be168c0dSopenharmony_ci-Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) { 5200be168c0dSopenharmony_ci- auto lite_node = lite::NN::SoftmaxCrossEntropy(cfg); 5201be168c0dSopenharmony_ci- return NodeImpl::Connect(lite_node); 5202be168c0dSopenharmony_ci-} 5203be168c0dSopenharmony_ci-} // namespace NN 5204be168c0dSopenharmony_ci- 5205be168c0dSopenharmony_ci-namespace lite { 5206be168c0dSopenharmony_ci-SoftmaxCrossEntropyM::SoftmaxCrossEntropyM() { 5207be168c0dSopenharmony_ci- auto param = calloc(1, sizeof(OpParameter)); 5208be168c0dSopenharmony_ci- if (param == nullptr) { 5209be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate parameter"; 5210be168c0dSopenharmony_ci- return; 5211be168c0dSopenharmony_ci- } 5212be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 5213be168c0dSopenharmony_ci- SetOpParam(param); 5214be168c0dSopenharmony_ci- set_name("SoftmaxCrossEntropy"); 5215be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_SoftmaxCrossEntropyWithLogits); 5216be168c0dSopenharmony_ci- EXPR e(this); 5217be168c0dSopenharmony_ci- e.SetSize(0); 5218be168c0dSopenharmony_ci- expr_.emplace_back(e); 5219be168c0dSopenharmony_ci-} 5220be168c0dSopenharmony_ci- 5221be168c0dSopenharmony_ci-Node *SoftmaxCrossEntropyM::GetReductionNode(const std::string &mode, const std::vector<int> &axis) { 5222be168c0dSopenharmony_ci- if (mode == "mean") { 5223be168c0dSopenharmony_ci- return NN::ReduceMean(false, axis); 5224be168c0dSopenharmony_ci- } else if (mode == "sum") { 5225be168c0dSopenharmony_ci- return NN::ReduceSum(false, axis); 5226be168c0dSopenharmony_ci- } else { 5227be168c0dSopenharmony_ci- return nullptr; 5228be168c0dSopenharmony_ci- } 5229be168c0dSopenharmony_ci-} 5230be168c0dSopenharmony_ci- 5231be168c0dSopenharmony_ci-SoftmaxCrossEntropyM::SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg) : SoftmaxCrossEntropyM() { 5232be168c0dSopenharmony_ci- std::vector<int> axis = {0}; 5233be168c0dSopenharmony_ci- reduce_ = GetReductionNode(cfg.reduction, axis); 5234be168c0dSopenharmony_ci- if (reduce_ != nullptr) { 5235be168c0dSopenharmony_ci- PushOp(reduce_); 5236be168c0dSopenharmony_ci- } 5237be168c0dSopenharmony_ci-} 5238be168c0dSopenharmony_ci- 5239be168c0dSopenharmony_ci-std::vector<EXPR *> SoftmaxCrossEntropyM::construct(const std::vector<EXPR *> &inputs) { 5240be168c0dSopenharmony_ci- auto y = Node::construct(inputs); 5241be168c0dSopenharmony_ci- if (reduce_ != nullptr) { 5242be168c0dSopenharmony_ci- y = (*reduce_)({y.front()}); 5243be168c0dSopenharmony_ci- } 5244be168c0dSopenharmony_ci- return y; 5245be168c0dSopenharmony_ci-} 5246be168c0dSopenharmony_ci- 5247be168c0dSopenharmony_ci-std::vector<EXPR *> SoftmaxCrossEntropyM::Grad(EXPR *expr) { return {this->expr(1)}; } 5248be168c0dSopenharmony_ci- 5249be168c0dSopenharmony_ci-int SoftmaxCrossEntropyM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 5250be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::SoftmaxCrossEntropyWithLogitsT; 5251be168c0dSopenharmony_ci- if (prim == nullptr) { 5252be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 5253be168c0dSopenharmony_ci- return RET_ERROR; 5254be168c0dSopenharmony_ci- } 5255be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 5256be168c0dSopenharmony_ci- return RET_OK; 5257be168c0dSopenharmony_ci-} 5258be168c0dSopenharmony_ci- 5259be168c0dSopenharmony_ci-namespace NN { 5260be168c0dSopenharmony_ci-Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg) { 5261be168c0dSopenharmony_ci- auto s = new (std::nothrow) SoftmaxCrossEntropyM(cfg); 5262be168c0dSopenharmony_ci- if (s == nullptr) { 5263be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate softmax node"; 5264be168c0dSopenharmony_ci- } 5265be168c0dSopenharmony_ci- return s; 5266be168c0dSopenharmony_ci-} 5267be168c0dSopenharmony_ci-} // namespace NN 5268be168c0dSopenharmony_ci-} // namespace lite 5269be168c0dSopenharmony_ci-} // namespace mindspore 5270be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/softmaxCE.h b/mindspore/lite/src/expression/ops/softmaxCE.h 5271be168c0dSopenharmony_cideleted file mode 100644 5272be168c0dSopenharmony_ciindex 1c0c516d..00000000 5273be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/softmaxCE.h 5274be168c0dSopenharmony_ci+++ /dev/null 5275be168c0dSopenharmony_ci@@ -1,47 +0,0 @@ 5276be168c0dSopenharmony_ci-/** 5277be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5278be168c0dSopenharmony_ci- * 5279be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5280be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5281be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5282be168c0dSopenharmony_ci- * 5283be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5284be168c0dSopenharmony_ci- * 5285be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5286be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5287be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5288be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5289be168c0dSopenharmony_ci- * limitations under the License. 5290be168c0dSopenharmony_ci- */ 5291be168c0dSopenharmony_ci- 5292be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_ 5293be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_ 5294be168c0dSopenharmony_ci- 5295be168c0dSopenharmony_ci-#include <vector> 5296be168c0dSopenharmony_ci-#include <memory> 5297be168c0dSopenharmony_ci-#include <string> 5298be168c0dSopenharmony_ci-#include "src/expression/node.h" 5299be168c0dSopenharmony_ci-#include "inner/model_generated.h" 5300be168c0dSopenharmony_ci-#include "include/api/net.h" 5301be168c0dSopenharmony_ci- 5302be168c0dSopenharmony_ci-namespace mindspore { 5303be168c0dSopenharmony_ci-namespace lite { 5304be168c0dSopenharmony_ci-class SoftmaxCrossEntropyM : public Node { 5305be168c0dSopenharmony_ci- public: 5306be168c0dSopenharmony_ci- SoftmaxCrossEntropyM(); 5307be168c0dSopenharmony_ci- explicit SoftmaxCrossEntropyM(const SoftMaxCrossEntropyCfg &cfg); 5308be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *expr) override; 5309be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 5310be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 5311be168c0dSopenharmony_ci- 5312be168c0dSopenharmony_ci- private: 5313be168c0dSopenharmony_ci- Node *GetReductionNode(const std::string &mode, const std::vector<int> &axis); 5314be168c0dSopenharmony_ci- Node *reduce_ = nullptr; 5315be168c0dSopenharmony_ci-}; 5316be168c0dSopenharmony_ci- 5317be168c0dSopenharmony_ci-namespace NN { 5318be168c0dSopenharmony_ci-Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg); 5319be168c0dSopenharmony_ci-} // namespace NN 5320be168c0dSopenharmony_ci-} // namespace lite 5321be168c0dSopenharmony_ci-} // namespace mindspore 5322be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_SOFTMAXCE_H_ 5323be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/tile.cc b/mindspore/lite/src/expression/ops/tile.cc 5324be168c0dSopenharmony_cideleted file mode 100644 5325be168c0dSopenharmony_ciindex 4008da12..00000000 5326be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/tile.cc 5327be168c0dSopenharmony_ci+++ /dev/null 5328be168c0dSopenharmony_ci@@ -1,62 +0,0 @@ 5329be168c0dSopenharmony_ci-/** 5330be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5331be168c0dSopenharmony_ci- * 5332be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5333be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5334be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5335be168c0dSopenharmony_ci- * 5336be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5337be168c0dSopenharmony_ci- * 5338be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5339be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5340be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5341be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5342be168c0dSopenharmony_ci- * limitations under the License. 5343be168c0dSopenharmony_ci- */ 5344be168c0dSopenharmony_ci- 5345be168c0dSopenharmony_ci-#include "src/expression/ops/tile.h" 5346be168c0dSopenharmony_ci-#include <memory> 5347be168c0dSopenharmony_ci-#include "src/expression/ops.h" 5348be168c0dSopenharmony_ci-#include "nnacl/base/tile_base.h" 5349be168c0dSopenharmony_ci- 5350be168c0dSopenharmony_ci-namespace mindspore { 5351be168c0dSopenharmony_ci-namespace lite { 5352be168c0dSopenharmony_ci-TileM::TileM(const std::vector<int> &multiples) : Node() { 5353be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 5354be168c0dSopenharmony_ci- TileParameter *param = reinterpret_cast<TileParameter *>(calloc(1, sizeof(TileParameter))); 5355be168c0dSopenharmony_ci- if (param == nullptr) { 5356be168c0dSopenharmony_ci- MS_LOG(ERROR) << " cannot allocate ConvParameter"; 5357be168c0dSopenharmony_ci- return; 5358be168c0dSopenharmony_ci- } 5359be168c0dSopenharmony_ci- SetOpParam(param); 5360be168c0dSopenharmony_ci- set_name(UniqueName("Tile")); 5361be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_TileFusion); 5362be168c0dSopenharmony_ci- Node::CreateConstTensor(C1NUM, {static_cast<int32_t>(multiples.size())}, kNumberTypeInt32, KHWC, "axis", 5363be168c0dSopenharmony_ci- multiples.data()); 5364be168c0dSopenharmony_ci-} 5365be168c0dSopenharmony_ci- 5366be168c0dSopenharmony_ci-int TileM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 5367be168c0dSopenharmony_ci- auto tile_param = reinterpret_cast<const TileParameter *>(OpParam()); 5368be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::TileFusionT; 5369be168c0dSopenharmony_ci- if (prim == nullptr) { 5370be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 5371be168c0dSopenharmony_ci- return RET_ERROR; 5372be168c0dSopenharmony_ci- } 5373be168c0dSopenharmony_ci- for (size_t i = 0; i < tile_param->dims_size_; i++) { 5374be168c0dSopenharmony_ci- prim->dims.push_back(tile_param->dims_[i]); 5375be168c0dSopenharmony_ci- } 5376be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 5377be168c0dSopenharmony_ci- return RET_OK; 5378be168c0dSopenharmony_ci-} 5379be168c0dSopenharmony_ci- 5380be168c0dSopenharmony_ci-namespace NN { 5381be168c0dSopenharmony_ci-Node *Tile(const std::vector<int> &multiples) { 5382be168c0dSopenharmony_ci- auto node = new (std::nothrow) TileM(multiples); 5383be168c0dSopenharmony_ci- if (node == nullptr) { 5384be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate tile node"; 5385be168c0dSopenharmony_ci- } 5386be168c0dSopenharmony_ci- return node; 5387be168c0dSopenharmony_ci-} 5388be168c0dSopenharmony_ci-} // namespace NN 5389be168c0dSopenharmony_ci-} // namespace lite 5390be168c0dSopenharmony_ci-} // namespace mindspore 5391be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/tile.h b/mindspore/lite/src/expression/ops/tile.h 5392be168c0dSopenharmony_cideleted file mode 100644 5393be168c0dSopenharmony_ciindex 0d793c0e..00000000 5394be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/tile.h 5395be168c0dSopenharmony_ci+++ /dev/null 5396be168c0dSopenharmony_ci@@ -1,40 +0,0 @@ 5397be168c0dSopenharmony_ci-/** 5398be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5399be168c0dSopenharmony_ci- * 5400be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5401be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5402be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5403be168c0dSopenharmony_ci- * 5404be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5405be168c0dSopenharmony_ci- * 5406be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5407be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5408be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5409be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5410be168c0dSopenharmony_ci- * limitations under the License. 5411be168c0dSopenharmony_ci- */ 5412be168c0dSopenharmony_ci- 5413be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_ 5414be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_ 5415be168c0dSopenharmony_ci- 5416be168c0dSopenharmony_ci-#include <vector> 5417be168c0dSopenharmony_ci-#include <memory> 5418be168c0dSopenharmony_ci-#include "src/expression/node.h" 5419be168c0dSopenharmony_ci-#include "inner/model_generated.h" 5420be168c0dSopenharmony_ci- 5421be168c0dSopenharmony_ci-namespace mindspore { 5422be168c0dSopenharmony_ci-namespace lite { 5423be168c0dSopenharmony_ci-class TileM : public Node { 5424be168c0dSopenharmony_ci- public: 5425be168c0dSopenharmony_ci- TileM() = default; 5426be168c0dSopenharmony_ci- explicit TileM(const std::vector<int> &multiples); 5427be168c0dSopenharmony_ci- Param *weight() override { return input(1)->node()->data(); } 5428be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 5429be168c0dSopenharmony_ci-}; 5430be168c0dSopenharmony_ci- 5431be168c0dSopenharmony_ci-namespace NN { 5432be168c0dSopenharmony_ci-Node *Tile(const std::vector<int> &multiples); 5433be168c0dSopenharmony_ci-} 5434be168c0dSopenharmony_ci-} // namespace lite 5435be168c0dSopenharmony_ci-} // namespace mindspore 5436be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TILE_H_ 5437be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/transpose.cc b/mindspore/lite/src/expression/ops/transpose.cc 5438be168c0dSopenharmony_cideleted file mode 100644 5439be168c0dSopenharmony_ciindex fbe2b14d..00000000 5440be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/transpose.cc 5441be168c0dSopenharmony_ci+++ /dev/null 5442be168c0dSopenharmony_ci@@ -1,88 +0,0 @@ 5443be168c0dSopenharmony_ci-/** 5444be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5445be168c0dSopenharmony_ci- * 5446be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5447be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5448be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5449be168c0dSopenharmony_ci- * 5450be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5451be168c0dSopenharmony_ci- * 5452be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5453be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5454be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5455be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5456be168c0dSopenharmony_ci- * limitations under the License. 5457be168c0dSopenharmony_ci- */ 5458be168c0dSopenharmony_ci- 5459be168c0dSopenharmony_ci-#include "src/expression/ops/transpose.h" 5460be168c0dSopenharmony_ci-#include <memory> 5461be168c0dSopenharmony_ci-#include "nnacl/transpose_parameter.h" 5462be168c0dSopenharmony_ci-#include "inner/model_generated.h" 5463be168c0dSopenharmony_ci-#include "src/expression/import.h" 5464be168c0dSopenharmony_ci- 5465be168c0dSopenharmony_ci-namespace mindspore { 5466be168c0dSopenharmony_ci-namespace lite { 5467be168c0dSopenharmony_ci-TransposeM::TransposeM(const std::vector<int> &vector) { 5468be168c0dSopenharmony_ci- auto param = calloc(1, sizeof(TransposeParameter)); 5469be168c0dSopenharmony_ci- if (param == nullptr) { 5470be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate transpose parameter"; 5471be168c0dSopenharmony_ci- return; 5472be168c0dSopenharmony_ci- } 5473be168c0dSopenharmony_ci- SetOpParam(param); 5474be168c0dSopenharmony_ci- expr()->SetSize(C2NUM); 5475be168c0dSopenharmony_ci- set_primitive(schema::PrimitiveType_Transpose); 5476be168c0dSopenharmony_ci- std::vector<int> dims = {static_cast<int>(vector.size())}; 5477be168c0dSopenharmony_ci- set_name(UniqueName("Transpose")); 5478be168c0dSopenharmony_ci- CreateConstTensor(C1NUM, dims, kNumberTypeInt32, KHWC, "axis", vector.data()); 5479be168c0dSopenharmony_ci-} 5480be168c0dSopenharmony_ci- 5481be168c0dSopenharmony_ci-std::vector<int> TransposeM::Invert(const std::vector<int> &vector) { 5482be168c0dSopenharmony_ci- std::vector<int> res; 5483be168c0dSopenharmony_ci- for (size_t i = 0; i < vector.size(); i++) { 5484be168c0dSopenharmony_ci- int idx = static_cast<int>(i); 5485be168c0dSopenharmony_ci- auto val = std::find_if(vector.begin(), vector.end(), [idx](int x) { return (x == idx) ? true : false; }); 5486be168c0dSopenharmony_ci- if (val == vector.end()) { 5487be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Wrong index for " << idx; 5488be168c0dSopenharmony_ci- return {}; 5489be168c0dSopenharmony_ci- } 5490be168c0dSopenharmony_ci- res.push_back(std::distance(vector.begin(), val)); 5491be168c0dSopenharmony_ci- } 5492be168c0dSopenharmony_ci- return res; 5493be168c0dSopenharmony_ci-} 5494be168c0dSopenharmony_ci- 5495be168c0dSopenharmony_ci-std::vector<EXPR *> TransposeM::Grad(EXPR *yt) { 5496be168c0dSopenharmony_ci- auto tensor = input(1)->node(); 5497be168c0dSopenharmony_ci- auto data = tensor->data(); 5498be168c0dSopenharmony_ci- auto vec = data->Extract<int>(); 5499be168c0dSopenharmony_ci- auto invert = Invert(vec); 5500be168c0dSopenharmony_ci- auto tran = new (std::nothrow) TransposeM(invert); 5501be168c0dSopenharmony_ci- if (tran == nullptr) { 5502be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate transpose grad"; 5503be168c0dSopenharmony_ci- return {}; 5504be168c0dSopenharmony_ci- } 5505be168c0dSopenharmony_ci- tran->set_name(kGradName + "/" + name() + "/" + tran->name()); 5506be168c0dSopenharmony_ci- PushOp(tran); 5507be168c0dSopenharmony_ci- auto grad = (*tran)({yt}); 5508be168c0dSopenharmony_ci- return grad; 5509be168c0dSopenharmony_ci-} 5510be168c0dSopenharmony_ci- 5511be168c0dSopenharmony_ci-int TransposeM::UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) { 5512be168c0dSopenharmony_ci- auto prim = new (std::nothrow) schema::TransposeT; 5513be168c0dSopenharmony_ci- if (prim == nullptr) { 5514be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate primitive"; 5515be168c0dSopenharmony_ci- return RET_ERROR; 5516be168c0dSopenharmony_ci- } 5517be168c0dSopenharmony_ci- cnode->primitive->value.value = prim; 5518be168c0dSopenharmony_ci- return RET_OK; 5519be168c0dSopenharmony_ci-} 5520be168c0dSopenharmony_ci- 5521be168c0dSopenharmony_ci-static ImportReg reg(schema::PrimitiveType_Transpose, ReturnNode<TransposeM>); 5522be168c0dSopenharmony_ci- 5523be168c0dSopenharmony_ci-namespace NN { 5524be168c0dSopenharmony_ci-Node *Transpose(const std::vector<int> &permute) { 5525be168c0dSopenharmony_ci- auto node = new (std::nothrow) TransposeM(permute); 5526be168c0dSopenharmony_ci- return node; 5527be168c0dSopenharmony_ci-} 5528be168c0dSopenharmony_ci-} // namespace NN 5529be168c0dSopenharmony_ci-} // namespace lite 5530be168c0dSopenharmony_ci-} // namespace mindspore 5531be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/transpose.h b/mindspore/lite/src/expression/ops/transpose.h 5532be168c0dSopenharmony_cideleted file mode 100644 5533be168c0dSopenharmony_ciindex d4c9a1c1..00000000 5534be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/transpose.h 5535be168c0dSopenharmony_ci+++ /dev/null 5536be168c0dSopenharmony_ci@@ -1,59 +0,0 @@ 5537be168c0dSopenharmony_ci-/** 5538be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5539be168c0dSopenharmony_ci- * 5540be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5541be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5542be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5543be168c0dSopenharmony_ci- * 5544be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5545be168c0dSopenharmony_ci- * 5546be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5547be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5548be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5549be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5550be168c0dSopenharmony_ci- * limitations under the License. 5551be168c0dSopenharmony_ci- */ 5552be168c0dSopenharmony_ci- 5553be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_ 5554be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_ 5555be168c0dSopenharmony_ci- 5556be168c0dSopenharmony_ci-#include <vector> 5557be168c0dSopenharmony_ci-#include <memory> 5558be168c0dSopenharmony_ci-#include "src/expression/node.h" 5559be168c0dSopenharmony_ci-#include "inner/model_generated.h" 5560be168c0dSopenharmony_ci- 5561be168c0dSopenharmony_ci-namespace mindspore { 5562be168c0dSopenharmony_ci-namespace lite { 5563be168c0dSopenharmony_ci-class TransposeM : public Node { 5564be168c0dSopenharmony_ci- public: 5565be168c0dSopenharmony_ci- TransposeM() = default; 5566be168c0dSopenharmony_ci- explicit TransposeM(const std::vector<int> &vector); 5567be168c0dSopenharmony_ci- static EXPR *TransposeCHW2HWC(EXPR *in) { 5568be168c0dSopenharmony_ci- std::vector<int> res = {0, 2, 3, 1}; 5569be168c0dSopenharmony_ci- auto trans = new (std::nothrow) TransposeM(res); 5570be168c0dSopenharmony_ci- if (trans == nullptr) { 5571be168c0dSopenharmony_ci- return nullptr; 5572be168c0dSopenharmony_ci- } 5573be168c0dSopenharmony_ci- return (*trans)({in}).front(); 5574be168c0dSopenharmony_ci- } 5575be168c0dSopenharmony_ci- static EXPR *TransposeHWC2CHW(EXPR *in) { 5576be168c0dSopenharmony_ci- std::vector<int> res = {0, 3, 1, 2}; 5577be168c0dSopenharmony_ci- auto trans = new (std::nothrow) TransposeM(res); 5578be168c0dSopenharmony_ci- if (trans == nullptr) { 5579be168c0dSopenharmony_ci- return nullptr; 5580be168c0dSopenharmony_ci- } 5581be168c0dSopenharmony_ci- return (*trans)({in}).front(); 5582be168c0dSopenharmony_ci- } 5583be168c0dSopenharmony_ci- int UnPopulate(const std::unique_ptr<schema::CNodeT> &cnode) override; 5584be168c0dSopenharmony_ci- std::vector<EXPR *> Grad(EXPR *yt) override; 5585be168c0dSopenharmony_ci- 5586be168c0dSopenharmony_ci- private: 5587be168c0dSopenharmony_ci- std::vector<int> Invert(const std::vector<int> &vec); 5588be168c0dSopenharmony_ci-}; 5589be168c0dSopenharmony_ci- 5590be168c0dSopenharmony_ci-namespace NN { 5591be168c0dSopenharmony_ci-Node *Transpose(const std::vector<int> &permute); 5592be168c0dSopenharmony_ci-} // namespace NN 5593be168c0dSopenharmony_ci-} // namespace lite 5594be168c0dSopenharmony_ci-} // namespace mindspore 5595be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_TRANSPOSE_H_ 5596be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops_utils.cc b/mindspore/lite/src/expression/ops_utils.cc 5597be168c0dSopenharmony_cideleted file mode 100644 5598be168c0dSopenharmony_ciindex e63b903d..00000000 5599be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops_utils.cc 5600be168c0dSopenharmony_ci+++ /dev/null 5601be168c0dSopenharmony_ci@@ -1,275 +0,0 @@ 5602be168c0dSopenharmony_ci-/** 5603be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5604be168c0dSopenharmony_ci- * 5605be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5606be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5607be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5608be168c0dSopenharmony_ci- * 5609be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5610be168c0dSopenharmony_ci- * 5611be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5612be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5613be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5614be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5615be168c0dSopenharmony_ci- * limitations under the License. 5616be168c0dSopenharmony_ci- */ 5617be168c0dSopenharmony_ci- 5618be168c0dSopenharmony_ci-#include "src/expression/ops_utils.h" 5619be168c0dSopenharmony_ci-#include <set> 5620be168c0dSopenharmony_ci-#include <algorithm> 5621be168c0dSopenharmony_ci- 5622be168c0dSopenharmony_ci-namespace mindspore { 5623be168c0dSopenharmony_ci-namespace lite { 5624be168c0dSopenharmony_ci-enum class State { 5625be168c0dSopenharmony_ci- SAME, 5626be168c0dSopenharmony_ci- X_ONE, 5627be168c0dSopenharmony_ci- Y_ONE, 5628be168c0dSopenharmony_ci-}; 5629be168c0dSopenharmony_ci- 5630be168c0dSopenharmony_ci-bool CompareShape(const std::vector<int> &x_shape, const std::vector<int> &y_shape) { 5631be168c0dSopenharmony_ci- if (x_shape.size() != y_shape.size()) { 5632be168c0dSopenharmony_ci- return false; 5633be168c0dSopenharmony_ci- } 5634be168c0dSopenharmony_ci- 5635be168c0dSopenharmony_ci- for (size_t i = 0; i < x_shape.size(); ++i) { 5636be168c0dSopenharmony_ci- if (x_shape.at(i) != y_shape.at(i)) { 5637be168c0dSopenharmony_ci- return false; 5638be168c0dSopenharmony_ci- } 5639be168c0dSopenharmony_ci- } 5640be168c0dSopenharmony_ci- 5641be168c0dSopenharmony_ci- return true; 5642be168c0dSopenharmony_ci-} 5643be168c0dSopenharmony_ci- 5644be168c0dSopenharmony_ci-void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y, 5645be168c0dSopenharmony_ci- std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) { 5646be168c0dSopenharmony_ci- MS_ASSERT(grad_x_reduce_idx != nullptr); 5647be168c0dSopenharmony_ci- MS_ASSERT(grad_y_reduce_idy != nullptr); 5648be168c0dSopenharmony_ci- const size_t n = reverse_x.size(); 5649be168c0dSopenharmony_ci- if (reverse_y.size() < n) { 5650be168c0dSopenharmony_ci- MS_LOG_ERROR << "The size of reverse_y is less than the size of reverse_x."; 5651be168c0dSopenharmony_ci- } 5652be168c0dSopenharmony_ci- for (size_t i = 0; i < n; ++i) { 5653be168c0dSopenharmony_ci- State curr = State::SAME; 5654be168c0dSopenharmony_ci- const int x_i = reverse_x[i]; 5655be168c0dSopenharmony_ci- const int y_i = reverse_y[i]; 5656be168c0dSopenharmony_ci- const int reduce_idx = (n - 1 - i); 5657be168c0dSopenharmony_ci- if (x_i == y_i) { 5658be168c0dSopenharmony_ci- curr = State::SAME; 5659be168c0dSopenharmony_ci- } else if (x_i == 1) { 5660be168c0dSopenharmony_ci- grad_x_reduce_idx->push_back(reduce_idx); 5661be168c0dSopenharmony_ci- curr = State::X_ONE; 5662be168c0dSopenharmony_ci- } else if (y_i == 1) { 5663be168c0dSopenharmony_ci- grad_y_reduce_idy->push_back(reduce_idx); 5664be168c0dSopenharmony_ci- curr = State::Y_ONE; 5665be168c0dSopenharmony_ci- } else { 5666be168c0dSopenharmony_ci- MS_LOG_ERROR << "not compatible shape input for BroadcastGradientArgs"; 5667be168c0dSopenharmony_ci- } 5668be168c0dSopenharmony_ci- if (curr == State::SAME && x_i == 1) { 5669be168c0dSopenharmony_ci- grad_x_reduce_idx->push_back(reduce_idx); 5670be168c0dSopenharmony_ci- grad_y_reduce_idy->push_back(reduce_idx); 5671be168c0dSopenharmony_ci- continue; 5672be168c0dSopenharmony_ci- } 5673be168c0dSopenharmony_ci- } 5674be168c0dSopenharmony_ci- 5675be168c0dSopenharmony_ci- std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); 5676be168c0dSopenharmony_ci- std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); 5677be168c0dSopenharmony_ci-} 5678be168c0dSopenharmony_ci- 5679be168c0dSopenharmony_ci-std::vector<std::vector<int>> BroadcastGradientArgs::operator()() { 5680be168c0dSopenharmony_ci- std::vector<std::vector<int>> input_dim(kInNum); 5681be168c0dSopenharmony_ci- input_dim[0] = dim0_; 5682be168c0dSopenharmony_ci- input_dim[1] = dim1_; 5683be168c0dSopenharmony_ci- auto same_shape = CompareShape(dim0_, dim1_); 5684be168c0dSopenharmony_ci- if (same_shape) { 5685be168c0dSopenharmony_ci- return {{}, {}}; 5686be168c0dSopenharmony_ci- } 5687be168c0dSopenharmony_ci- 5688be168c0dSopenharmony_ci- std::vector<int> reverse_x; 5689be168c0dSopenharmony_ci- std::vector<int> reverse_y; 5690be168c0dSopenharmony_ci- 5691be168c0dSopenharmony_ci- (void)std::transform(dim0_.rbegin(), dim0_.rend(), std::back_inserter(reverse_x), [](const int &v) { return v; }); 5692be168c0dSopenharmony_ci- (void)std::transform(dim1_.rbegin(), dim1_.rend(), std::back_inserter(reverse_y), [](const int &v) { return v; }); 5693be168c0dSopenharmony_ci- 5694be168c0dSopenharmony_ci- if (reverse_x.size() > reverse_y.size()) { 5695be168c0dSopenharmony_ci- reverse_y.resize(reverse_x.size(), 1); 5696be168c0dSopenharmony_ci- } else { 5697be168c0dSopenharmony_ci- reverse_x.resize(reverse_y.size(), 1); 5698be168c0dSopenharmony_ci- } 5699be168c0dSopenharmony_ci- 5700be168c0dSopenharmony_ci- std::vector<int> grad_x_reduce_idx; 5701be168c0dSopenharmony_ci- std::vector<int> grad_y_reduce_idy; 5702be168c0dSopenharmony_ci- ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); 5703be168c0dSopenharmony_ci- return {grad_x_reduce_idx, grad_y_reduce_idy}; 5704be168c0dSopenharmony_ci-} 5705be168c0dSopenharmony_ci- 5706be168c0dSopenharmony_ci-void DynamicBroadcastGradientArgs::AddElementToGradReduceIdx(std::vector<std::vector<int>> *grad_reduce_idx, 5707be168c0dSopenharmony_ci- std::vector<bool> current_is_one, bool none_is_one, 5708be168c0dSopenharmony_ci- const size_t largest_rank, size_t j) { 5709be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5710be168c0dSopenharmony_ci- if (current_is_one[i] && !none_is_one) { 5711be168c0dSopenharmony_ci- (void)(*grad_reduce_idx)[i].emplace_back(largest_rank - 1 - j); 5712be168c0dSopenharmony_ci- } 5713be168c0dSopenharmony_ci- } 5714be168c0dSopenharmony_ci-} 5715be168c0dSopenharmony_ci- 5716be168c0dSopenharmony_ci-void DynamicBroadcastGradientArgs::UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one) { 5717be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5718be168c0dSopenharmony_ci- (*prev_is_one)[i] = current_is_one[i]; 5719be168c0dSopenharmony_ci- } 5720be168c0dSopenharmony_ci-} 5721be168c0dSopenharmony_ci- 5722be168c0dSopenharmony_ci-std::vector<std::vector<int>> DynamicBroadcastGradientArgs::GetGradientIndices( 5723be168c0dSopenharmony_ci- const std::vector<std::vector<int>> &reverse_shape, const size_t largest_rank) { 5724be168c0dSopenharmony_ci- std::vector<std::vector<int>> grad_reduce_idx(kInNum); 5725be168c0dSopenharmony_ci- // indices of j-th component of each input. 5726be168c0dSopenharmony_ci- std::vector<bool> prev_is_one(kInNum); 5727be168c0dSopenharmony_ci- std::vector<bool> current_is_one(kInNum); 5728be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5729be168c0dSopenharmony_ci- prev_is_one[i] = false; 5730be168c0dSopenharmony_ci- current_is_one[i] = false; 5731be168c0dSopenharmony_ci- } 5732be168c0dSopenharmony_ci- 5733be168c0dSopenharmony_ci- bool set_one = false; 5734be168c0dSopenharmony_ci- for (size_t j = 0; j < largest_rank; ++j) { 5735be168c0dSopenharmony_ci- int output_dim = -1; 5736be168c0dSopenharmony_ci- bool output_dim_set = false; 5737be168c0dSopenharmony_ci- bool none_is_one = true; 5738be168c0dSopenharmony_ci- // Find which indices are 1. 5739be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5740be168c0dSopenharmony_ci- if (reverse_shape[i][j] == 1) { 5741be168c0dSopenharmony_ci- current_is_one[i] = true; 5742be168c0dSopenharmony_ci- none_is_one = false; 5743be168c0dSopenharmony_ci- } else { 5744be168c0dSopenharmony_ci- current_is_one[i] = false; 5745be168c0dSopenharmony_ci- if (!output_dim_set || reverse_shape[i][j] == static_cast<int>(output_dim)) { 5746be168c0dSopenharmony_ci- output_dim = reverse_shape[i][j]; 5747be168c0dSopenharmony_ci- output_dim_set = true; 5748be168c0dSopenharmony_ci- } else { 5749be168c0dSopenharmony_ci- std::cout << "Input[0] and input[1] Cannot broadcast!"; 5750be168c0dSopenharmony_ci- } 5751be168c0dSopenharmony_ci- } 5752be168c0dSopenharmony_ci- } 5753be168c0dSopenharmony_ci- // All dimensions are 1. 5754be168c0dSopenharmony_ci- if (!output_dim_set) { 5755be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5756be168c0dSopenharmony_ci- (void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j); 5757be168c0dSopenharmony_ci- } 5758be168c0dSopenharmony_ci- continue; 5759be168c0dSopenharmony_ci- } else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) { 5760be168c0dSopenharmony_ci- AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j); 5761be168c0dSopenharmony_ci- } else { 5762be168c0dSopenharmony_ci- AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j); 5763be168c0dSopenharmony_ci- } 5764be168c0dSopenharmony_ci- set_one = true; 5765be168c0dSopenharmony_ci- UpdatePreIsOne(&prev_is_one, current_is_one); 5766be168c0dSopenharmony_ci- } 5767be168c0dSopenharmony_ci- return grad_reduce_idx; 5768be168c0dSopenharmony_ci-} 5769be168c0dSopenharmony_ci- 5770be168c0dSopenharmony_ci-std::vector<std::vector<int>> DynamicBroadcastGradientArgs::CalculateOutput(const std::vector<std::vector<int>> &x) { 5771be168c0dSopenharmony_ci- std::vector<std::vector<int>> grad_reduce_idx(kInNum); 5772be168c0dSopenharmony_ci- bool all_equal = true; 5773be168c0dSopenharmony_ci- size_t largest_rank = 0; 5774be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5775be168c0dSopenharmony_ci- if (x[i] != x[0]) { 5776be168c0dSopenharmony_ci- all_equal = false; 5777be168c0dSopenharmony_ci- } 5778be168c0dSopenharmony_ci- if (x[i].size() > largest_rank) { 5779be168c0dSopenharmony_ci- largest_rank = x[i].size(); 5780be168c0dSopenharmony_ci- } 5781be168c0dSopenharmony_ci- } 5782be168c0dSopenharmony_ci- if (all_equal) { 5783be168c0dSopenharmony_ci- return grad_reduce_idx; 5784be168c0dSopenharmony_ci- } 5785be168c0dSopenharmony_ci- 5786be168c0dSopenharmony_ci- // Reverse input the shapes 5787be168c0dSopenharmony_ci- std::vector<std::vector<int>> reverse_shape(kInNum); 5788be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5789be168c0dSopenharmony_ci- reverse_shape[i] = x[i]; 5790be168c0dSopenharmony_ci- std::reverse(reverse_shape[i].begin(), reverse_shape[i].end()); 5791be168c0dSopenharmony_ci- } 5792be168c0dSopenharmony_ci- 5793be168c0dSopenharmony_ci- // 1-extend and align all vectors. 5794be168c0dSopenharmony_ci- for (size_t i = 0; i < kInNum; ++i) { 5795be168c0dSopenharmony_ci- if (reverse_shape[i].size() < largest_rank) { 5796be168c0dSopenharmony_ci- reverse_shape[i].resize(largest_rank, 1); 5797be168c0dSopenharmony_ci- } 5798be168c0dSopenharmony_ci- } 5799be168c0dSopenharmony_ci- grad_reduce_idx = GetGradientIndices(reverse_shape, largest_rank); 5800be168c0dSopenharmony_ci- return grad_reduce_idx; 5801be168c0dSopenharmony_ci-} 5802be168c0dSopenharmony_ci- 5803be168c0dSopenharmony_ci-std::vector<std::vector<int>> DynamicBroadcastGradientArgs::SetOutputValue( 5804be168c0dSopenharmony_ci- const std::vector<std::vector<int>> &grad_reduce_idx, const std::vector<std::vector<int>> &input_dim) { 5805be168c0dSopenharmony_ci- std::vector<std::vector<int>> output(kInNum); 5806be168c0dSopenharmony_ci- for (size_t index = 0; index < kInNum; ++index) { 5807be168c0dSopenharmony_ci- auto idx_num = grad_reduce_idx[index].size(); 5808be168c0dSopenharmony_ci- for (size_t k = 0; k < idx_num; ++k) { 5809be168c0dSopenharmony_ci- output[index].push_back(grad_reduce_idx[index][idx_num - 1 - k]); 5810be168c0dSopenharmony_ci- } 5811be168c0dSopenharmony_ci- if (idx_num == 0) { 5812be168c0dSopenharmony_ci- auto input_num = input_dim[index].size(); 5813be168c0dSopenharmony_ci- for (size_t k = 0; k < input_num; ++k) { 5814be168c0dSopenharmony_ci- output[index].push_back(k); 5815be168c0dSopenharmony_ci- } 5816be168c0dSopenharmony_ci- } 5817be168c0dSopenharmony_ci- } 5818be168c0dSopenharmony_ci- return output; 5819be168c0dSopenharmony_ci-} 5820be168c0dSopenharmony_ci- 5821be168c0dSopenharmony_ci-std::vector<std::vector<int>> DynamicBroadcastGradientArgs::operator()() { 5822be168c0dSopenharmony_ci- std::vector<std::vector<int>> input_dim(kInNum); 5823be168c0dSopenharmony_ci- input_dim[0] = dim0_; 5824be168c0dSopenharmony_ci- input_dim[1] = dim1_; 5825be168c0dSopenharmony_ci- auto grad_reduce_idx = CalculateOutput(input_dim); 5826be168c0dSopenharmony_ci- auto output = SetOutputValue(grad_reduce_idx, input_dim); 5827be168c0dSopenharmony_ci- return output; 5828be168c0dSopenharmony_ci-} 5829be168c0dSopenharmony_ci- 5830be168c0dSopenharmony_ci-std::vector<int> VectorDiv::operator()(const std::vector<int> &x, const std::vector<int> &d) { 5831be168c0dSopenharmony_ci- if (d.size() != x.size()) { 5832be168c0dSopenharmony_ci- MS_LOG(ERROR) << "x and divider must have same size"; 5833be168c0dSopenharmony_ci- return {}; 5834be168c0dSopenharmony_ci- } 5835be168c0dSopenharmony_ci- std::vector<int> res; 5836be168c0dSopenharmony_ci- for (size_t i = 0; i < d.size(); i++) { 5837be168c0dSopenharmony_ci- auto x_value = x.at(i); 5838be168c0dSopenharmony_ci- auto d_value = d.at(i); 5839be168c0dSopenharmony_ci- if (d_value == 0) { 5840be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Divisor is zero"; 5841be168c0dSopenharmony_ci- return {}; 5842be168c0dSopenharmony_ci- } 5843be168c0dSopenharmony_ci- if ((x_value % d_value) != 0) { 5844be168c0dSopenharmony_ci- MS_LOG(ERROR) << "x and d and not dividable"; 5845be168c0dSopenharmony_ci- } 5846be168c0dSopenharmony_ci- auto r = x_value / d_value; 5847be168c0dSopenharmony_ci- res.push_back(r); 5848be168c0dSopenharmony_ci- } 5849be168c0dSopenharmony_ci- return res; 5850be168c0dSopenharmony_ci-} 5851be168c0dSopenharmony_ci- 5852be168c0dSopenharmony_ci-std::vector<int> ShapeReduce::operator()(const std::vector<int> &x_shape, const std::vector<int> &axis) { 5853be168c0dSopenharmony_ci- int x_rank = x_shape.size(); 5854be168c0dSopenharmony_ci- std::set<int> axis_set; 5855be168c0dSopenharmony_ci- 5856be168c0dSopenharmony_ci- auto min = -x_rank; 5857be168c0dSopenharmony_ci- auto max = x_rank - 1; 5858be168c0dSopenharmony_ci- for (auto &elem : axis) { 5859be168c0dSopenharmony_ci- if (elem > max || elem < min) { 5860be168c0dSopenharmony_ci- MS_LOG(ERROR) << "illegal axis value"; 5861be168c0dSopenharmony_ci- return {}; 5862be168c0dSopenharmony_ci- } 5863be168c0dSopenharmony_ci- axis_set.insert(elem); 5864be168c0dSopenharmony_ci- } 5865be168c0dSopenharmony_ci- std::vector<int> res; 5866be168c0dSopenharmony_ci- for (int i = 0; i < x_rank; i++) { 5867be168c0dSopenharmony_ci- if (axis_set.count(i) || axis_set.count(i - x_rank)) { 5868be168c0dSopenharmony_ci- res.push_back(1); 5869be168c0dSopenharmony_ci- } else { 5870be168c0dSopenharmony_ci- res.push_back(x_shape.at(i)); 5871be168c0dSopenharmony_ci- } 5872be168c0dSopenharmony_ci- } 5873be168c0dSopenharmony_ci- return res; 5874be168c0dSopenharmony_ci-} 5875be168c0dSopenharmony_ci-} // namespace lite 5876be168c0dSopenharmony_ci-} // namespace mindspore 5877be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops_utils.h b/mindspore/lite/src/expression/ops_utils.h 5878be168c0dSopenharmony_cideleted file mode 100644 5879be168c0dSopenharmony_ciindex 6c62de11..00000000 5880be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops_utils.h 5881be168c0dSopenharmony_ci+++ /dev/null 5882be168c0dSopenharmony_ci@@ -1,69 +0,0 @@ 5883be168c0dSopenharmony_ci-/** 5884be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5885be168c0dSopenharmony_ci- * 5886be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5887be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5888be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5889be168c0dSopenharmony_ci- * 5890be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5891be168c0dSopenharmony_ci- * 5892be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5893be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5894be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5895be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5896be168c0dSopenharmony_ci- * limitations under the License. 5897be168c0dSopenharmony_ci- */ 5898be168c0dSopenharmony_ci- 5899be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_ 5900be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_ 5901be168c0dSopenharmony_ci- 5902be168c0dSopenharmony_ci-#include "include/api/cfg.h" 5903be168c0dSopenharmony_ci-#include "src/expression/net.h" 5904be168c0dSopenharmony_ci-#include "vector" 5905be168c0dSopenharmony_ci- 5906be168c0dSopenharmony_ci-namespace mindspore { 5907be168c0dSopenharmony_ci-namespace lite { 5908be168c0dSopenharmony_ci-class BroadcastGradientArgs { 5909be168c0dSopenharmony_ci- public: 5910be168c0dSopenharmony_ci- BroadcastGradientArgs(const std::vector<int> &dim0, const std::vector<int> &dim1) : dim0_(dim0), dim1_(dim1) {} 5911be168c0dSopenharmony_ci- std::vector<std::vector<int>> operator()(); 5912be168c0dSopenharmony_ci- 5913be168c0dSopenharmony_ci- private: 5914be168c0dSopenharmony_ci- static const int kInNum = 2; 5915be168c0dSopenharmony_ci- const std::vector<int> &dim0_; 5916be168c0dSopenharmony_ci- const std::vector<int> &dim1_; 5917be168c0dSopenharmony_ci-}; 5918be168c0dSopenharmony_ci- 5919be168c0dSopenharmony_ci-class DynamicBroadcastGradientArgs { 5920be168c0dSopenharmony_ci- public: 5921be168c0dSopenharmony_ci- DynamicBroadcastGradientArgs(const std::vector<int> &dim0, const std::vector<int> &dim1) : dim0_(dim0), dim1_(dim1) {} 5922be168c0dSopenharmony_ci- std::vector<std::vector<int>> operator()(); 5923be168c0dSopenharmony_ci- 5924be168c0dSopenharmony_ci- private: 5925be168c0dSopenharmony_ci- void AddElementToGradReduceIdx(std::vector<std::vector<int>> *grad_reduce_idx, std::vector<bool> current_is_one, 5926be168c0dSopenharmony_ci- bool none_is_one, const size_t largest_rank, size_t j); 5927be168c0dSopenharmony_ci- void UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one); 5928be168c0dSopenharmony_ci- std::vector<std::vector<int>> GetGradientIndices(const std::vector<std::vector<int>> &reverse_shape, 5929be168c0dSopenharmony_ci- const size_t largest_rank); 5930be168c0dSopenharmony_ci- std::vector<std::vector<int>> CalculateOutput(const std::vector<std::vector<int>> &x); 5931be168c0dSopenharmony_ci- std::vector<std::vector<int>> SetOutputValue(const std::vector<std::vector<int>> &grad_reduce_idx, 5932be168c0dSopenharmony_ci- const std::vector<std::vector<int>> &input_dim); 5933be168c0dSopenharmony_ci- static const int kInNum = 2; 5934be168c0dSopenharmony_ci- const std::vector<int> &dim0_; 5935be168c0dSopenharmony_ci- const std::vector<int> &dim1_; 5936be168c0dSopenharmony_ci-}; 5937be168c0dSopenharmony_ci- 5938be168c0dSopenharmony_ci-class VectorDiv { 5939be168c0dSopenharmony_ci- public: 5940be168c0dSopenharmony_ci- VectorDiv() {} 5941be168c0dSopenharmony_ci- std::vector<int> operator()(const std::vector<int> &x, const std::vector<int> &d); 5942be168c0dSopenharmony_ci-}; 5943be168c0dSopenharmony_ci- 5944be168c0dSopenharmony_ci-class ShapeReduce { 5945be168c0dSopenharmony_ci- public: 5946be168c0dSopenharmony_ci- ShapeReduce() {} 5947be168c0dSopenharmony_ci- std::vector<int> operator()(const std::vector<int> &x_shape, const std::vector<int> &axis); 5948be168c0dSopenharmony_ci-}; 5949be168c0dSopenharmony_ci-} // namespace lite 5950be168c0dSopenharmony_ci-} // namespace mindspore 5951be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_UTILS_H_ 5952be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/param.cc b/mindspore/lite/src/expression/param.cc 5953be168c0dSopenharmony_cideleted file mode 100644 5954be168c0dSopenharmony_ciindex 284cf141..00000000 5955be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/param.cc 5956be168c0dSopenharmony_ci+++ /dev/null 5957be168c0dSopenharmony_ci@@ -1,70 +0,0 @@ 5958be168c0dSopenharmony_ci-/** 5959be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 5960be168c0dSopenharmony_ci- * 5961be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 5962be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 5963be168c0dSopenharmony_ci- * You may obtain a copy of the License at 5964be168c0dSopenharmony_ci- * 5965be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 5966be168c0dSopenharmony_ci- * 5967be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 5968be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 5969be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5970be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 5971be168c0dSopenharmony_ci- * limitations under the License. 5972be168c0dSopenharmony_ci- */ 5973be168c0dSopenharmony_ci- 5974be168c0dSopenharmony_ci-#include "src/expression/param.h" 5975be168c0dSopenharmony_ci-#include <random> 5976be168c0dSopenharmony_ci-#include <algorithm> 5977be168c0dSopenharmony_ci-#include <string> 5978be168c0dSopenharmony_ci-#include "include/errorcode.h" 5979be168c0dSopenharmony_ci- 5980be168c0dSopenharmony_ci-using mindspore::lite::RET_ERROR; 5981be168c0dSopenharmony_ci-using mindspore::lite::RET_OK; 5982be168c0dSopenharmony_ci- 5983be168c0dSopenharmony_ci-constexpr float kZero = 0.0f; 5984be168c0dSopenharmony_ci-constexpr float kOne = 1.0f; 5985be168c0dSopenharmony_ci- 5986be168c0dSopenharmony_ci-namespace mindspore { 5987be168c0dSopenharmony_ci-namespace lite { 5988be168c0dSopenharmony_ci-int Param::Fill(Mode mode) { 5989be168c0dSopenharmony_ci- std::default_random_engine engine{static_cast<unsigned int>(0)}; 5990be168c0dSopenharmony_ci- std::vector<float> data(size_); 5991be168c0dSopenharmony_ci- switch (mode) { 5992be168c0dSopenharmony_ci- case NORMAL: { 5993be168c0dSopenharmony_ci- constexpr float scale = 0.01; 5994be168c0dSopenharmony_ci- std::normal_distribution<float> n{0, 1}; 5995be168c0dSopenharmony_ci- std::generate_n(data.begin(), size_, [&]() { return n(engine); }); 5996be168c0dSopenharmony_ci- (void)std::transform(data.begin(), data.end(), data.begin(), [=](float x) { return x * scale; }); 5997be168c0dSopenharmony_ci- break; 5998be168c0dSopenharmony_ci- } 5999be168c0dSopenharmony_ci- case UNIFORM: { 6000be168c0dSopenharmony_ci- constexpr float scale = 0.07; 6001be168c0dSopenharmony_ci- std::uniform_real_distribution<float> u{-1.0, 1.0}; 6002be168c0dSopenharmony_ci- std::generate_n(data.begin(), size_, [&]() { return u(engine) * scale; }); 6003be168c0dSopenharmony_ci- break; 6004be168c0dSopenharmony_ci- } 6005be168c0dSopenharmony_ci- case ZEROS: 6006be168c0dSopenharmony_ci- std::fill_n(data.begin(), size_, kZero); 6007be168c0dSopenharmony_ci- break; 6008be168c0dSopenharmony_ci- case ONES: 6009be168c0dSopenharmony_ci- std::fill_n(data.begin(), size_, kOne); 6010be168c0dSopenharmony_ci- break; 6011be168c0dSopenharmony_ci- case NOT_SUPPORTED: 6012be168c0dSopenharmony_ci- return RET_ERROR; 6013be168c0dSopenharmony_ci- } 6014be168c0dSopenharmony_ci- Copy(data); 6015be168c0dSopenharmony_ci- return RET_OK; 6016be168c0dSopenharmony_ci-} 6017be168c0dSopenharmony_ci- 6018be168c0dSopenharmony_ci-Param::Mode Param::String2Enum(std::string mode) { 6019be168c0dSopenharmony_ci- (void)std::transform(mode.begin(), mode.end(), mode.begin(), ::tolower); 6020be168c0dSopenharmony_ci- if (mode == "normal") return NORMAL; 6021be168c0dSopenharmony_ci- if (mode == "uniform") return UNIFORM; 6022be168c0dSopenharmony_ci- if (mode == "ones") return ONES; 6023be168c0dSopenharmony_ci- if (mode == "zeors") return ZEROS; 6024be168c0dSopenharmony_ci- return NOT_SUPPORTED; 6025be168c0dSopenharmony_ci-} 6026be168c0dSopenharmony_ci-} // namespace lite 6027be168c0dSopenharmony_ci-} // namespace mindspore 6028be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/param.h b/mindspore/lite/src/expression/param.h 6029be168c0dSopenharmony_cideleted file mode 100644 6030be168c0dSopenharmony_ciindex 201e69fc..00000000 6031be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/param.h 6032be168c0dSopenharmony_ci+++ /dev/null 6033be168c0dSopenharmony_ci@@ -1,60 +0,0 @@ 6034be168c0dSopenharmony_ci-/** 6035be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6036be168c0dSopenharmony_ci- * 6037be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6038be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6039be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6040be168c0dSopenharmony_ci- * 6041be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6042be168c0dSopenharmony_ci- * 6043be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6044be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6045be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6046be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6047be168c0dSopenharmony_ci- * limitations under the License. 6048be168c0dSopenharmony_ci- */ 6049be168c0dSopenharmony_ci- 6050be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_ 6051be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_ 6052be168c0dSopenharmony_ci- 6053be168c0dSopenharmony_ci-#include <vector> 6054be168c0dSopenharmony_ci-#include <iostream> 6055be168c0dSopenharmony_ci-#include <fstream> 6056be168c0dSopenharmony_ci-#include <string> 6057be168c0dSopenharmony_ci- 6058be168c0dSopenharmony_ci-namespace mindspore { 6059be168c0dSopenharmony_ci-namespace lite { 6060be168c0dSopenharmony_ci-class Param { 6061be168c0dSopenharmony_ci- public: 6062be168c0dSopenharmony_ci- enum Mode { NORMAL, UNIFORM, ONES, ZEROS, NOT_SUPPORTED }; 6063be168c0dSopenharmony_ci- int Fill(Mode type); 6064be168c0dSopenharmony_ci- static Mode String2Enum(std::string); 6065be168c0dSopenharmony_ci- std::vector<uint8_t> &data() { return data_; } 6066be168c0dSopenharmony_ci- size_t Load(std::string file_name, size_t offset = 0) { return data_.size() * sizeof(float); } 6067be168c0dSopenharmony_ci- size_t Load(std::ifstream &s, int offset = 0) { return data_.size() * sizeof(float); } 6068be168c0dSopenharmony_ci- void SetSize(size_t size) { size_ = size; } 6069be168c0dSopenharmony_ci- template <typename T> 6070be168c0dSopenharmony_ci- void Copy(const T *data, size_t size) { 6071be168c0dSopenharmony_ci- auto cast_data = reinterpret_cast<const uint8_t *>(data); 6072be168c0dSopenharmony_ci- data_ = decltype(data_)(cast_data, cast_data + size * sizeof(T) / sizeof(uint8_t)); 6073be168c0dSopenharmony_ci- } 6074be168c0dSopenharmony_ci- template <typename T> 6075be168c0dSopenharmony_ci- void Copy(const std::vector<T> data) { 6076be168c0dSopenharmony_ci- Copy<T>(data.data(), data.size()); 6077be168c0dSopenharmony_ci- } 6078be168c0dSopenharmony_ci- 6079be168c0dSopenharmony_ci- template <typename T> 6080be168c0dSopenharmony_ci- std::vector<T> Extract() { 6081be168c0dSopenharmony_ci- T *num = reinterpret_cast<T *>(data_.data()); 6082be168c0dSopenharmony_ci- std::vector<T> res(num, num + data_.size() / sizeof(T)); 6083be168c0dSopenharmony_ci- return res; 6084be168c0dSopenharmony_ci- } 6085be168c0dSopenharmony_ci- 6086be168c0dSopenharmony_ci- private: 6087be168c0dSopenharmony_ci- size_t size_; 6088be168c0dSopenharmony_ci- std::vector<uint8_t> data_; 6089be168c0dSopenharmony_ci-}; 6090be168c0dSopenharmony_ci-} // namespace lite 6091be168c0dSopenharmony_ci-} // namespace mindspore 6092be168c0dSopenharmony_ci- 6093be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_PARAM_H_ 6094be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/sequential.cc b/mindspore/lite/src/expression/sequential.cc 6095be168c0dSopenharmony_cideleted file mode 100644 6096be168c0dSopenharmony_ciindex 5f3a8a76..00000000 6097be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/sequential.cc 6098be168c0dSopenharmony_ci+++ /dev/null 6099be168c0dSopenharmony_ci@@ -1,30 +0,0 @@ 6100be168c0dSopenharmony_ci-/** 6101be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6102be168c0dSopenharmony_ci- * 6103be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6104be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6105be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6106be168c0dSopenharmony_ci- * 6107be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6108be168c0dSopenharmony_ci- * 6109be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6110be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6111be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6112be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6113be168c0dSopenharmony_ci- * limitations under the License. 6114be168c0dSopenharmony_ci- */ 6115be168c0dSopenharmony_ci-#include "src/expression/sequential.h" 6116be168c0dSopenharmony_ci- 6117be168c0dSopenharmony_ci-namespace mindspore { 6118be168c0dSopenharmony_ci-namespace lite { 6119be168c0dSopenharmony_ci-void Sequential::Add(Node *node) { PushOp(node); } 6120be168c0dSopenharmony_ci- 6121be168c0dSopenharmony_ci-std::vector<EXPR *> Sequential::construct(const std::vector<EXPR *> &inputs) { 6122be168c0dSopenharmony_ci- auto x = inputs; 6123be168c0dSopenharmony_ci- for (auto &node : ops_) { 6124be168c0dSopenharmony_ci- x = (*node)({x.front()}); 6125be168c0dSopenharmony_ci- } 6126be168c0dSopenharmony_ci- return x; 6127be168c0dSopenharmony_ci-} 6128be168c0dSopenharmony_ci-} // namespace lite 6129be168c0dSopenharmony_ci-} // namespace mindspore 6130be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/sequential.h b/mindspore/lite/src/expression/sequential.h 6131be168c0dSopenharmony_cideleted file mode 100644 6132be168c0dSopenharmony_ciindex 9b1a69e5..00000000 6133be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/sequential.h 6134be168c0dSopenharmony_ci+++ /dev/null 6135be168c0dSopenharmony_ci@@ -1,32 +0,0 @@ 6136be168c0dSopenharmony_ci-/** 6137be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6138be168c0dSopenharmony_ci- * 6139be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6140be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6141be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6142be168c0dSopenharmony_ci- * 6143be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6144be168c0dSopenharmony_ci- * 6145be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6146be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6147be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6148be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6149be168c0dSopenharmony_ci- * limitations under the License. 6150be168c0dSopenharmony_ci- */ 6151be168c0dSopenharmony_ci- 6152be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_ 6153be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_ 6154be168c0dSopenharmony_ci-#include <vector> 6155be168c0dSopenharmony_ci-#include "src/expression/net.h" 6156be168c0dSopenharmony_ci- 6157be168c0dSopenharmony_ci-namespace mindspore { 6158be168c0dSopenharmony_ci-namespace lite { 6159be168c0dSopenharmony_ci-class Sequential : public Net { 6160be168c0dSopenharmony_ci- public: 6161be168c0dSopenharmony_ci- std::vector<EXPR *> construct(const std::vector<EXPR *> &inputs) override; 6162be168c0dSopenharmony_ci- void Add(Node *node) override; 6163be168c0dSopenharmony_ci-}; 6164be168c0dSopenharmony_ci-} // namespace lite 6165be168c0dSopenharmony_ci-} // namespace mindspore 6166be168c0dSopenharmony_ci- 6167be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_SEQUENTIAL_H_ 6168be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/expression/net.cc b/mindspore/lite/src/litert/cxx_api/expression/net.cc 6169be168c0dSopenharmony_cideleted file mode 100644 6170be168c0dSopenharmony_ciindex fd590e94..00000000 6171be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/expression/net.cc 6172be168c0dSopenharmony_ci+++ /dev/null 6173be168c0dSopenharmony_ci@@ -1,145 +0,0 @@ 6174be168c0dSopenharmony_ci-/** 6175be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6176be168c0dSopenharmony_ci- * 6177be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6178be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6179be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6180be168c0dSopenharmony_ci- * 6181be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6182be168c0dSopenharmony_ci- * 6183be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6184be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6185be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6186be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6187be168c0dSopenharmony_ci- * limitations under the License. 6188be168c0dSopenharmony_ci- */ 6189be168c0dSopenharmony_ci- 6190be168c0dSopenharmony_ci-#include "include/api/net.h" 6191be168c0dSopenharmony_ci-#include "include/api/status.h" 6192be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 6193be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 6194be168c0dSopenharmony_ci-#include "src/expression/ops.h" 6195be168c0dSopenharmony_ci-#include "src/expression/cfg.h" 6196be168c0dSopenharmony_ci- 6197be168c0dSopenharmony_ci-namespace mindspore { 6198be168c0dSopenharmony_ci-uint32_t Node::type() { return kNodeType; } 6199be168c0dSopenharmony_ci- 6200be168c0dSopenharmony_ci-std::vector<Expr *> Node::operator()(const std::vector<Expr *> &inputs) { 6201be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 6202be168c0dSopenharmony_ci- if (impl_ == nullptr) { 6203be168c0dSopenharmony_ci- MS_LOG(ERROR) << "empty implementation"; 6204be168c0dSopenharmony_ci- return {}; 6205be168c0dSopenharmony_ci- } 6206be168c0dSopenharmony_ci- if (impl_->node() == nullptr) { 6207be168c0dSopenharmony_ci- MS_LOG(ERROR) << "expression node is not attached"; 6208be168c0dSopenharmony_ci- return {}; 6209be168c0dSopenharmony_ci- } 6210be168c0dSopenharmony_ci- auto out = impl_->node()->construct(in); 6211be168c0dSopenharmony_ci- return Expr::convert(out); 6212be168c0dSopenharmony_ci-} 6213be168c0dSopenharmony_ci- 6214be168c0dSopenharmony_ci-Expr *Node::Create(std::string name) { 6215be168c0dSopenharmony_ci- auto expr = impl_->node()->create(name); 6216be168c0dSopenharmony_ci- return reinterpret_cast<Expr *>(expr); 6217be168c0dSopenharmony_ci-} 6218be168c0dSopenharmony_ci- 6219be168c0dSopenharmony_ci-Node::Node() { 6220be168c0dSopenharmony_ci- auto impl = std::make_shared<NodeImpl>(); 6221be168c0dSopenharmony_ci- impl_ = impl; 6222be168c0dSopenharmony_ci- impl_->set_pnode(this); 6223be168c0dSopenharmony_ci-} 6224be168c0dSopenharmony_ci- 6225be168c0dSopenharmony_ci-Node::~Node() { 6226be168c0dSopenharmony_ci- impl_->set_pnode(nullptr); 6227be168c0dSopenharmony_ci- auto node = impl_->node(); 6228be168c0dSopenharmony_ci- if (node != nullptr) { 6229be168c0dSopenharmony_ci- impl_->set_node(nullptr); 6230be168c0dSopenharmony_ci- delete node; 6231be168c0dSopenharmony_ci- } 6232be168c0dSopenharmony_ci-} 6233be168c0dSopenharmony_ci- 6234be168c0dSopenharmony_ci-Net::Net(std::string name) { 6235be168c0dSopenharmony_ci- auto impl = std::make_shared<NetImpl>(); 6236be168c0dSopenharmony_ci- if (impl == nullptr) { 6237be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate network implementation"; 6238be168c0dSopenharmony_ci- return; 6239be168c0dSopenharmony_ci- } 6240be168c0dSopenharmony_ci- impl_ = impl; 6241be168c0dSopenharmony_ci- impl_->set_pnet(std::shared_ptr<Net>(this)); 6242be168c0dSopenharmony_ci- auto netl = new (std::nothrow) lite::Net(name); 6243be168c0dSopenharmony_ci- if (netl == nullptr) { 6244be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate network lite"; 6245be168c0dSopenharmony_ci- return; 6246be168c0dSopenharmony_ci- } 6247be168c0dSopenharmony_ci- netl->set_impl(impl); 6248be168c0dSopenharmony_ci- impl_->set_net(netl); 6249be168c0dSopenharmony_ci-} 6250be168c0dSopenharmony_ci- 6251be168c0dSopenharmony_ci-Net::Net() : Net("") {} 6252be168c0dSopenharmony_ci- 6253be168c0dSopenharmony_ci-Net::Net(const Graph &g) { 6254be168c0dSopenharmony_ci- auto net = NetImpl::GetNet(g); 6255be168c0dSopenharmony_ci- impl_ = net->impl_; 6256be168c0dSopenharmony_ci-} 6257be168c0dSopenharmony_ci- 6258be168c0dSopenharmony_ci-void Net::Add(NetBase *element) { MS_LOG(WARNING) << "Only sequential can add element"; } 6259be168c0dSopenharmony_ci- 6260be168c0dSopenharmony_ci-uint32_t Net::type() { return kNetType; } 6261be168c0dSopenharmony_ci- 6262be168c0dSopenharmony_ci-std::vector<Expr *> Net::construct(const std::vector<Expr *> &inputs) { 6263be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 6264be168c0dSopenharmony_ci- auto out = impl_->net()->construct(in); 6265be168c0dSopenharmony_ci- return Expr::convert(out); 6266be168c0dSopenharmony_ci-} 6267be168c0dSopenharmony_ci- 6268be168c0dSopenharmony_ci-std::vector<Expr *> Net::operator()(const std::vector<Expr *> &inputs) { 6269be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 6270be168c0dSopenharmony_ci- auto x = construct(inputs); 6271be168c0dSopenharmony_ci- impl_->net()->input_ = in; 6272be168c0dSopenharmony_ci- auto out = Expr::convert(x); 6273be168c0dSopenharmony_ci- impl_->net()->output_ = out; 6274be168c0dSopenharmony_ci- impl_->net()->real_output_ = out; 6275be168c0dSopenharmony_ci- return x; 6276be168c0dSopenharmony_ci-} 6277be168c0dSopenharmony_ci-void Net::Register(Net *net, std::string &&name) { 6278be168c0dSopenharmony_ci- if (net != nullptr) { 6279be168c0dSopenharmony_ci- auto net_lite = net->impl_->net(); 6280be168c0dSopenharmony_ci- impl_->net()->Register(net_lite, std::move(name)); 6281be168c0dSopenharmony_ci- } 6282be168c0dSopenharmony_ci-} 6283be168c0dSopenharmony_ci- 6284be168c0dSopenharmony_ci-void Net::Register(Node *node, std::string &&name) { 6285be168c0dSopenharmony_ci- if (node != nullptr) { 6286be168c0dSopenharmony_ci- auto impl = NodeImpl::GetImpl(node); 6287be168c0dSopenharmony_ci- if (impl == nullptr) { 6288be168c0dSopenharmony_ci- MS_LOG(ERROR) << "missing implementation"; 6289be168c0dSopenharmony_ci- return; 6290be168c0dSopenharmony_ci- } 6291be168c0dSopenharmony_ci- auto node_lite = impl->node(); 6292be168c0dSopenharmony_ci- impl_->net()->Register(node_lite, std::move(name)); 6293be168c0dSopenharmony_ci- } 6294be168c0dSopenharmony_ci-} 6295be168c0dSopenharmony_ci- 6296be168c0dSopenharmony_ci-std::shared_ptr<NodeSet> Net::trainable_params() { 6297be168c0dSopenharmony_ci- auto node_set = std::make_shared<NodeSet>(); 6298be168c0dSopenharmony_ci- if (node_set == nullptr) { 6299be168c0dSopenharmony_ci- MS_LOG(ERROR) << "new NodeSet failed."; 6300be168c0dSopenharmony_ci- return nullptr; 6301be168c0dSopenharmony_ci- } 6302be168c0dSopenharmony_ci- node_set->set_ = impl_->net()->trainable_params(); 6303be168c0dSopenharmony_ci- return node_set; 6304be168c0dSopenharmony_ci-} 6305be168c0dSopenharmony_ci- 6306be168c0dSopenharmony_ci-const std::vector<int> Net::InputShape(int idx) { return impl_->InputShape(idx); } 6307be168c0dSopenharmony_ci-const std::vector<int> Net::OutputShape(int idx) { return impl_->OutputShape(idx); } 6308be168c0dSopenharmony_ci- 6309be168c0dSopenharmony_ci-Net::~Net() { 6310be168c0dSopenharmony_ci- if (impl_ != nullptr) { 6311be168c0dSopenharmony_ci- if ((impl_->pnet() == nullptr) || (impl_->pnet() == this)) { 6312be168c0dSopenharmony_ci- impl_->set_pnet(nullptr); 6313be168c0dSopenharmony_ci- impl_->set_net(nullptr); 6314be168c0dSopenharmony_ci- impl_.reset(); 6315be168c0dSopenharmony_ci- } 6316be168c0dSopenharmony_ci- } 6317be168c0dSopenharmony_ci-} 6318be168c0dSopenharmony_ci-} // namespace mindspore 6319be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/expression/net_impl.cc b/mindspore/lite/src/litert/cxx_api/expression/net_impl.cc 6320be168c0dSopenharmony_cideleted file mode 100644 6321be168c0dSopenharmony_ciindex 5487d8c7..00000000 6322be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/expression/net_impl.cc 6323be168c0dSopenharmony_ci+++ /dev/null 6324be168c0dSopenharmony_ci@@ -1,220 +0,0 @@ 6325be168c0dSopenharmony_ci-/** 6326be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6327be168c0dSopenharmony_ci- * 6328be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6329be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6330be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6331be168c0dSopenharmony_ci- * 6332be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6333be168c0dSopenharmony_ci- * 6334be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6335be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6336be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6337be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6338be168c0dSopenharmony_ci- * limitations under the License. 6339be168c0dSopenharmony_ci- */ 6340be168c0dSopenharmony_ci- 6341be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 6342be168c0dSopenharmony_ci-#include <vector> 6343be168c0dSopenharmony_ci-#include <utility> 6344be168c0dSopenharmony_ci-#include "include/api/serialization.h" 6345be168c0dSopenharmony_ci-#include "src/expression/import.h" 6346be168c0dSopenharmony_ci-#include "src/expression/ops.h" 6347be168c0dSopenharmony_ci-#include "src/litert/cxx_api/model/model_impl.h" 6348be168c0dSopenharmony_ci- 6349be168c0dSopenharmony_ci-namespace { 6350be168c0dSopenharmony_ci-constexpr size_t kFlatbuffersBuilderInitSize = 1024; 6351be168c0dSopenharmony_ci-}; 6352be168c0dSopenharmony_ci- 6353be168c0dSopenharmony_ci-namespace mindspore { 6354be168c0dSopenharmony_ci-Sequential::Sequential() {} 6355be168c0dSopenharmony_ci- 6356be168c0dSopenharmony_ci-lite::Node *Sequential::GetNode(NetBase *element) { 6357be168c0dSopenharmony_ci- lite::Node *lite_node = nullptr; 6358be168c0dSopenharmony_ci- switch (element->type()) { 6359be168c0dSopenharmony_ci- case kNodeType: { 6360be168c0dSopenharmony_ci- Node *node = reinterpret_cast<Node *>(element); 6361be168c0dSopenharmony_ci- auto impl = NodeImpl::GetImpl(node); 6362be168c0dSopenharmony_ci- if (impl == nullptr) { 6363be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot find node implement"; 6364be168c0dSopenharmony_ci- return nullptr; 6365be168c0dSopenharmony_ci- } 6366be168c0dSopenharmony_ci- lite_node = impl->node(); 6367be168c0dSopenharmony_ci- break; 6368be168c0dSopenharmony_ci- } 6369be168c0dSopenharmony_ci- case kNetType: { 6370be168c0dSopenharmony_ci- auto net = reinterpret_cast<Net *>(element); 6371be168c0dSopenharmony_ci- auto impl = NetImpl::GetImpl(net); 6372be168c0dSopenharmony_ci- if (impl == nullptr) { 6373be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot find node implement"; 6374be168c0dSopenharmony_ci- return nullptr; 6375be168c0dSopenharmony_ci- } 6376be168c0dSopenharmony_ci- lite_node = impl->net(); 6377be168c0dSopenharmony_ci- break; 6378be168c0dSopenharmony_ci- } 6379be168c0dSopenharmony_ci- } 6380be168c0dSopenharmony_ci- return lite_node; 6381be168c0dSopenharmony_ci-} 6382be168c0dSopenharmony_ci- 6383be168c0dSopenharmony_ci-void Sequential::Add(NetBase *element) { 6384be168c0dSopenharmony_ci- lite::Node *node = GetNode(element); 6385be168c0dSopenharmony_ci- auto impl = NetImpl::GetImpl(this); 6386be168c0dSopenharmony_ci- if (impl == nullptr) { 6387be168c0dSopenharmony_ci- MS_LOG(ERROR) << "No implementation"; 6388be168c0dSopenharmony_ci- return; 6389be168c0dSopenharmony_ci- } 6390be168c0dSopenharmony_ci- impl->net()->Add(node); 6391be168c0dSopenharmony_ci-} 6392be168c0dSopenharmony_ci- 6393be168c0dSopenharmony_ci-NetWithLoss::NetWithLoss(Net *net, Node *loss) : net_(net), loss_fn_(loss) { 6394be168c0dSopenharmony_ci- REG(net_); 6395be168c0dSopenharmony_ci- Register(loss_fn_, "_loss_fn"); 6396be168c0dSopenharmony_ci-} 6397be168c0dSopenharmony_ci- 6398be168c0dSopenharmony_ci-std::vector<Expr *> NetWithLoss::construct(const std::vector<Expr *> &inputs) { 6399be168c0dSopenharmony_ci- if (inputs.size() != C2NUM) { 6400be168c0dSopenharmony_ci- MS_LOG(ERROR) << "need 2 inputs for loss"; 6401be168c0dSopenharmony_ci- return {}; 6402be168c0dSopenharmony_ci- } 6403be168c0dSopenharmony_ci- auto input = inputs[FIRST_INPUT]; 6404be168c0dSopenharmony_ci- auto label = inputs[SECOND_INPUT]; 6405be168c0dSopenharmony_ci- auto x = (*net_)({input}); 6406be168c0dSopenharmony_ci- x = (*loss_fn_)({x[FIRST_INPUT], label}); 6407be168c0dSopenharmony_ci- return x; 6408be168c0dSopenharmony_ci-} 6409be168c0dSopenharmony_ci- 6410be168c0dSopenharmony_ci-NetImpl::NetImpl(std::shared_ptr<Net> p) { pnet_ = p; } 6411be168c0dSopenharmony_ci- 6412be168c0dSopenharmony_ci-NetImpl::NetImpl(Graph *g) { pnet_ = g->net_data_->net(); } 6413be168c0dSopenharmony_ci- 6414be168c0dSopenharmony_ci-std::vector<lite::EXPR *> NetImpl::construct(const std::vector<lite::EXPR *> &inputs) { 6415be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 6416be168c0dSopenharmony_ci- auto out = pnet_->construct(in); 6417be168c0dSopenharmony_ci- return Expr::convert(out); 6418be168c0dSopenharmony_ci-} 6419be168c0dSopenharmony_ci- 6420be168c0dSopenharmony_ci-Net *NetImpl::Connect(std::shared_ptr<Net> net, lite::Net *lnet) { 6421be168c0dSopenharmony_ci- auto impl = GetImpl(net.get()); 6422be168c0dSopenharmony_ci- if (impl == nullptr) { 6423be168c0dSopenharmony_ci- MS_LOG(ERROR) << "missing implementation"; 6424be168c0dSopenharmony_ci- return nullptr; 6425be168c0dSopenharmony_ci- } 6426be168c0dSopenharmony_ci- impl->set_pnet(net); 6427be168c0dSopenharmony_ci- lnet->set_impl(impl); 6428be168c0dSopenharmony_ci- impl->set_net(lnet); 6429be168c0dSopenharmony_ci- return net.get(); 6430be168c0dSopenharmony_ci-} 6431be168c0dSopenharmony_ci- 6432be168c0dSopenharmony_ci-Status NetImpl::Import(const char *model_buf, Graph *graph) { 6433be168c0dSopenharmony_ci- auto mg = schema::GetMetaGraph(model_buf); 6434be168c0dSopenharmony_ci- auto net = new (std::nothrow) Net(); 6435be168c0dSopenharmony_ci- if (net == nullptr) { 6436be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate network"; 6437be168c0dSopenharmony_ci- return kLiteMemoryFailed; 6438be168c0dSopenharmony_ci- } 6439be168c0dSopenharmony_ci- lite::Import import; 6440be168c0dSopenharmony_ci- auto lite_net = import.ImportMs(mg); 6441be168c0dSopenharmony_ci- if (lite_net == nullptr) { 6442be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to import net"; 6443be168c0dSopenharmony_ci- return kLiteMemoryFailed; 6444be168c0dSopenharmony_ci- } 6445be168c0dSopenharmony_ci- lite_net->SetRealOutput(); 6446be168c0dSopenharmony_ci- Connect(net->shared_from_this(), lite_net); 6447be168c0dSopenharmony_ci- *graph = Graph(net); 6448be168c0dSopenharmony_ci- return kSuccess; 6449be168c0dSopenharmony_ci-} 6450be168c0dSopenharmony_ci- 6451be168c0dSopenharmony_ci-Status NetImpl::TrainNet(Node *optimizer, const std::vector<Expr *> &inputs) { 6452be168c0dSopenharmony_ci- auto impl = NodeImpl::GetImpl(optimizer); 6453be168c0dSopenharmony_ci- if (impl == nullptr) { 6454be168c0dSopenharmony_ci- MS_LOG(ERROR) << "missing implementation "; 6455be168c0dSopenharmony_ci- return kLiteNullptr; 6456be168c0dSopenharmony_ci- } 6457be168c0dSopenharmony_ci- auto opt = impl->node(); 6458be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 6459be168c0dSopenharmony_ci- auto ret_net = net()->TrainNet(opt, in); 6460be168c0dSopenharmony_ci- if (ret_net == nullptr) { 6461be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to train network"; 6462be168c0dSopenharmony_ci- return kLiteNullptr; 6463be168c0dSopenharmony_ci- } 6464be168c0dSopenharmony_ci- return kSuccess; 6465be168c0dSopenharmony_ci-} 6466be168c0dSopenharmony_ci- 6467be168c0dSopenharmony_ci-std::unique_ptr<Graph> NetImpl::MakeMs() { 6468be168c0dSopenharmony_ci- auto mgraph = std::make_unique<Graph>(Graph::Type::kExecutableGraph); 6469be168c0dSopenharmony_ci- if (mgraph == nullptr) { 6470be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate graph"; 6471be168c0dSopenharmony_ci- return nullptr; 6472be168c0dSopenharmony_ci- } 6473be168c0dSopenharmony_ci- auto trained_graph = net()->MakeMs(); 6474be168c0dSopenharmony_ci- if (trained_graph == nullptr) { 6475be168c0dSopenharmony_ci- MS_LOG(ERROR) << "cannot create flat buffer"; 6476be168c0dSopenharmony_ci- return nullptr; 6477be168c0dSopenharmony_ci- } 6478be168c0dSopenharmony_ci- flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize); 6479be168c0dSopenharmony_ci- auto offset = schema::MetaGraph::Pack(builder, trained_graph.get()); 6480be168c0dSopenharmony_ci- builder.Finish(offset); 6481be168c0dSopenharmony_ci- schema::FinishMetaGraphBuffer(builder, offset); 6482be168c0dSopenharmony_ci- auto buffer = builder.GetBufferPointer(); 6483be168c0dSopenharmony_ci- size_t size = builder.GetSize(); 6484be168c0dSopenharmony_ci- auto status = Serialization::Load(buffer, size, mindspore::kMindIR, mgraph.get()); 6485be168c0dSopenharmony_ci- if (status != kSuccess) { 6486be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to load flatbuffer to graph"; 6487be168c0dSopenharmony_ci- return nullptr; 6488be168c0dSopenharmony_ci- } 6489be168c0dSopenharmony_ci- return mgraph; 6490be168c0dSopenharmony_ci-} 6491be168c0dSopenharmony_ci- 6492be168c0dSopenharmony_ci-const std::vector<int> NetImpl::InputShape(int idx) { return net_->InputShape(idx); } 6493be168c0dSopenharmony_ci- 6494be168c0dSopenharmony_ci-const std::vector<int> NetImpl::OutputShape(int idx) { return net_->OutputShape(idx); } 6495be168c0dSopenharmony_ci- 6496be168c0dSopenharmony_ci-void NetImpl::ReplaceNet(Graph *g, std::shared_ptr<Net> n) { g->net_data_->net().swap(n); } 6497be168c0dSopenharmony_ci- 6498be168c0dSopenharmony_ci-ExpressionLoader expression_registrator = CreateExpressionLoader(NetImpl::Import); 6499be168c0dSopenharmony_ci- 6500be168c0dSopenharmony_ci-namespace NN { 6501be168c0dSopenharmony_ci-Net *Sequential() { 6502be168c0dSopenharmony_ci- auto net = new (std::nothrow) mindspore::Sequential(); 6503be168c0dSopenharmony_ci- if (net == nullptr) { 6504be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate "; 6505be168c0dSopenharmony_ci- return nullptr; 6506be168c0dSopenharmony_ci- } 6507be168c0dSopenharmony_ci- auto netl = lite::NN::Sequential(); 6508be168c0dSopenharmony_ci- return NetImpl::Connect(net->shared_from_this(), netl); 6509be168c0dSopenharmony_ci-} 6510be168c0dSopenharmony_ci- 6511be168c0dSopenharmony_ci-Net *NetWithLoss(Net *net, Node *loss) { 6512be168c0dSopenharmony_ci- auto loss_net = new (std::nothrow) mindspore::NetWithLoss(net, loss); 6513be168c0dSopenharmony_ci- if (net == nullptr) { 6514be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate loss net"; 6515be168c0dSopenharmony_ci- return nullptr; 6516be168c0dSopenharmony_ci- } 6517be168c0dSopenharmony_ci- return loss_net; 6518be168c0dSopenharmony_ci-} 6519be168c0dSopenharmony_ci- 6520be168c0dSopenharmony_ci-Graph *GraphWithLoss(Graph *graph, Node *loss) { 6521be168c0dSopenharmony_ci- auto net = NetImpl::GetNet(*graph); 6522be168c0dSopenharmony_ci- if (net == nullptr) { 6523be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate network"; 6524be168c0dSopenharmony_ci- return nullptr; 6525be168c0dSopenharmony_ci- } 6526be168c0dSopenharmony_ci- auto loss_net = NetWithLoss(net.get(), loss); 6527be168c0dSopenharmony_ci- if (loss_net == nullptr) { 6528be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate network"; 6529be168c0dSopenharmony_ci- return nullptr; 6530be168c0dSopenharmony_ci- } 6531be168c0dSopenharmony_ci- NetImpl::ReplaceNet(graph, loss_net->shared_from_this()); 6532be168c0dSopenharmony_ci- return graph; 6533be168c0dSopenharmony_ci-} 6534be168c0dSopenharmony_ci- 6535be168c0dSopenharmony_ci-Net *NetWithLoss(Graph *g, Node *loss) { 6536be168c0dSopenharmony_ci- auto net = new (std::nothrow) Net(*g); 6537be168c0dSopenharmony_ci- if (net == nullptr) { 6538be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate net"; 6539be168c0dSopenharmony_ci- return nullptr; 6540be168c0dSopenharmony_ci- } 6541be168c0dSopenharmony_ci- return NetWithLoss(net, loss); 6542be168c0dSopenharmony_ci-} 6543be168c0dSopenharmony_ci-} // namespace NN 6544be168c0dSopenharmony_ci-} // namespace mindspore 6545be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/expression/net_impl.h b/mindspore/lite/src/litert/cxx_api/expression/net_impl.h 6546be168c0dSopenharmony_cideleted file mode 100644 6547be168c0dSopenharmony_ciindex 682ba0b2..00000000 6548be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/expression/net_impl.h 6549be168c0dSopenharmony_ci+++ /dev/null 6550be168c0dSopenharmony_ci@@ -1,95 +0,0 @@ 6551be168c0dSopenharmony_ci-/** 6552be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6553be168c0dSopenharmony_ci- * 6554be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6555be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6556be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6557be168c0dSopenharmony_ci- * 6558be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6559be168c0dSopenharmony_ci- * 6560be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6561be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6562be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6563be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6564be168c0dSopenharmony_ci- * limitations under the License. 6565be168c0dSopenharmony_ci- */ 6566be168c0dSopenharmony_ci- 6567be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NET_IMPL_H_ 6568be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NET_IMPL_H_ 6569be168c0dSopenharmony_ci- 6570be168c0dSopenharmony_ci-#include <algorithm> 6571be168c0dSopenharmony_ci-#include <set> 6572be168c0dSopenharmony_ci-#include <memory> 6573be168c0dSopenharmony_ci-#include <vector> 6574be168c0dSopenharmony_ci-#include <utility> 6575be168c0dSopenharmony_ci-#include "include/api/cfg.h" 6576be168c0dSopenharmony_ci-#include "include/api/data_type.h" 6577be168c0dSopenharmony_ci-#include "include/api/graph.h" 6578be168c0dSopenharmony_ci-#include "include/api/status.h" 6579be168c0dSopenharmony_ci-#include "include/api/net.h" 6580be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 6581be168c0dSopenharmony_ci-#include "src/litert/cxx_api/graph/net_data.h" 6582be168c0dSopenharmony_ci-#include "src/expression/net.h" 6583be168c0dSopenharmony_ci-#include "src/expression/ops.h" 6584be168c0dSopenharmony_ci- 6585be168c0dSopenharmony_ci-namespace mindspore { 6586be168c0dSopenharmony_ci-constexpr uint32_t kNodeType = 1; 6587be168c0dSopenharmony_ci-constexpr uint32_t kNetType = 2; 6588be168c0dSopenharmony_ci-class Sequential : public Net { 6589be168c0dSopenharmony_ci- public: 6590be168c0dSopenharmony_ci- Sequential(); 6591be168c0dSopenharmony_ci- void Add(NetBase *n) override; 6592be168c0dSopenharmony_ci- 6593be168c0dSopenharmony_ci- private: 6594be168c0dSopenharmony_ci- std::vector<NetBase *> ops_; 6595be168c0dSopenharmony_ci- lite::Node *GetNode(NetBase *element); 6596be168c0dSopenharmony_ci-}; 6597be168c0dSopenharmony_ci- 6598be168c0dSopenharmony_ci-class NetWithLoss : public Net { 6599be168c0dSopenharmony_ci- public: 6600be168c0dSopenharmony_ci- NetWithLoss(Net *net, Node *loss); 6601be168c0dSopenharmony_ci- std::vector<Expr *> construct(const std::vector<Expr *> &inputs) override; 6602be168c0dSopenharmony_ci- 6603be168c0dSopenharmony_ci- private: 6604be168c0dSopenharmony_ci- Net *net_{nullptr}; 6605be168c0dSopenharmony_ci- Node *loss_fn_{nullptr}; 6606be168c0dSopenharmony_ci-}; 6607be168c0dSopenharmony_ci- 6608be168c0dSopenharmony_ci-class MS_API NetImpl { 6609be168c0dSopenharmony_ci- public: 6610be168c0dSopenharmony_ci- virtual ~NetImpl() {} 6611be168c0dSopenharmony_ci- explicit NetImpl(std::shared_ptr<Net> p); 6612be168c0dSopenharmony_ci- explicit NetImpl(Graph *g); 6613be168c0dSopenharmony_ci- NetImpl() = default; 6614be168c0dSopenharmony_ci- void set_net(lite::Net *net) { 6615be168c0dSopenharmony_ci- if (net_ != nullptr) { 6616be168c0dSopenharmony_ci- net_->set_impl(nullptr); 6617be168c0dSopenharmony_ci- delete net_; 6618be168c0dSopenharmony_ci- } 6619be168c0dSopenharmony_ci- net_ = net; 6620be168c0dSopenharmony_ci- } 6621be168c0dSopenharmony_ci- void erase_net() { net_ = nullptr; } 6622be168c0dSopenharmony_ci- void set_pnet(std::shared_ptr<Net> net) { pnet_ = net; } 6623be168c0dSopenharmony_ci- Net *pnet() { return pnet_.get(); } 6624be168c0dSopenharmony_ci- lite::Net *net() { return net_; } 6625be168c0dSopenharmony_ci- 6626be168c0dSopenharmony_ci- std::vector<lite::EXPR *> construct(const std::vector<lite::EXPR *> &inputs); 6627be168c0dSopenharmony_ci- static std::shared_ptr<mindspore::NetImpl> &GetImpl(Net *net) { return net->impl_; } 6628be168c0dSopenharmony_ci- static Net *Connect(std::shared_ptr<Net> net, lite::Net *lnet); 6629be168c0dSopenharmony_ci- static std::shared_ptr<Net> &GetNet(const Graph &g) { return g.net_data_->net(); } 6630be168c0dSopenharmony_ci- static void SetNet(Graph *g, std::shared_ptr<Net> n) { g->net_data_->set_net(n); } 6631be168c0dSopenharmony_ci- static void ReplaceNet(Graph *g, std::shared_ptr<Net> n); 6632be168c0dSopenharmony_ci- static Status Import(const char *model_buf, Graph *graph); 6633be168c0dSopenharmony_ci- Status TrainNet(Node *optimizer, const std::vector<Expr *> &inputs); 6634be168c0dSopenharmony_ci- const std::vector<int> InputShape(int idx); 6635be168c0dSopenharmony_ci- const std::vector<int> OutputShape(int idx); 6636be168c0dSopenharmony_ci- std::unique_ptr<Graph> MakeMs(); 6637be168c0dSopenharmony_ci- void Release() { pnet_.reset(); } 6638be168c0dSopenharmony_ci- 6639be168c0dSopenharmony_ci- private: 6640be168c0dSopenharmony_ci- std::shared_ptr<Net> pnet_; 6641be168c0dSopenharmony_ci- lite::Net *net_ = nullptr; 6642be168c0dSopenharmony_ci-}; 6643be168c0dSopenharmony_ci-} // namespace mindspore 6644be168c0dSopenharmony_ci- 6645be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NET_IMPL_H_ 6646be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/expression/node_impl.cc b/mindspore/lite/src/litert/cxx_api/expression/node_impl.cc 6647be168c0dSopenharmony_cideleted file mode 100644 6648be168c0dSopenharmony_ciindex df10a7d0..00000000 6649be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/expression/node_impl.cc 6650be168c0dSopenharmony_ci+++ /dev/null 6651be168c0dSopenharmony_ci@@ -1,50 +0,0 @@ 6652be168c0dSopenharmony_ci-/** 6653be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6654be168c0dSopenharmony_ci- * 6655be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6656be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6657be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6658be168c0dSopenharmony_ci- * 6659be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6660be168c0dSopenharmony_ci- * 6661be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6662be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6663be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6664be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6665be168c0dSopenharmony_ci- * limitations under the License. 6666be168c0dSopenharmony_ci- */ 6667be168c0dSopenharmony_ci- 6668be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 6669be168c0dSopenharmony_ci-#include <vector> 6670be168c0dSopenharmony_ci-#include "include/api/net.h" 6671be168c0dSopenharmony_ci-#include "src/expression/ops.h" 6672be168c0dSopenharmony_ci- 6673be168c0dSopenharmony_ci-namespace mindspore { 6674be168c0dSopenharmony_ci-Node *NodeImpl::Connect(lite::Node *lnode) { 6675be168c0dSopenharmony_ci- auto node = std::make_unique<Node>(); 6676be168c0dSopenharmony_ci- if (node == nullptr) { 6677be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Cannot allocate node"; 6678be168c0dSopenharmony_ci- return nullptr; 6679be168c0dSopenharmony_ci- } 6680be168c0dSopenharmony_ci- if (lnode == nullptr) { 6681be168c0dSopenharmony_ci- MS_LOG(ERROR) << "lite node is null"; 6682be168c0dSopenharmony_ci- return nullptr; 6683be168c0dSopenharmony_ci- } 6684be168c0dSopenharmony_ci- auto pnode = node.release(); 6685be168c0dSopenharmony_ci- auto impl = GetImpl(pnode); 6686be168c0dSopenharmony_ci- if (impl == nullptr) { 6687be168c0dSopenharmony_ci- MS_LOG(ERROR) << "missing implementation"; 6688be168c0dSopenharmony_ci- return nullptr; 6689be168c0dSopenharmony_ci- } 6690be168c0dSopenharmony_ci- impl->set_node(lnode); 6691be168c0dSopenharmony_ci- lnode->set_impl(impl); 6692be168c0dSopenharmony_ci- return pnode; 6693be168c0dSopenharmony_ci-} 6694be168c0dSopenharmony_ci-namespace NN { 6695be168c0dSopenharmony_ci-std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type, int fmt) { 6696be168c0dSopenharmony_ci- auto type = static_cast<TypeId>(data_type); 6697be168c0dSopenharmony_ci- auto lite_node = lite::NN::Input(dims, type, fmt); 6698be168c0dSopenharmony_ci- return std::unique_ptr<Node>(NodeImpl::Connect(lite_node)); 6699be168c0dSopenharmony_ci-} 6700be168c0dSopenharmony_ci-} // namespace NN 6701be168c0dSopenharmony_ci-} // namespace mindspore 6702be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/expression/node_impl.h b/mindspore/lite/src/litert/cxx_api/expression/node_impl.h 6703be168c0dSopenharmony_cideleted file mode 100644 6704be168c0dSopenharmony_ciindex a1ff4530..00000000 6705be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/expression/node_impl.h 6706be168c0dSopenharmony_ci+++ /dev/null 6707be168c0dSopenharmony_ci@@ -1,71 +0,0 @@ 6708be168c0dSopenharmony_ci-/** 6709be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6710be168c0dSopenharmony_ci- * 6711be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6712be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6713be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6714be168c0dSopenharmony_ci- * 6715be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6716be168c0dSopenharmony_ci- * 6717be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6718be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6719be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6720be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6721be168c0dSopenharmony_ci- * limitations under the License. 6722be168c0dSopenharmony_ci- */ 6723be168c0dSopenharmony_ci- 6724be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NODE_IMPL_H_ 6725be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NODE_IMPL_H_ 6726be168c0dSopenharmony_ci- 6727be168c0dSopenharmony_ci-#include <algorithm> 6728be168c0dSopenharmony_ci-#include <set> 6729be168c0dSopenharmony_ci-#include <memory> 6730be168c0dSopenharmony_ci-#include <vector> 6731be168c0dSopenharmony_ci-#include "include/api/net.h" 6732be168c0dSopenharmony_ci-#include "include/api/cfg.h" 6733be168c0dSopenharmony_ci-#include "include/api/data_type.h" 6734be168c0dSopenharmony_ci-#include "src/expression/node.h" 6735be168c0dSopenharmony_ci-#include "src/expression/expr.h" 6736be168c0dSopenharmony_ci- 6737be168c0dSopenharmony_ci-namespace mindspore { 6738be168c0dSopenharmony_ci-using lite::EXPR; 6739be168c0dSopenharmony_ci-class NodeSet { 6740be168c0dSopenharmony_ci- public: 6741be168c0dSopenharmony_ci- std::set<lite::Node *> set_; 6742be168c0dSopenharmony_ci-}; 6743be168c0dSopenharmony_ci- 6744be168c0dSopenharmony_ci-class Expr : public EXPR { 6745be168c0dSopenharmony_ci- public: 6746be168c0dSopenharmony_ci- static std::vector<EXPR *> convert(const std::vector<Expr *> &input) { 6747be168c0dSopenharmony_ci- std::vector<EXPR *> vec(input.size()); 6748be168c0dSopenharmony_ci- (void)std::transform(input.begin(), input.end(), vec.begin(), [](Expr *e) { return reinterpret_cast<EXPR *>(e); }); 6749be168c0dSopenharmony_ci- return vec; 6750be168c0dSopenharmony_ci- } 6751be168c0dSopenharmony_ci- static std::vector<Expr *> convert(const std::vector<EXPR *> &input) { 6752be168c0dSopenharmony_ci- std::vector<Expr *> vec(input.size()); 6753be168c0dSopenharmony_ci- (void)std::transform(input.begin(), input.end(), vec.begin(), [](EXPR *e) { return reinterpret_cast<Expr *>(e); }); 6754be168c0dSopenharmony_ci- return vec; 6755be168c0dSopenharmony_ci- } 6756be168c0dSopenharmony_ci-}; 6757be168c0dSopenharmony_ci- 6758be168c0dSopenharmony_ci-class MS_API NodeImpl { 6759be168c0dSopenharmony_ci- public: 6760be168c0dSopenharmony_ci- std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) { 6761be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 6762be168c0dSopenharmony_ci- auto out = (*node_)(in); 6763be168c0dSopenharmony_ci- return Expr::convert(out); 6764be168c0dSopenharmony_ci- } 6765be168c0dSopenharmony_ci- lite::Node *node() { return node_; } 6766be168c0dSopenharmony_ci- void set_node(lite::Node *node) { node_ = node; } 6767be168c0dSopenharmony_ci- void set_pnode(Node *node) { pnode_ = node; } 6768be168c0dSopenharmony_ci- Node *pnode() { return pnode_; } 6769be168c0dSopenharmony_ci- static Node *Connect(lite::Node *lnode); 6770be168c0dSopenharmony_ci- static std::shared_ptr<NodeImpl> &GetImpl(Node *node) { return node->impl_; } 6771be168c0dSopenharmony_ci- 6772be168c0dSopenharmony_ci- private: 6773be168c0dSopenharmony_ci- Node *pnode_{nullptr}; 6774be168c0dSopenharmony_ci- lite::Node *node_{nullptr}; 6775be168c0dSopenharmony_ci-}; 6776be168c0dSopenharmony_ci-} // namespace mindspore 6777be168c0dSopenharmony_ci- 6778be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_EXPRESSION_NODE_IMPL_H_ 6779be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/graph/graph.cc b/mindspore/lite/src/litert/cxx_api/graph/graph.cc 6780be168c0dSopenharmony_ciindex 35912580..5c1567cc 100644 6781be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/graph/graph.cc 6782be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/graph/graph.cc 6783be168c0dSopenharmony_ci@@ -16,9 +16,7 @@ 6784be168c0dSopenharmony_ci 6785be168c0dSopenharmony_ci #include "include/api/graph.h" 6786be168c0dSopenharmony_ci #include "include/api/cell.h" 6787be168c0dSopenharmony_ci-#include "include/api/net.h" 6788be168c0dSopenharmony_ci #include "src/litert/cxx_api/graph/graph_data.h" 6789be168c0dSopenharmony_ci-#include "src/litert/cxx_api/graph/net_data.h" 6790be168c0dSopenharmony_ci 6791be168c0dSopenharmony_ci namespace mindspore { 6792be168c0dSopenharmony_ci Graph::Graph() : graph_data_(nullptr) {} 6793be168c0dSopenharmony_ci@@ -27,15 +25,8 @@ Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_d 6794be168c0dSopenharmony_ci 6795be168c0dSopenharmony_ci Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {} 6796be168c0dSopenharmony_ci 6797be168c0dSopenharmony_ci-Graph::Graph(Graph::Type type) : graph_type_(type) {} 6798be168c0dSopenharmony_ci- 6799be168c0dSopenharmony_ci Graph::~Graph() {} 6800be168c0dSopenharmony_ci 6801be168c0dSopenharmony_ci-Graph::Graph(Net *net) : graph_type_(kExpressionGraph) { 6802be168c0dSopenharmony_ci- auto shared = std::make_shared<NetData>(net->shared_from_this()); 6803be168c0dSopenharmony_ci- net_data_ = shared; 6804be168c0dSopenharmony_ci-} 6805be168c0dSopenharmony_ci- 6806be168c0dSopenharmony_ci Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {} 6807be168c0dSopenharmony_ci 6808be168c0dSopenharmony_ci bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; } 6809be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/graph/net_data.cc b/mindspore/lite/src/litert/cxx_api/graph/net_data.cc 6810be168c0dSopenharmony_cideleted file mode 100644 6811be168c0dSopenharmony_ciindex 2de4d875..00000000 6812be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/graph/net_data.cc 6813be168c0dSopenharmony_ci+++ /dev/null 6814be168c0dSopenharmony_ci@@ -1,21 +0,0 @@ 6815be168c0dSopenharmony_ci-/** 6816be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6817be168c0dSopenharmony_ci- * 6818be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6819be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6820be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6821be168c0dSopenharmony_ci- * 6822be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6823be168c0dSopenharmony_ci- * 6824be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6825be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6826be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6827be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6828be168c0dSopenharmony_ci- * limitations under the License. 6829be168c0dSopenharmony_ci- */ 6830be168c0dSopenharmony_ci-#include "src/litert/cxx_api/graph/net_data.h" 6831be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 6832be168c0dSopenharmony_ci- 6833be168c0dSopenharmony_ci-namespace mindspore { 6834be168c0dSopenharmony_ci-NetData::~NetData() { net_->impl_->Release(); } 6835be168c0dSopenharmony_ci-} // namespace mindspore 6836be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/graph/net_data.h b/mindspore/lite/src/litert/cxx_api/graph/net_data.h 6837be168c0dSopenharmony_cideleted file mode 100644 6838be168c0dSopenharmony_ciindex 15393d49..00000000 6839be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/graph/net_data.h 6840be168c0dSopenharmony_ci+++ /dev/null 6841be168c0dSopenharmony_ci@@ -1,35 +0,0 @@ 6842be168c0dSopenharmony_ci-/** 6843be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 6844be168c0dSopenharmony_ci- * 6845be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 6846be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 6847be168c0dSopenharmony_ci- * You may obtain a copy of the License at 6848be168c0dSopenharmony_ci- * 6849be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 6850be168c0dSopenharmony_ci- * 6851be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 6852be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 6853be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6854be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 6855be168c0dSopenharmony_ci- * limitations under the License. 6856be168c0dSopenharmony_ci- */ 6857be168c0dSopenharmony_ci- 6858be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_RUNTIME_CXX_API_GRAPH_NET_DATA_H_ 6859be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_RUNTIME_CXX_API_GRAPH_NET_DATA_H_ 6860be168c0dSopenharmony_ci- 6861be168c0dSopenharmony_ci-#include <memory> 6862be168c0dSopenharmony_ci-#include "include/api/net.h" 6863be168c0dSopenharmony_ci- 6864be168c0dSopenharmony_ci-namespace mindspore { 6865be168c0dSopenharmony_ci-class NetData { 6866be168c0dSopenharmony_ci- public: 6867be168c0dSopenharmony_ci- explicit NetData(const std::shared_ptr<Net> &net) : net_(net) {} 6868be168c0dSopenharmony_ci- virtual ~NetData(); 6869be168c0dSopenharmony_ci- void set_net(std::shared_ptr<Net> net) { net_ = net; } 6870be168c0dSopenharmony_ci- std::shared_ptr<Net> &net() { return net_; } 6871be168c0dSopenharmony_ci- 6872be168c0dSopenharmony_ci- private: 6873be168c0dSopenharmony_ci- std::shared_ptr<Net> net_; 6874be168c0dSopenharmony_ci-}; 6875be168c0dSopenharmony_ci-} // namespace mindspore 6876be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_RUNTIME_CXX_API_GRAPH_NET_DATA_H_ 6877be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/model/model.cc b/mindspore/lite/src/litert/cxx_api/model/model.cc 6878be168c0dSopenharmony_ciindex 081b4dfd..7bbaca5c 100644 6879be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/model/model.cc 6880be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/model/model.cc 6881be168c0dSopenharmony_ci@@ -30,7 +30,6 @@ 6882be168c0dSopenharmony_ci #if defined(ENABLE_PRE_INFERENCE) && defined(__linux__) && !defined(Debug) 6883be168c0dSopenharmony_ci #include "src/common/thread_utils.h" 6884be168c0dSopenharmony_ci #endif 6885be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 6886be168c0dSopenharmony_ci #include "src/litert/cxx_api/callback/callback_adapter.h" 6887be168c0dSopenharmony_ci #include "src/litert/cxx_api/callback/callback_impl.h" 6888be168c0dSopenharmony_ci #include "src/litert/cxx_api/model/model_impl.h" 6889be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 6890be168c0dSopenharmony_ciindex 77ee95ab..78b1ca67 100644 6891be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 6892be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 6893be168c0dSopenharmony_ci@@ -74,14 +74,6 @@ CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProt 6894be168c0dSopenharmony_ci return proto_; 6895be168c0dSopenharmony_ci } 6896be168c0dSopenharmony_ci 6897be168c0dSopenharmony_ci-ExpressionLoader CreateExpressionLoader(const ExpressionLoader &loader) { 6898be168c0dSopenharmony_ci- static ExpressionLoader loader_ = nullptr; 6899be168c0dSopenharmony_ci- if (loader != nullptr) { 6900be168c0dSopenharmony_ci- loader_ = loader; 6901be168c0dSopenharmony_ci- } 6902be168c0dSopenharmony_ci- return loader_; 6903be168c0dSopenharmony_ci-} 6904be168c0dSopenharmony_ci- 6905be168c0dSopenharmony_ci #if defined(ENABLE_PRE_INFERENCE) && defined(__linux__) && !defined(Debug) 6906be168c0dSopenharmony_ci Status ModelImpl::BuildAndRun(const void *model_data, size_t data_size, ModelType model_type, 6907be168c0dSopenharmony_ci const std::shared_ptr<Context> &model_context) { 6908be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.h b/mindspore/lite/src/litert/cxx_api/model/model_impl.h 6909be168c0dSopenharmony_ciindex 19433cce..8e11ee55 100644 6910be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.h 6911be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.h 6912be168c0dSopenharmony_ci@@ -49,9 +49,6 @@ typedef std::shared_ptr<lite::LiteSession>(CreateTrainSessionProto)(std::shared_ 6913be168c0dSopenharmony_ci const std::shared_ptr<lite::InnerContext> &context); 6914be168c0dSopenharmony_ci MS_API CreateTrainSessionProto *CreateTrainSessionCallbackHolder(CreateTrainSessionProto *proto = nullptr); 6915be168c0dSopenharmony_ci 6916be168c0dSopenharmony_ci-using ExpressionLoader = std::function<Status(const char *, Graph *)>; 6917be168c0dSopenharmony_ci-MS_API ExpressionLoader CreateExpressionLoader(const ExpressionLoader &loader = nullptr); 6918be168c0dSopenharmony_ci- 6919be168c0dSopenharmony_ci namespace session { 6920be168c0dSopenharmony_ci class Metrics; 6921be168c0dSopenharmony_ci class TrainLoopCallBack; 6922be168c0dSopenharmony_ci@@ -106,7 +103,6 @@ class ModelImpl { 6923be168c0dSopenharmony_ci 6924be168c0dSopenharmony_ci static bool CheckModelSupport(const std::string &device_type, ModelType model_type); 6925be168c0dSopenharmony_ci bool IsTrainModel(); 6926be168c0dSopenharmony_ci- std::unique_ptr<Graph> BuildTrain(Node *optimizer, std::vector<Expr *> inputs); 6927be168c0dSopenharmony_ci Status SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum); 6928be168c0dSopenharmony_ci Status SetLearningRate(float learning_rate); 6929be168c0dSopenharmony_ci float GetLearningRate(); 6930be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/serialization.cc b/mindspore/lite/src/litert/cxx_api/serialization.cc 6931be168c0dSopenharmony_ciindex 08dfaf61..1bc33a69 100644 6932be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/serialization.cc 6933be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/serialization.cc 6934be168c0dSopenharmony_ci@@ -20,7 +20,6 @@ 6935be168c0dSopenharmony_ci #include "include/api/graph.h" 6936be168c0dSopenharmony_ci #include "include/api/types.h" 6937be168c0dSopenharmony_ci #include "include/model.h" 6938be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 6939be168c0dSopenharmony_ci #include "src/litert/cxx_api/graph/graph_data.h" 6940be168c0dSopenharmony_ci #include "src/litert/cxx_api/model/model_impl.h" 6941be168c0dSopenharmony_ci #include "src/litert/cxx_api/converters.h" 6942be168c0dSopenharmony_ci@@ -28,8 +27,6 @@ 6943be168c0dSopenharmony_ci #include "src/litert/lite_session.h" 6944be168c0dSopenharmony_ci 6945be168c0dSopenharmony_ci namespace mindspore { 6946be168c0dSopenharmony_ci-std::function<int(void *)> ExpressionCallback; 6947be168c0dSopenharmony_ci- 6948be168c0dSopenharmony_ci Key::Key(const char *dec_key, size_t key_len) { 6949be168c0dSopenharmony_ci len = 0; 6950be168c0dSopenharmony_ci if (key_len >= max_key_len) { 6951be168c0dSopenharmony_ci@@ -121,31 +118,19 @@ Status Serialization::Load(const std::vector<char> &file, ModelType model_type, 6952be168c0dSopenharmony_ci MS_LOG(ERROR) << "Read model file failed"; 6953be168c0dSopenharmony_ci return kLiteNullptr; 6954be168c0dSopenharmony_ci } 6955be168c0dSopenharmony_ci- if (graph->IsExecutable()) { 6956be168c0dSopenharmony_ci- auto model = 6957be168c0dSopenharmony_ci- std::shared_ptr<lite::Model>(lite::ImportFromBuffer(static_cast<const char *>(model_buf), model_size, true)); 6958be168c0dSopenharmony_ci- if (model == nullptr) { 6959be168c0dSopenharmony_ci- MS_LOG(ERROR) << "New model failed."; 6960be168c0dSopenharmony_ci- return kLiteNullptr; 6961be168c0dSopenharmony_ci- } 6962be168c0dSopenharmony_ci- auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model)); 6963be168c0dSopenharmony_ci- if (graph_data == nullptr) { 6964be168c0dSopenharmony_ci- MS_LOG(ERROR) << "New graph data failed."; 6965be168c0dSopenharmony_ci- return kLiteMemoryFailed; 6966be168c0dSopenharmony_ci- } 6967be168c0dSopenharmony_ci- *graph = Graph(graph_data); 6968be168c0dSopenharmony_ci- return kSuccess; 6969be168c0dSopenharmony_ci- } else { 6970be168c0dSopenharmony_ci- auto loader = CreateExpressionLoader(); 6971be168c0dSopenharmony_ci- if (loader == nullptr) { 6972be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported Feature."; 6973be168c0dSopenharmony_ci- delete[] model_buf; 6974be168c0dSopenharmony_ci- return kLiteError; 6975be168c0dSopenharmony_ci- } 6976be168c0dSopenharmony_ci- (void)loader(model_buf, graph); 6977be168c0dSopenharmony_ci- delete[] model_buf; 6978be168c0dSopenharmony_ci- return kSuccess; 6979be168c0dSopenharmony_ci+ auto model = 6980be168c0dSopenharmony_ci+ std::shared_ptr<lite::Model>(lite::ImportFromBuffer(static_cast<const char *>(model_buf), model_size, true)); 6981be168c0dSopenharmony_ci+ if (model == nullptr) { 6982be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "New model failed."; 6983be168c0dSopenharmony_ci+ return kLiteNullptr; 6984be168c0dSopenharmony_ci } 6985be168c0dSopenharmony_ci+ auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model)); 6986be168c0dSopenharmony_ci+ if (graph_data == nullptr) { 6987be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "New graph data failed."; 6988be168c0dSopenharmony_ci+ return kLiteMemoryFailed; 6989be168c0dSopenharmony_ci+ } 6990be168c0dSopenharmony_ci+ *graph = Graph(graph_data); 6991be168c0dSopenharmony_ci+ return kSuccess; 6992be168c0dSopenharmony_ci } 6993be168c0dSopenharmony_ci 6994be168c0dSopenharmony_ci Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type, 6995be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/train/model.cc b/mindspore/lite/src/litert/cxx_api/train/model.cc 6996be168c0dSopenharmony_ciindex 40525d9d..2ac44ada 100644 6997be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/train/model.cc 6998be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/train/model.cc 6999be168c0dSopenharmony_ci@@ -15,7 +15,6 @@ 7000be168c0dSopenharmony_ci */ 7001be168c0dSopenharmony_ci 7002be168c0dSopenharmony_ci #include "include/api/model.h" 7003be168c0dSopenharmony_ci-#include "include/api/net.h" 7004be168c0dSopenharmony_ci #include "include/api/callback/callback.h" 7005be168c0dSopenharmony_ci #include "include/api/dual_abi_helper.h" 7006be168c0dSopenharmony_ci #include "src/litert/cxx_api/model/model_impl.h" 7007be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/train/model_build.cc b/mindspore/lite/src/litert/cxx_api/train/model_build.cc 7008be168c0dSopenharmony_ciindex c2f0161b..6ec79777 100644 7009be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/train/model_build.cc 7010be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/train/model_build.cc 7011be168c0dSopenharmony_ci@@ -18,34 +18,6 @@ 7012be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 7013be168c0dSopenharmony_ci #include "src/litert/cxx_api/model/model_impl.h" 7014be168c0dSopenharmony_ci namespace mindspore { 7015be168c0dSopenharmony_ci-Status Model::Build(GraphCell lossGraphCell, Node *optimizer, std::vector<Expr *> inputs, 7016be168c0dSopenharmony_ci- const std::shared_ptr<Context> &model_context, const std::shared_ptr<TrainCfg> &train_cfg) { 7017be168c0dSopenharmony_ci- std::stringstream err_msg; 7018be168c0dSopenharmony_ci- if (impl_ == nullptr) { 7019be168c0dSopenharmony_ci- impl_ = std::make_shared<ModelImpl>(); 7020be168c0dSopenharmony_ci- if (impl_ == nullptr) { 7021be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Model implement is null."; 7022be168c0dSopenharmony_ci- return kLiteFileError; 7023be168c0dSopenharmony_ci- } 7024be168c0dSopenharmony_ci- } 7025be168c0dSopenharmony_ci- auto lossGraph = lossGraphCell.GetGraph(); 7026be168c0dSopenharmony_ci- if (lossGraph == nullptr) { 7027be168c0dSopenharmony_ci- err_msg << "Invalid null graph"; 7028be168c0dSopenharmony_ci- MS_LOG(ERROR) << err_msg.str(); 7029be168c0dSopenharmony_ci- return Status(kLiteNullptr, err_msg.str()); 7030be168c0dSopenharmony_ci- } 7031be168c0dSopenharmony_ci- impl_->SetContext(model_context); 7032be168c0dSopenharmony_ci- impl_->SetConfig(train_cfg); 7033be168c0dSopenharmony_ci- impl_->SetGraph(lossGraph); 7034be168c0dSopenharmony_ci- auto graph = impl_->BuildTrain(optimizer, inputs); 7035be168c0dSopenharmony_ci- auto status = Build(GraphCell(*graph), model_context, train_cfg); 7036be168c0dSopenharmony_ci- if (status != mindspore::kSuccess) { 7037be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Error " << status << " during model build"; 7038be168c0dSopenharmony_ci- return status; 7039be168c0dSopenharmony_ci- } 7040be168c0dSopenharmony_ci- return kSuccess; // status 7041be168c0dSopenharmony_ci-} 7042be168c0dSopenharmony_ci- 7043be168c0dSopenharmony_ci Status Model::BuildTransferLearning(GraphCell backbone, GraphCell head, const std::shared_ptr<Context> &context, 7044be168c0dSopenharmony_ci const std::shared_ptr<TrainCfg> &train_cfg) { 7045be168c0dSopenharmony_ci std::stringstream err_msg; 7046be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/train/model_build_impl.cc b/mindspore/lite/src/litert/cxx_api/train/model_build_impl.cc 7047be168c0dSopenharmony_ciindex 28328944..ef561708 100644 7048be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/train/model_build_impl.cc 7049be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/train/model_build_impl.cc 7050be168c0dSopenharmony_ci@@ -18,35 +18,7 @@ 7051be168c0dSopenharmony_ci #include "include/train/train_cfg.h" 7052be168c0dSopenharmony_ci #include "src/litert/cxx_api/converters.h" 7053be168c0dSopenharmony_ci #include "src/train/transfer_session.h" 7054be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/node_impl.h" 7055be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 7056be168c0dSopenharmony_ci namespace mindspore { 7057be168c0dSopenharmony_ci-std::unique_ptr<Graph> ModelImpl::BuildTrain(Node *optimizer, std::vector<Expr *> inputs) { 7058be168c0dSopenharmony_ci- auto opt_impl = NodeImpl::GetImpl(optimizer); 7059be168c0dSopenharmony_ci- if (opt_impl == nullptr) { 7060be168c0dSopenharmony_ci- MS_LOG(ERROR) << "missing optimizer node implementation"; 7061be168c0dSopenharmony_ci- return nullptr; 7062be168c0dSopenharmony_ci- } 7063be168c0dSopenharmony_ci- auto opt = opt_impl->node(); 7064be168c0dSopenharmony_ci- auto in = Expr::convert(inputs); 7065be168c0dSopenharmony_ci- auto net_impl = NetImpl::GetImpl(graph_->net_data_->net().get()); 7066be168c0dSopenharmony_ci- if (net_impl == nullptr) { 7067be168c0dSopenharmony_ci- MS_LOG(ERROR) << "missing net implementation"; 7068be168c0dSopenharmony_ci- return nullptr; 7069be168c0dSopenharmony_ci- } 7070be168c0dSopenharmony_ci- auto trained_net = net_impl->net()->TrainNet(opt, in); 7071be168c0dSopenharmony_ci- if (trained_net == nullptr) { 7072be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to train network"; 7073be168c0dSopenharmony_ci- return nullptr; 7074be168c0dSopenharmony_ci- } 7075be168c0dSopenharmony_ci- auto mgraph = net_impl->MakeMs(); 7076be168c0dSopenharmony_ci- if (mgraph == nullptr) { 7077be168c0dSopenharmony_ci- MS_LOG(ERROR) << "failed to create graph"; 7078be168c0dSopenharmony_ci- return nullptr; 7079be168c0dSopenharmony_ci- } 7080be168c0dSopenharmony_ci- return mgraph; 7081be168c0dSopenharmony_ci-} 7082be168c0dSopenharmony_ci- 7083be168c0dSopenharmony_ci Status ModelImpl::BuildTransferLearning(const std::shared_ptr<Graph> &backbone, const std::shared_ptr<Graph> &head) { 7084be168c0dSopenharmony_ci const auto b_graph_data = backbone->graph_data_; 7085be168c0dSopenharmony_ci const auto h_graph_data = head->graph_data_; 7086be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/train/model_impl.cc b/mindspore/lite/src/litert/cxx_api/train/model_impl.cc 7087be168c0dSopenharmony_ciindex e2fe7a82..726b0585 100644 7088be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/train/model_impl.cc 7089be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/train/model_impl.cc 7090be168c0dSopenharmony_ci@@ -20,7 +20,6 @@ 7091be168c0dSopenharmony_ci #include "include/api/serialization.h" 7092be168c0dSopenharmony_ci #include "include/api/callback/callback.h" 7093be168c0dSopenharmony_ci #include "include/api/metrics/metrics.h" 7094be168c0dSopenharmony_ci-#include "src/litert/cxx_api/expression/net_impl.h" 7095be168c0dSopenharmony_ci #include "src/litert/cxx_api/converters.h" 7096be168c0dSopenharmony_ci #include "src/litert/cxx_api/metrics/metrics_adapter.h" 7097be168c0dSopenharmony_ci #include "src/litert/cxx_api/metrics/metrics_impl.h" 7098be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/graph_fusion.cc b/mindspore/lite/src/train/graph_fusion.cc 7099be168c0dSopenharmony_ciindex 7982f818..980b2baa 100644 7100be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/graph_fusion.cc 7101be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/graph_fusion.cc 7102be168c0dSopenharmony_ci@@ -27,6 +27,7 @@ 7103be168c0dSopenharmony_ci #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" 7104be168c0dSopenharmony_ci #include "src/train/optimizer/fusion/matmul_add_fusion_pass.h" 7105be168c0dSopenharmony_ci #include "src/train/optimizer/fusion/matmul_matmul_add_fusion_pass.h" 7106be168c0dSopenharmony_ci+#include "src/train/optimizer/fusion/remove_redundant_tensor.h" 7107be168c0dSopenharmony_ci 7108be168c0dSopenharmony_ci namespace mindspore { 7109be168c0dSopenharmony_ci namespace lite { 7110be168c0dSopenharmony_ci@@ -64,6 +65,12 @@ STATUS GraphFusion::Run(schema::MetaGraphT *graph) { 7111be168c0dSopenharmony_ci MS_LOG(ERROR) << "graph fusion failed."; 7112be168c0dSopenharmony_ci return RET_ERROR; 7113be168c0dSopenharmony_ci } 7114be168c0dSopenharmony_ci+ auto opt_tensor = new (std::nothrow) RemoveRedundantTensor(); 7115be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(opt_tensor != nullptr, RET_NULL_PTR, "Create RemoveRedundantTensor failed."); 7116be168c0dSopenharmony_ci+ if (opt_tensor->Run(graph) != RET_OK) { 7117be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Do RemoveRedundantTensor failed."; 7118be168c0dSopenharmony_ci+ return RET_ERROR; 7119be168c0dSopenharmony_ci+ } 7120be168c0dSopenharmony_ci return RET_OK; 7121be168c0dSopenharmony_ci } 7122be168c0dSopenharmony_ci } // namespace lite 7123be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.cc b/mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.cc 7124be168c0dSopenharmony_cinew file mode 100644 7125be168c0dSopenharmony_ciindex 00000000..e74e78e2 7126be168c0dSopenharmony_ci--- /dev/null 7127be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.cc 7128be168c0dSopenharmony_ci@@ -0,0 +1,89 @@ 7129be168c0dSopenharmony_ci+/** 7130be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 7131be168c0dSopenharmony_ci+ * 7132be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 7133be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 7134be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 7135be168c0dSopenharmony_ci+ * 7136be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 7137be168c0dSopenharmony_ci+ * 7138be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 7139be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 7140be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7141be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 7142be168c0dSopenharmony_ci+ * limitations under the License. 7143be168c0dSopenharmony_ci+ */ 7144be168c0dSopenharmony_ci+ 7145be168c0dSopenharmony_ci+#include "src/train/optimizer/fusion/remove_redundant_tensor.h" 7146be168c0dSopenharmony_ci+#include <map> 7147be168c0dSopenharmony_ci+#include "src/common/log_adapter.h" 7148be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 7149be168c0dSopenharmony_ci+ 7150be168c0dSopenharmony_ci+namespace mindspore { 7151be168c0dSopenharmony_ci+namespace lite { 7152be168c0dSopenharmony_ci+STATUS RemoveRedundantTensor::Run(schema::MetaGraphT *graph) { 7153be168c0dSopenharmony_ci+ if (graph == nullptr) { 7154be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "The graph is a nullptr."; 7155be168c0dSopenharmony_ci+ return RET_NULL_PTR; 7156be168c0dSopenharmony_ci+ } 7157be168c0dSopenharmony_ci+ std::map<uint32_t, uint32_t> index_map; 7158be168c0dSopenharmony_ci+ uint32_t index = 0; 7159be168c0dSopenharmony_ci+ auto graph_input_index = graph->inputIndex; 7160be168c0dSopenharmony_ci+ graph->inputIndex.clear(); 7161be168c0dSopenharmony_ci+ for (auto input_index : graph_input_index) { 7162be168c0dSopenharmony_ci+ if (index_map.find(input_index) == index_map.end()) { 7163be168c0dSopenharmony_ci+ index_map[input_index] = index; 7164be168c0dSopenharmony_ci+ ++index; 7165be168c0dSopenharmony_ci+ } 7166be168c0dSopenharmony_ci+ graph->inputIndex.push_back(index_map[input_index]); 7167be168c0dSopenharmony_ci+ } 7168be168c0dSopenharmony_ci+ for (auto &node : graph->nodes) { 7169be168c0dSopenharmony_ci+ auto node_in_index = node->inputIndex; 7170be168c0dSopenharmony_ci+ node->inputIndex.clear(); 7171be168c0dSopenharmony_ci+ for (auto in_index : node_in_index) { 7172be168c0dSopenharmony_ci+ if (index_map.find(in_index) == index_map.end()) { 7173be168c0dSopenharmony_ci+ index_map[in_index] = index; 7174be168c0dSopenharmony_ci+ ++index; 7175be168c0dSopenharmony_ci+ } 7176be168c0dSopenharmony_ci+ node->inputIndex.push_back(index_map[in_index]); 7177be168c0dSopenharmony_ci+ } 7178be168c0dSopenharmony_ci+ auto node_out_index = node->outputIndex; 7179be168c0dSopenharmony_ci+ node->outputIndex.clear(); 7180be168c0dSopenharmony_ci+ for (auto out_index : node_out_index) { 7181be168c0dSopenharmony_ci+ if (index_map.find(out_index) == index_map.end()) { 7182be168c0dSopenharmony_ci+ index_map[out_index] = index; 7183be168c0dSopenharmony_ci+ ++index; 7184be168c0dSopenharmony_ci+ } 7185be168c0dSopenharmony_ci+ node->outputIndex.push_back(index_map[out_index]); 7186be168c0dSopenharmony_ci+ } 7187be168c0dSopenharmony_ci+ } 7188be168c0dSopenharmony_ci+ auto graph_output_index = graph->outputIndex; 7189be168c0dSopenharmony_ci+ graph->outputIndex.clear(); 7190be168c0dSopenharmony_ci+ for (auto output_index : graph_output_index) { 7191be168c0dSopenharmony_ci+ if (index_map.find(output_index) == index_map.end()) { 7192be168c0dSopenharmony_ci+ index_map[output_index] = index; 7193be168c0dSopenharmony_ci+ ++index; 7194be168c0dSopenharmony_ci+ } 7195be168c0dSopenharmony_ci+ graph->outputIndex.push_back(index_map[output_index]); 7196be168c0dSopenharmony_ci+ } 7197be168c0dSopenharmony_ci+ std::vector<std::unique_ptr<mindspore::schema::TensorT>> old_tensors; 7198be168c0dSopenharmony_ci+ old_tensors.swap(graph->allTensors); 7199be168c0dSopenharmony_ci+ graph->allTensors.resize(index_map.size()); 7200be168c0dSopenharmony_ci+ for (size_t i = 0; i < old_tensors.size(); ++i) { 7201be168c0dSopenharmony_ci+ if (index_map.find(i) == index_map.end()) { 7202be168c0dSopenharmony_ci+ continue; 7203be168c0dSopenharmony_ci+ } 7204be168c0dSopenharmony_ci+ graph->allTensors[index_map[i]].swap(old_tensors[i]); 7205be168c0dSopenharmony_ci+ } 7206be168c0dSopenharmony_ci+ if (!graph->subGraph.empty()) { 7207be168c0dSopenharmony_ci+ graph->subGraph[0]->inputIndices = graph->inputIndex; 7208be168c0dSopenharmony_ci+ graph->subGraph[0]->outputIndices = graph->outputIndex; 7209be168c0dSopenharmony_ci+ graph->subGraph[0]->tensorIndices = {}; 7210be168c0dSopenharmony_ci+ for (uint32_t i = 0; i < index; ++i) { 7211be168c0dSopenharmony_ci+ graph->subGraph[0]->tensorIndices.push_back(i); 7212be168c0dSopenharmony_ci+ } 7213be168c0dSopenharmony_ci+ } 7214be168c0dSopenharmony_ci+ return RET_OK; 7215be168c0dSopenharmony_ci+} 7216be168c0dSopenharmony_ci+} // namespace lite 7217be168c0dSopenharmony_ci+} // namespace mindspore 7218be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/expression/ops/depend.h b/mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.h 7219be168c0dSopenharmony_cisimilarity index 55% 7220be168c0dSopenharmony_cirename from mindspore/lite/src/expression/ops/depend.h 7221be168c0dSopenharmony_cirename to mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.h 7222be168c0dSopenharmony_ciindex 0995e664..8da58dc5 100644 7223be168c0dSopenharmony_ci--- a/mindspore/lite/src/expression/ops/depend.h 7224be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/optimizer/fusion/remove_redundant_tensor.h 7225be168c0dSopenharmony_ci@@ -1,5 +1,5 @@ 7226be168c0dSopenharmony_ci /** 7227be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 7228be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 7229be168c0dSopenharmony_ci * 7230be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License"); 7231be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 7232be168c0dSopenharmony_ci@@ -14,24 +14,23 @@ 7233be168c0dSopenharmony_ci * limitations under the License. 7234be168c0dSopenharmony_ci */ 7235be168c0dSopenharmony_ci 7236be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_ 7237be168c0dSopenharmony_ci-#define MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_ 7238be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_REMOVE_REDUNDANT_TENSOR_H_ 7239be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_REMOVE_REDUNDANT_TENSOR_H_ 7240be168c0dSopenharmony_ci 7241be168c0dSopenharmony_ci-#include <vector> 7242be168c0dSopenharmony_ci-#include "src/expression/node.h" 7243be168c0dSopenharmony_ci-#include "src/expression/ops.h" 7244be168c0dSopenharmony_ci-#include "inner/model_generated.h" 7245be168c0dSopenharmony_ci+#include "include/errorcode.h" 7246be168c0dSopenharmony_ci+#include "schema/inner/model_generated.h" 7247be168c0dSopenharmony_ci 7248be168c0dSopenharmony_ci namespace mindspore { 7249be168c0dSopenharmony_ci namespace lite { 7250be168c0dSopenharmony_ci-class DependM : public Node { 7251be168c0dSopenharmony_ci+class RemoveRedundantTensor { 7252be168c0dSopenharmony_ci public: 7253be168c0dSopenharmony_ci- DependM(); 7254be168c0dSopenharmony_ci-}; 7255be168c0dSopenharmony_ci+ RemoveRedundantTensor() = default; 7256be168c0dSopenharmony_ci+ 7257be168c0dSopenharmony_ci+ ~RemoveRedundantTensor() = default; 7258be168c0dSopenharmony_ci 7259be168c0dSopenharmony_ci-namespace NN { 7260be168c0dSopenharmony_ci-Node *Depend(); 7261be168c0dSopenharmony_ci+ STATUS Run(schema::MetaGraphT *graph); 7262be168c0dSopenharmony_ci }; 7263be168c0dSopenharmony_ci } // namespace lite 7264be168c0dSopenharmony_ci } // namespace mindspore 7265be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_SRC_EXPRESSION_OPS_DEPEND_H_ 7266be168c0dSopenharmony_ci+ 7267be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_FUSION_REMOVE_REDUNDANT_TENSOR_H_ 7268be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc 7269be168c0dSopenharmony_ciindex c123cba8..fa7625cb 100644 7270be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_session.cc 7271be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_session.cc 7272be168c0dSopenharmony_ci@@ -1248,38 +1248,53 @@ int TrainSession::FindExportKernels(std::vector<kernel::KernelExec *> *export_ke 7273be168c0dSopenharmony_ci } 7274be168c0dSopenharmony_ci 7275be168c0dSopenharmony_ci template <typename DestType> 7276be168c0dSopenharmony_ci-int TrainSession::ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type, 7277be168c0dSopenharmony_ci- bool orig_train_state, std::vector<std::string> output_tensor_name) { 7278be168c0dSopenharmony_ci+int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, 7279be168c0dSopenharmony_ci+ FormatType format, std::vector<std::string> out_put_tensor_name) { 7280be168c0dSopenharmony_ci+ if constexpr (std::is_same_v<DestType, const std::string &>) { 7281be168c0dSopenharmony_ci+ MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty"); 7282be168c0dSopenharmony_ci+ struct stat path_type; 7283be168c0dSopenharmony_ci+ if (stat(destination.c_str(), &path_type) == RET_OK) { 7284be168c0dSopenharmony_ci+ if (path_type.st_mode & S_IFDIR) { 7285be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Destination must be path, now is a directory"; 7286be168c0dSopenharmony_ci+ return RET_ERROR; 7287be168c0dSopenharmony_ci+ } 7288be168c0dSopenharmony_ci+ } 7289be168c0dSopenharmony_ci+ } else if constexpr (std::is_same_v<DestType, Buffer *>) { 7290be168c0dSopenharmony_ci+ MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr"); 7291be168c0dSopenharmony_ci+ } else { 7292be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported destination."; 7293be168c0dSopenharmony_ci+ return RET_ERROR; 7294be168c0dSopenharmony_ci+ } 7295be168c0dSopenharmony_ci+ MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR, 7296be168c0dSopenharmony_ci+ "Export model type parameter error"); 7297be168c0dSopenharmony_ci+ MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR, 7298be168c0dSopenharmony_ci+ "Export quant type parameter error"); 7299be168c0dSopenharmony_ci+ MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty"); 7300be168c0dSopenharmony_ci+ 7301be168c0dSopenharmony_ci+ bool orig_train_state = IsTrain(); 7302be168c0dSopenharmony_ci TrainExport texport(destination); 7303be168c0dSopenharmony_ci int status = texport.ExportInit(model_.get()->graph_.name_, model_.get()->graph_.version_); 7304be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to init export"); 7305be168c0dSopenharmony_ci- if (!output_tensor_name.empty() && model_type == MT_INFERENCE) { 7306be168c0dSopenharmony_ci+ 7307be168c0dSopenharmony_ci+ if (!out_put_tensor_name.empty() && model_type == MT_INFERENCE) { 7308be168c0dSopenharmony_ci std::vector<kernel::KernelExec *> export_kernels = {}; 7309be168c0dSopenharmony_ci- status = FindExportKernels(&export_kernels, output_tensor_name, const_fold_kernels_); 7310be168c0dSopenharmony_ci+ status = FindExportKernels(&export_kernels, out_put_tensor_name, const_fold_kernels_); 7311be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "FindExportKernels failed."); 7312be168c0dSopenharmony_ci status = 7313be168c0dSopenharmony_ci- texport.ExportNet(export_kernels, tensors_, const_output_tensors_, output_tensor_name, model_.get(), quant_type); 7314be168c0dSopenharmony_ci+ texport.ExportNet(export_kernels, tensors_, const_output_tensors_, out_put_tensor_name, model_.get(), quant_type); 7315be168c0dSopenharmony_ci } else { 7316be168c0dSopenharmony_ci- if (!output_tensor_name.empty() && model_type == MT_TRAIN) { 7317be168c0dSopenharmony_ci- MS_LOG(WARNING) << "Train model does not support to export selected output tensor, and all of the train kernels " 7318be168c0dSopenharmony_ci- "tensors will be exported"; 7319be168c0dSopenharmony_ci- } 7320be168c0dSopenharmony_ci if ((!model_buff_changed_) && (quant_type == QT_NONE) && (model_type == MT_TRAIN) && 7321be168c0dSopenharmony_ci std::all_of(model_->graph_.all_nodes_.begin(), model_->graph_.all_nodes_.end(), [](const LiteGraph::Node *n) { 7322be168c0dSopenharmony_ci return n->quant_type_ == schema::QuantType::QuantType_QUANT_NONE; 7323be168c0dSopenharmony_ci })) { 7324be168c0dSopenharmony_ci status = texport.SaveModel(model_.get(), destination); 7325be168c0dSopenharmony_ci- TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Failed to save model"); 7326be168c0dSopenharmony_ci- if (orig_train_state) { 7327be168c0dSopenharmony_ci- status = Train(); 7328be168c0dSopenharmony_ci- TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed."); 7329be168c0dSopenharmony_ci- } 7330be168c0dSopenharmony_ci+ if (orig_train_state) Train(); 7331be168c0dSopenharmony_ci return status; 7332be168c0dSopenharmony_ci } else { 7333be168c0dSopenharmony_ci if (quant_type == QT_NONE) { 7334be168c0dSopenharmony_ci status = texport.ExportNet( 7335be168c0dSopenharmony_ci- (model_type == MT_TRAIN) ? train_kernels_ : const_fold_kernels_, tensors_, const_output_tensors_, 7336be168c0dSopenharmony_ci- (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_.get(), quant_type); 7337be168c0dSopenharmony_ci+ (model_type == MT_TRAIN) ? train_kernels_ : const_fold_kernels_, tensors_, const_output_tensors_, 7338be168c0dSopenharmony_ci+ (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, model_.get(), quant_type); 7339be168c0dSopenharmony_ci } else { 7340be168c0dSopenharmony_ci status = texport.ExportNet((model_type == MT_TRAIN) ? train_kernels_ : inference_kernels_, tensors_, {}, 7341be168c0dSopenharmony_ci (model_type == MT_TRAIN) ? train_output_tensor_names_ : eval_output_tensor_names_, 7342be168c0dSopenharmony_ci@@ -1288,6 +1303,7 @@ int TrainSession::ExportByDifferentType(DestType destination, ModelType model_ty 7343be168c0dSopenharmony_ci } 7344be168c0dSopenharmony_ci } 7345be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Fail to export Network."); 7346be168c0dSopenharmony_ci+ 7347be168c0dSopenharmony_ci if (model_type == MT_INFERENCE) { 7348be168c0dSopenharmony_ci status = texport.TrainModelDrop(); 7349be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "TrainModelDrop failed."); 7350be168c0dSopenharmony_ci@@ -1304,46 +1320,8 @@ int TrainSession::ExportByDifferentType(DestType destination, ModelType model_ty 7351be168c0dSopenharmony_ci status = texport.SaveToBuffer(); 7352be168c0dSopenharmony_ci TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "fail to save to model buffer."); 7353be168c0dSopenharmony_ci } 7354be168c0dSopenharmony_ci- return RET_OK; 7355be168c0dSopenharmony_ci-} 7356be168c0dSopenharmony_ci- 7357be168c0dSopenharmony_ci-template <typename DestType> 7358be168c0dSopenharmony_ci-int TrainSession::ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, 7359be168c0dSopenharmony_ci- FormatType format, std::vector<std::string> out_put_tensor_name) { 7360be168c0dSopenharmony_ci- if constexpr (std::is_same_v<DestType, const std::string &>) { 7361be168c0dSopenharmony_ci- MS_CHECK_FALSE_MSG(destination.empty(), RET_ERROR, "File name cannot be empty"); 7362be168c0dSopenharmony_ci- struct stat path_type; 7363be168c0dSopenharmony_ci- if (stat(destination.c_str(), &path_type) == RET_OK) { 7364be168c0dSopenharmony_ci- if (path_type.st_mode & S_IFDIR) { 7365be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Destination must be path, now is a directory"; 7366be168c0dSopenharmony_ci- return RET_ERROR; 7367be168c0dSopenharmony_ci- } 7368be168c0dSopenharmony_ci- } 7369be168c0dSopenharmony_ci- } else if constexpr (std::is_same_v<DestType, Buffer *>) { 7370be168c0dSopenharmony_ci- MS_CHECK_FALSE_MSG(destination == nullptr, RET_ERROR, "model buffer cannot be nullptr"); 7371be168c0dSopenharmony_ci- } else { 7372be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported destination."; 7373be168c0dSopenharmony_ci- return RET_ERROR; 7374be168c0dSopenharmony_ci- } 7375be168c0dSopenharmony_ci- MS_CHECK_FALSE_MSG(model_type > mindspore::lite::MT_INFERENCE || model_type < mindspore::lite::MT_TRAIN, RET_ERROR, 7376be168c0dSopenharmony_ci- "Export model type parameter error"); 7377be168c0dSopenharmony_ci- MS_CHECK_FALSE_MSG(quant_type < mindspore::lite::QT_DEFAULT || quant_type > mindspore::lite::QT_WEIGHT, RET_ERROR, 7378be168c0dSopenharmony_ci- "Export quant type parameter error"); 7379be168c0dSopenharmony_ci- MS_CHECK_FALSE_MSG(format != FT_FLATBUFFERS, RET_ERROR, "File name cannot be empty"); 7380be168c0dSopenharmony_ci- 7381be168c0dSopenharmony_ci- bool orig_train_state = IsTrain(); 7382be168c0dSopenharmony_ci- int status = Eval(); 7383be168c0dSopenharmony_ci- TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Eval failed"); 7384be168c0dSopenharmony_ci- status = ExportByDifferentType<DestType>(destination, model_type, quant_type, orig_train_state, out_put_tensor_name); 7385be168c0dSopenharmony_ci- if (status != RET_OK) { 7386be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Fail to export by different type"; 7387be168c0dSopenharmony_ci- return status; 7388be168c0dSopenharmony_ci- } 7389be168c0dSopenharmony_ci- if (orig_train_state) { 7390be168c0dSopenharmony_ci- status = Train(); 7391be168c0dSopenharmony_ci- TRAIN_SESSION_CHECK_FALSE_MSG(status != RET_OK, status, "Train failed"); 7392be168c0dSopenharmony_ci- } 7393be168c0dSopenharmony_ci- return RET_OK; 7394be168c0dSopenharmony_ci+ if (orig_train_state) Train(); 7395be168c0dSopenharmony_ci+ return status; 7396be168c0dSopenharmony_ci } 7397be168c0dSopenharmony_ci 7398be168c0dSopenharmony_ci int TrainSession::Export(const std::string &file_name, ModelType model_type, QuantizationType quant_type, 7399be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h 7400be168c0dSopenharmony_ciindex 0bd14b21..0a65f64e 100644 7401be168c0dSopenharmony_ci--- a/mindspore/lite/src/train/train_session.h 7402be168c0dSopenharmony_ci+++ b/mindspore/lite/src/train/train_session.h 7403be168c0dSopenharmony_ci@@ -175,9 +175,6 @@ class TrainSession : virtual public lite::LiteSession { 7404be168c0dSopenharmony_ci const std::unordered_map<lite::Tensor *, size_t> &offset_map, 7405be168c0dSopenharmony_ci std::unordered_map<lite::Tensor *, int> *ref_count, uint32_t input_idx); 7406be168c0dSopenharmony_ci template <typename DestType> 7407be168c0dSopenharmony_ci- int ExportByDifferentType(DestType destination, ModelType model_type, QuantizationType quant_type, 7408be168c0dSopenharmony_ci- bool orig_train_state, std::vector<std::string> output_tensor_name = {}); 7409be168c0dSopenharmony_ci- template <typename DestType> 7410be168c0dSopenharmony_ci int ExportInner(DestType destination, ModelType model_type, QuantizationType quant_type, FormatType, 7411be168c0dSopenharmony_ci std::vector<std::string> out_put_tensor_name = {}); 7412be168c0dSopenharmony_ci std::map<Tensor *, Tensor *> restored_origin_tensors_; 7413be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/config_level0/models_ms_train.cfg b/mindspore/lite/test/config_level0/models_ms_train.cfg 7414be168c0dSopenharmony_ciindex 645a31c4..7a6b9702 100644 7415be168c0dSopenharmony_ci--- a/mindspore/lite/test/config_level0/models_ms_train.cfg 7416be168c0dSopenharmony_ci+++ b/mindspore/lite/test/config_level0/models_ms_train.cfg 7417be168c0dSopenharmony_ci@@ -51,7 +51,5 @@ vae 7418be168c0dSopenharmony_ci unified_api code_example 7419be168c0dSopenharmony_ci train_lenet code_example 7420be168c0dSopenharmony_ci train_lenet_java code_example 7421be168c0dSopenharmony_ci-lenet expression 7422be168c0dSopenharmony_ci-mobilenetv2 expression noarm32 7423be168c0dSopenharmony_ci # LAST 7424be168c0dSopenharmony_ci #test_resize inputShapes 16,10,10,1:16,10,10,1 0.5 7425be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt 7426be168c0dSopenharmony_ciindex 1b9fc347..3c92af7f 100644 7427be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt 7428be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt 7429be168c0dSopenharmony_ci@@ -17,13 +17,6 @@ set(TEST_SRC 7430be168c0dSopenharmony_ci # add static securec link library 7431be168c0dSopenharmony_ci include(${TOP_DIR}/cmake/dependency_securec.cmake) 7432be168c0dSopenharmony_ci 7433be168c0dSopenharmony_ci-if(MSLITE_MINDDATA_IMPLEMENT STREQUAL "full") 7434be168c0dSopenharmony_ci- set(TEST_SRC 7435be168c0dSopenharmony_ci- ${TEST_SRC} 7436be168c0dSopenharmony_ci- ${CMAKE_CURRENT_SOURCE_DIR}/net_runner.cc 7437be168c0dSopenharmony_ci- ) 7438be168c0dSopenharmony_ci-endif() 7439be168c0dSopenharmony_ci- 7440be168c0dSopenharmony_ci add_executable(benchmark_train 7441be168c0dSopenharmony_ci ${TEST_SRC} 7442be168c0dSopenharmony_ci ${COMMON_SRC}) 7443be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_runner.cc b/mindspore/lite/tools/benchmark_train/net_runner.cc 7444be168c0dSopenharmony_cideleted file mode 100644 7445be168c0dSopenharmony_ciindex edf3e964..00000000 7446be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_runner.cc 7447be168c0dSopenharmony_ci+++ /dev/null 7448be168c0dSopenharmony_ci@@ -1,371 +0,0 @@ 7449be168c0dSopenharmony_ci-/** 7450be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 7451be168c0dSopenharmony_ci- * 7452be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 7453be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 7454be168c0dSopenharmony_ci- * You may obtain a copy of the License at 7455be168c0dSopenharmony_ci- * 7456be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 7457be168c0dSopenharmony_ci- * 7458be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 7459be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 7460be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7461be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 7462be168c0dSopenharmony_ci- * limitations under the License. 7463be168c0dSopenharmony_ci- */ 7464be168c0dSopenharmony_ci- 7465be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_runner.h" 7466be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_train_base.h" 7467be168c0dSopenharmony_ci-#include <getopt.h> 7468be168c0dSopenharmony_ci-#include <malloc.h> 7469be168c0dSopenharmony_ci-#include <cmath> 7470be168c0dSopenharmony_ci-#include <cstdio> 7471be168c0dSopenharmony_ci-#include <cstring> 7472be168c0dSopenharmony_ci-#include <iostream> 7473be168c0dSopenharmony_ci-#include <fstream> 7474be168c0dSopenharmony_ci-#include <utility> 7475be168c0dSopenharmony_ci-#include <chrono> 7476be168c0dSopenharmony_ci-#include "include/api/types.h" 7477be168c0dSopenharmony_ci-#include "include/api/context.h" 7478be168c0dSopenharmony_ci-#include "include/api/serialization.h" 7479be168c0dSopenharmony_ci-#include "include/api/callback/loss_monitor.h" 7480be168c0dSopenharmony_ci-#include "include/api/metrics/accuracy.h" 7481be168c0dSopenharmony_ci-#include "include/api/callback/ckpt_saver.h" 7482be168c0dSopenharmony_ci-#include "include/api/callback/train_accuracy.h" 7483be168c0dSopenharmony_ci-#include "include/api/callback/lr_scheduler.h" 7484be168c0dSopenharmony_ci-#include "include/dataset/datasets.h" 7485be168c0dSopenharmony_ci-#include "include/dataset/vision_lite.h" 7486be168c0dSopenharmony_ci-#include "include/dataset/transforms.h" 7487be168c0dSopenharmony_ci-#include "include/api/cfg.h" 7488be168c0dSopenharmony_ci-#include "include/api/net.h" 7489be168c0dSopenharmony_ci- 7490be168c0dSopenharmony_ci-using mindspore::AccuracyMetrics; 7491be168c0dSopenharmony_ci-using mindspore::Model; 7492be168c0dSopenharmony_ci-using mindspore::TrainAccuracy; 7493be168c0dSopenharmony_ci-using mindspore::TrainCallBack; 7494be168c0dSopenharmony_ci-using mindspore::TrainCallBackData; 7495be168c0dSopenharmony_ci-using mindspore::dataset::Dataset; 7496be168c0dSopenharmony_ci-using mindspore::dataset::Mnist; 7497be168c0dSopenharmony_ci-using mindspore::dataset::SequentialSampler; 7498be168c0dSopenharmony_ci-using mindspore::dataset::TensorOperation; 7499be168c0dSopenharmony_ci-using mindspore::dataset::transforms::TypeCast; 7500be168c0dSopenharmony_ci-using mindspore::dataset::vision::Normalize; 7501be168c0dSopenharmony_ci-using mindspore::dataset::vision::Resize; 7502be168c0dSopenharmony_ci- 7503be168c0dSopenharmony_ci-constexpr int kNCHWCDim = 2; 7504be168c0dSopenharmony_ci-constexpr int kPrintTimes = 100; 7505be168c0dSopenharmony_ci-constexpr float kBetta1 = 0.9f; 7506be168c0dSopenharmony_ci-constexpr float kBetta2 = 0.999f; 7507be168c0dSopenharmony_ci- 7508be168c0dSopenharmony_ci-class Rescaler : public mindspore::TrainCallBack { 7509be168c0dSopenharmony_ci- public: 7510be168c0dSopenharmony_ci- explicit Rescaler(float scale) : scale_(scale) { 7511be168c0dSopenharmony_ci- if (scale_ == 0) { 7512be168c0dSopenharmony_ci- scale_ = 1.0; 7513be168c0dSopenharmony_ci- } 7514be168c0dSopenharmony_ci- } 7515be168c0dSopenharmony_ci- ~Rescaler() override = default; 7516be168c0dSopenharmony_ci- void StepBegin(const mindspore::TrainCallBackData &cb_data) override { 7517be168c0dSopenharmony_ci- auto inputs = cb_data.model_->GetInputs(); 7518be168c0dSopenharmony_ci- auto *input_data = reinterpret_cast<float *>(inputs.at(0).MutableData()); 7519be168c0dSopenharmony_ci- for (int k = 0; k < inputs.at(0).ElementNum(); k++) input_data[k] /= scale_; 7520be168c0dSopenharmony_ci- } 7521be168c0dSopenharmony_ci- 7522be168c0dSopenharmony_ci- private: 7523be168c0dSopenharmony_ci- float scale_ = 1.0; 7524be168c0dSopenharmony_ci-}; 7525be168c0dSopenharmony_ci- 7526be168c0dSopenharmony_ci-/* This is an example of a user defined Callback to measure memory and latency of execution */ 7527be168c0dSopenharmony_ci-class Measurement : public mindspore::TrainCallBack { 7528be168c0dSopenharmony_ci- public: 7529be168c0dSopenharmony_ci- explicit Measurement(unsigned int epochs) 7530be168c0dSopenharmony_ci- : time_avg_(std::chrono::duration<double, std::milli>(0)), epochs_(epochs) {} 7531be168c0dSopenharmony_ci- ~Measurement() override = default; 7532be168c0dSopenharmony_ci- void EpochBegin(const mindspore::TrainCallBackData &cb_data) override { 7533be168c0dSopenharmony_ci- start_time_ = std::chrono::high_resolution_clock::now(); 7534be168c0dSopenharmony_ci- } 7535be168c0dSopenharmony_ci- mindspore::CallbackRetValue EpochEnd(const mindspore::TrainCallBackData &cb_data) override { 7536be168c0dSopenharmony_ci- end_time_ = std::chrono::high_resolution_clock::now(); 7537be168c0dSopenharmony_ci- auto time = std::chrono::duration<double, std::milli>(end_time_ - start_time_); 7538be168c0dSopenharmony_ci- time_avg_ += time; 7539be168c0dSopenharmony_ci- return mindspore::kContinue; 7540be168c0dSopenharmony_ci- } 7541be168c0dSopenharmony_ci- void End(const mindspore::TrainCallBackData &cb_data) override { 7542be168c0dSopenharmony_ci- if (epochs_ > 0) { 7543be168c0dSopenharmony_ci- std::cout << "AvgRunTime: " << time_avg_.count() / epochs_ << " ms" << std::endl; 7544be168c0dSopenharmony_ci- } 7545be168c0dSopenharmony_ci- 7546be168c0dSopenharmony_ci- struct mallinfo info = mallinfo(); 7547be168c0dSopenharmony_ci- std::cout << "Total allocation: " << info.arena + info.hblkhd << std::endl; 7548be168c0dSopenharmony_ci- } 7549be168c0dSopenharmony_ci- 7550be168c0dSopenharmony_ci- private: 7551be168c0dSopenharmony_ci- std::chrono::time_point<std::chrono::high_resolution_clock> start_time_; 7552be168c0dSopenharmony_ci- std::chrono::time_point<std::chrono::high_resolution_clock> end_time_; 7553be168c0dSopenharmony_ci- std::chrono::duration<double, std::milli> time_avg_; 7554be168c0dSopenharmony_ci- unsigned int epochs_; 7555be168c0dSopenharmony_ci-}; 7556be168c0dSopenharmony_ci- 7557be168c0dSopenharmony_ci-NetRunner::~NetRunner() { 7558be168c0dSopenharmony_ci- if (model_ != nullptr) { 7559be168c0dSopenharmony_ci- delete model_; 7560be168c0dSopenharmony_ci- } 7561be168c0dSopenharmony_ci- if (graph_ != nullptr) { 7562be168c0dSopenharmony_ci- delete graph_; 7563be168c0dSopenharmony_ci- } 7564be168c0dSopenharmony_ci-} 7565be168c0dSopenharmony_ci- 7566be168c0dSopenharmony_ci-mindspore::Status NetRunner::InitAndFigureInputs() { 7567be168c0dSopenharmony_ci- auto context = std::make_shared<mindspore::Context>(); 7568be168c0dSopenharmony_ci- auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>(); 7569be168c0dSopenharmony_ci- cpu_context->SetEnableFP16(enable_fp16_); 7570be168c0dSopenharmony_ci- context->MutableDeviceInfo().push_back(cpu_context); 7571be168c0dSopenharmony_ci- 7572be168c0dSopenharmony_ci- graph_ = new (std::nothrow) mindspore::Graph(mindspore::Graph::Type::kExpressionGraph); 7573be168c0dSopenharmony_ci- if (graph_ == nullptr) { 7574be168c0dSopenharmony_ci- std::cout << "Cannot allocate graph" << std::endl; 7575be168c0dSopenharmony_ci- return mindspore::kLiteMemoryFailed; 7576be168c0dSopenharmony_ci- } 7577be168c0dSopenharmony_ci- auto status = mindspore::Serialization::Load(ms_file_, mindspore::kMindIR, graph_); 7578be168c0dSopenharmony_ci- if (status != mindspore::kSuccess) { 7579be168c0dSopenharmony_ci- std::cout << "Error " << status << " during serialization of graph " << ms_file_; 7580be168c0dSopenharmony_ci- return status; 7581be168c0dSopenharmony_ci- } 7582be168c0dSopenharmony_ci- auto net = std::make_unique<mindspore::Net>(*graph_); 7583be168c0dSopenharmony_ci- auto input_shape = net->InputShape(0); 7584be168c0dSopenharmony_ci- auto label_shape = net->OutputShape(0); 7585be168c0dSopenharmony_ci- auto inputM = mindspore::NN::Input(input_shape); 7586be168c0dSopenharmony_ci- auto labelM = mindspore::NN::Input(label_shape); 7587be168c0dSopenharmony_ci- auto label = labelM->Create("label"); 7588be168c0dSopenharmony_ci- auto input = inputM->Create("input"); 7589be168c0dSopenharmony_ci- 7590be168c0dSopenharmony_ci- auto cfg = std::make_shared<mindspore::TrainCfg>(); 7591be168c0dSopenharmony_ci- if (enable_fp16_) { 7592be168c0dSopenharmony_ci- cfg.get()->optimization_level_ = mindspore::kO2; 7593be168c0dSopenharmony_ci- } 7594be168c0dSopenharmony_ci- 7595be168c0dSopenharmony_ci- model_ = new (std::nothrow) mindspore::Model(); 7596be168c0dSopenharmony_ci- if (model_ == nullptr) { 7597be168c0dSopenharmony_ci- std::cout << "model allocation failed" << std::endl; 7598be168c0dSopenharmony_ci- return mindspore::kLiteMemoryFailed; 7599be168c0dSopenharmony_ci- } 7600be168c0dSopenharmony_ci- mindspore::SoftMaxCrossEntropyCfg softmax_ce_cfg; 7601be168c0dSopenharmony_ci- softmax_ce_cfg.reduction = "none"; 7602be168c0dSopenharmony_ci- auto netWithLoss = mindspore::NN::GraphWithLoss(graph_, mindspore::NN::SoftmaxCrossEntropy(softmax_ce_cfg)); 7603be168c0dSopenharmony_ci- mindspore::AdamConfig AdamCfg; 7604be168c0dSopenharmony_ci- AdamCfg.beta1_ = kBetta1; 7605be168c0dSopenharmony_ci- AdamCfg.beta2_ = kBetta2; 7606be168c0dSopenharmony_ci- AdamCfg.eps_ = 1e-8; 7607be168c0dSopenharmony_ci- AdamCfg.learning_rate_ = 1e-2; 7608be168c0dSopenharmony_ci- auto optimizer = mindspore::NN::Adam(net->trainable_params(), AdamCfg); 7609be168c0dSopenharmony_ci- status = model_->Build(mindspore::GraphCell(*netWithLoss), optimizer, {input, label}, context, cfg); 7610be168c0dSopenharmony_ci- if (status != mindspore::kSuccess) { 7611be168c0dSopenharmony_ci- std::cout << "Error " << status << " during build of model " << ms_file_ << std::endl; 7612be168c0dSopenharmony_ci- return status; 7613be168c0dSopenharmony_ci- } 7614be168c0dSopenharmony_ci- delete graph_; 7615be168c0dSopenharmony_ci- graph_ = nullptr; 7616be168c0dSopenharmony_ci- auto inputs = model_->GetInputs(); 7617be168c0dSopenharmony_ci- if (inputs.size() < 1) { 7618be168c0dSopenharmony_ci- return mindspore::kLiteError; 7619be168c0dSopenharmony_ci- } 7620be168c0dSopenharmony_ci- auto nhwc_input_dims = inputs.at(0).Shape(); 7621be168c0dSopenharmony_ci- batch_size_ = nhwc_input_dims.at(0); 7622be168c0dSopenharmony_ci- h_ = nhwc_input_dims.at(1); 7623be168c0dSopenharmony_ci- w_ = nhwc_input_dims.at(kNCHWCDim); 7624be168c0dSopenharmony_ci- return mindspore::kSuccess; 7625be168c0dSopenharmony_ci-} 7626be168c0dSopenharmony_ci- 7627be168c0dSopenharmony_ci-int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) { 7628be168c0dSopenharmony_ci- std::cout << "================ Comparing Forward Output data ================" << std::endl; 7629be168c0dSopenharmony_ci- float total_bias = 0; 7630be168c0dSopenharmony_ci- int total_size = 0; 7631be168c0dSopenharmony_ci- bool has_error = false; 7632be168c0dSopenharmony_ci- int i = 1; 7633be168c0dSopenharmony_ci- for (auto &tensor : outputs) { 7634be168c0dSopenharmony_ci- std::cout << "output is tensor " << tensor.Name() << "\n"; 7635be168c0dSopenharmony_ci- auto output = tensor.Data(); 7636be168c0dSopenharmony_ci- size_t size; 7637be168c0dSopenharmony_ci- std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; 7638be168c0dSopenharmony_ci- auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(output_file.c_str(), &size)); 7639be168c0dSopenharmony_ci- if (bin_buf == nullptr) { 7640be168c0dSopenharmony_ci- MS_LOG(ERROR) << "ReadFile return nullptr"; 7641be168c0dSopenharmony_ci- std::cout << "ReadFile return nullptr" << std::endl; 7642be168c0dSopenharmony_ci- return mindspore::kLiteNullptr; 7643be168c0dSopenharmony_ci- } 7644be168c0dSopenharmony_ci- if (size != tensor.DataSize()) { 7645be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() 7646be168c0dSopenharmony_ci- << ", read size: " << size; 7647be168c0dSopenharmony_ci- std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() 7648be168c0dSopenharmony_ci- << ", read size: " << size << std::endl; 7649be168c0dSopenharmony_ci- return mindspore::kLiteError; 7650be168c0dSopenharmony_ci- } 7651be168c0dSopenharmony_ci- float bias = mindspore::lite::NetTrainBase::CompareData<float>(bin_buf.get(), tensor.ElementNum(), 7652be168c0dSopenharmony_ci- reinterpret_cast<const float *>(output.get())); 7653be168c0dSopenharmony_ci- if (bias >= 0) { 7654be168c0dSopenharmony_ci- total_bias += bias; 7655be168c0dSopenharmony_ci- total_size++; 7656be168c0dSopenharmony_ci- } else { 7657be168c0dSopenharmony_ci- has_error = true; 7658be168c0dSopenharmony_ci- break; 7659be168c0dSopenharmony_ci- } 7660be168c0dSopenharmony_ci- i++; 7661be168c0dSopenharmony_ci- } 7662be168c0dSopenharmony_ci- 7663be168c0dSopenharmony_ci- if (!has_error) { 7664be168c0dSopenharmony_ci- float mean_bias; 7665be168c0dSopenharmony_ci- if (total_size != 0) { 7666be168c0dSopenharmony_ci- mean_bias = total_bias / total_size * kPrintTimes; 7667be168c0dSopenharmony_ci- } else { 7668be168c0dSopenharmony_ci- mean_bias = 0; 7669be168c0dSopenharmony_ci- } 7670be168c0dSopenharmony_ci- 7671be168c0dSopenharmony_ci- std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" 7672be168c0dSopenharmony_ci- << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl; 7673be168c0dSopenharmony_ci- std::cout << "=======================================================" << std::endl << std::endl; 7674be168c0dSopenharmony_ci- 7675be168c0dSopenharmony_ci- if (mean_bias > this->flags_->accuracy_threshold_) { 7676be168c0dSopenharmony_ci- MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%"; 7677be168c0dSopenharmony_ci- std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl; 7678be168c0dSopenharmony_ci- return mindspore::kLiteError; 7679be168c0dSopenharmony_ci- } else { 7680be168c0dSopenharmony_ci- return mindspore::kSuccess; 7681be168c0dSopenharmony_ci- } 7682be168c0dSopenharmony_ci- } else { 7683be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Error in CompareData"; 7684be168c0dSopenharmony_ci- std::cout << "Error in CompareData" << std::endl; 7685be168c0dSopenharmony_ci- std::cout << "=======================================================" << std::endl << std::endl; 7686be168c0dSopenharmony_ci- return mindspore::kSuccess; 7687be168c0dSopenharmony_ci- } 7688be168c0dSopenharmony_ci-} 7689be168c0dSopenharmony_ci- 7690be168c0dSopenharmony_ci-void NetRunner::CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out) { 7691be168c0dSopenharmony_ci- constexpr int kPrintLen = 4; 7692be168c0dSopenharmony_ci- int tensor_size = tensor.ElementNum(); 7693be168c0dSopenharmony_ci- const void *data = tensor.Data().get(); 7694be168c0dSopenharmony_ci- const float *fdata = reinterpret_cast<const float *>(data); 7695be168c0dSopenharmony_ci- mindspore::DataType type = tensor.DataType(); 7696be168c0dSopenharmony_ci- std::cout << node_type << " " << in_out << id << std::endl; 7697be168c0dSopenharmony_ci- std::cout << "tensor name: " << tensor.Name() << std::endl; 7698be168c0dSopenharmony_ci- if ((tensor_size) == 0 || (data == nullptr)) { 7699be168c0dSopenharmony_ci- std::cout << "Empty tensor" << std::endl; 7700be168c0dSopenharmony_ci- return; 7701be168c0dSopenharmony_ci- } 7702be168c0dSopenharmony_ci- switch (type) { 7703be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeFloat32: 7704be168c0dSopenharmony_ci- std::cout << "sum=" << mindspore::lite::TensorSum<float>(data, tensor_size) << std::endl; 7705be168c0dSopenharmony_ci- std::cout << "data: "; 7706be168c0dSopenharmony_ci- for (int i = 0; i <= kPrintLen && i < tensor_size; i++) { 7707be168c0dSopenharmony_ci- std::cout << static_cast<float>(fdata[i]) << ", "; 7708be168c0dSopenharmony_ci- } 7709be168c0dSopenharmony_ci- std::cout << std::endl; 7710be168c0dSopenharmony_ci- break; 7711be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeInt32: 7712be168c0dSopenharmony_ci- std::cout << "sum=" << mindspore::lite::TensorSum<int>(data, tensor_size) << std::endl; 7713be168c0dSopenharmony_ci- break; 7714be168c0dSopenharmony_ci- default: 7715be168c0dSopenharmony_ci- std::cout << "unsupported type:" << static_cast<int>(type) << std::endl; 7716be168c0dSopenharmony_ci- break; 7717be168c0dSopenharmony_ci- } 7718be168c0dSopenharmony_ci-} 7719be168c0dSopenharmony_ci- 7720be168c0dSopenharmony_ci-int NetRunner::InitCallbackParameter() { 7721be168c0dSopenharmony_ci- // after callback 7722be168c0dSopenharmony_ci- after_call_back_ = [&](const std::vector<mindspore::MSTensor> &after_inputs, 7723be168c0dSopenharmony_ci- const std::vector<mindspore::MSTensor> &after_outputs, 7724be168c0dSopenharmony_ci- const mindspore::MSCallBackParam &call_param) { 7725be168c0dSopenharmony_ci- if (after_inputs.empty()) { 7726be168c0dSopenharmony_ci- MS_LOG(INFO) << "The num of after inputs is empty"; 7727be168c0dSopenharmony_ci- } 7728be168c0dSopenharmony_ci- if (after_outputs.empty()) { 7729be168c0dSopenharmony_ci- MS_LOG(INFO) << "The num of after outputs is empty"; 7730be168c0dSopenharmony_ci- } 7731be168c0dSopenharmony_ci- if (flags_->layer_checksum_) { 7732be168c0dSopenharmony_ci- for (size_t i = 0; i < after_inputs.size(); i++) { 7733be168c0dSopenharmony_ci- CheckSum(after_inputs.at(i), call_param.node_type, i, "in"); 7734be168c0dSopenharmony_ci- } 7735be168c0dSopenharmony_ci- for (size_t i = 0; i < after_outputs.size(); i++) { 7736be168c0dSopenharmony_ci- CheckSum(after_outputs.at(i), call_param.node_type, i, "out"); 7737be168c0dSopenharmony_ci- } 7738be168c0dSopenharmony_ci- std::cout << std::endl; 7739be168c0dSopenharmony_ci- } 7740be168c0dSopenharmony_ci- return true; 7741be168c0dSopenharmony_ci- }; 7742be168c0dSopenharmony_ci- return false; 7743be168c0dSopenharmony_ci-} 7744be168c0dSopenharmony_ci- 7745be168c0dSopenharmony_ci-int NetRunner::RunOnce() { 7746be168c0dSopenharmony_ci- auto inputs = model_->GetInputs(); 7747be168c0dSopenharmony_ci- std::vector<mindspore::MSTensor> output; 7748be168c0dSopenharmony_ci- auto status = LoadInput(&inputs); 7749be168c0dSopenharmony_ci- if (status != mindspore::kSuccess) { 7750be168c0dSopenharmony_ci- std::cout << "cannot load data"; 7751be168c0dSopenharmony_ci- return status; 7752be168c0dSopenharmony_ci- } 7753be168c0dSopenharmony_ci- model_->SetTrainMode(true); 7754be168c0dSopenharmony_ci- model_->RunStep(nullptr, nullptr); 7755be168c0dSopenharmony_ci- model_->SetTrainMode(false); 7756be168c0dSopenharmony_ci- model_->Predict(inputs, &output, nullptr, nullptr); 7757be168c0dSopenharmony_ci- return CompareOutput(output); 7758be168c0dSopenharmony_ci-} 7759be168c0dSopenharmony_ci- 7760be168c0dSopenharmony_ci-int NetRunner::LoadInput(std::vector<mindspore::MSTensor> *ms_inputs) { 7761be168c0dSopenharmony_ci- auto status = ReadInputFile(ms_inputs); 7762be168c0dSopenharmony_ci- if (status != mindspore::kSuccess) { 7763be168c0dSopenharmony_ci- std::cout << "Read Input File error, " << status << std::endl; 7764be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Read Input File error, " << status; 7765be168c0dSopenharmony_ci- return status; 7766be168c0dSopenharmony_ci- } 7767be168c0dSopenharmony_ci- return mindspore::kSuccess; 7768be168c0dSopenharmony_ci-} 7769be168c0dSopenharmony_ci- 7770be168c0dSopenharmony_ci-int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) { 7771be168c0dSopenharmony_ci- if (ms_inputs->empty()) { 7772be168c0dSopenharmony_ci- std::cout << "no inputs to input" << std::endl; 7773be168c0dSopenharmony_ci- return mindspore::kLiteError; 7774be168c0dSopenharmony_ci- } 7775be168c0dSopenharmony_ci- for (size_t i = 0; i < ms_inputs->size(); i++) { 7776be168c0dSopenharmony_ci- auto cur_tensor = ms_inputs->at(i); 7777be168c0dSopenharmony_ci- if (cur_tensor == nullptr) { 7778be168c0dSopenharmony_ci- std::cout << "empty tensor " << i << std::endl; 7779be168c0dSopenharmony_ci- MS_LOG(ERROR) << "empty tensor " << i; 7780be168c0dSopenharmony_ci- } 7781be168c0dSopenharmony_ci- size_t size; 7782be168c0dSopenharmony_ci- std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; 7783be168c0dSopenharmony_ci- auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(file_name.c_str(), &size)); 7784be168c0dSopenharmony_ci- if (bin_buf == nullptr) { 7785be168c0dSopenharmony_ci- MS_LOG(ERROR) << "ReadFile return nullptr"; 7786be168c0dSopenharmony_ci- std::cout << "ReadFile return nullptr" << std::endl; 7787be168c0dSopenharmony_ci- return mindspore::kLiteNullptr; 7788be168c0dSopenharmony_ci- } 7789be168c0dSopenharmony_ci- auto tensor_data_size = cur_tensor.DataSize(); 7790be168c0dSopenharmony_ci- if (size != tensor_data_size) { 7791be168c0dSopenharmony_ci- std::cout << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size 7792be168c0dSopenharmony_ci- << " ,file_name: " << file_name.c_str() << std::endl; 7793be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Input binary file size error, required: " << tensor_data_size << ", in fact: " << size 7794be168c0dSopenharmony_ci- << " ,file_name: " << file_name.c_str(); 7795be168c0dSopenharmony_ci- return mindspore::kLiteError; 7796be168c0dSopenharmony_ci- } 7797be168c0dSopenharmony_ci- auto input_data = cur_tensor.MutableData(); 7798be168c0dSopenharmony_ci- memcpy(input_data, bin_buf.get(), tensor_data_size); 7799be168c0dSopenharmony_ci- } 7800be168c0dSopenharmony_ci- return mindspore::kSuccess; 7801be168c0dSopenharmony_ci-} 7802be168c0dSopenharmony_ci- 7803be168c0dSopenharmony_ci-int NetRunner::Main() { 7804be168c0dSopenharmony_ci- ms_file_ = flags_->model_file_; 7805be168c0dSopenharmony_ci- InitCallbackParameter(); 7806be168c0dSopenharmony_ci- auto status = InitAndFigureInputs(); 7807be168c0dSopenharmony_ci- if (status != mindspore::kSuccess) { 7808be168c0dSopenharmony_ci- std::cout << "failed to initialize network" << std::endl; 7809be168c0dSopenharmony_ci- return status.StatusCode(); 7810be168c0dSopenharmony_ci- } 7811be168c0dSopenharmony_ci- return RunOnce(); 7812be168c0dSopenharmony_ci-} 7813be168c0dSopenharmony_ci- 7814be168c0dSopenharmony_ci-int CallBack(mindspore::lite::NetTrainFlags *flags) { 7815be168c0dSopenharmony_ci- NetRunner nr(flags); 7816be168c0dSopenharmony_ci- return nr.Main(); 7817be168c0dSopenharmony_ci-} 7818be168c0dSopenharmony_ci- 7819be168c0dSopenharmony_ci-int init = mindspore::lite::NetTrainBase::SetNr(CallBack); 7820be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_runner.h b/mindspore/lite/tools/benchmark_train/net_runner.h 7821be168c0dSopenharmony_cideleted file mode 100644 7822be168c0dSopenharmony_ciindex 243b94ef..00000000 7823be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_runner.h 7824be168c0dSopenharmony_ci+++ /dev/null 7825be168c0dSopenharmony_ci@@ -1,81 +0,0 @@ 7826be168c0dSopenharmony_ci-/** 7827be168c0dSopenharmony_ci- * Copyright 2022 Huawei Technologies Co., Ltd 7828be168c0dSopenharmony_ci- * 7829be168c0dSopenharmony_ci- * Licensed under the Apache License, Version 2.0 (the "License"); 7830be168c0dSopenharmony_ci- * you may not use this file except in compliance with the License. 7831be168c0dSopenharmony_ci- * You may obtain a copy of the License at 7832be168c0dSopenharmony_ci- * 7833be168c0dSopenharmony_ci- * http://www.apache.org/licenses/LICENSE-2.0 7834be168c0dSopenharmony_ci- * 7835be168c0dSopenharmony_ci- * Unless required by applicable law or agreed to in writing, software 7836be168c0dSopenharmony_ci- * distributed under the License is distributed on an "AS IS" BASIS, 7837be168c0dSopenharmony_ci- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7838be168c0dSopenharmony_ci- * See the License for the specific language governing permissions and 7839be168c0dSopenharmony_ci- * limitations under the License. 7840be168c0dSopenharmony_ci- */ 7841be168c0dSopenharmony_ci- 7842be168c0dSopenharmony_ci-#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_ 7843be168c0dSopenharmony_ci-#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_ 7844be168c0dSopenharmony_ci- 7845be168c0dSopenharmony_ci-#include <tuple> 7846be168c0dSopenharmony_ci-#include <iomanip> 7847be168c0dSopenharmony_ci-#include <map> 7848be168c0dSopenharmony_ci-#include <vector> 7849be168c0dSopenharmony_ci-#include <memory> 7850be168c0dSopenharmony_ci-#include <string> 7851be168c0dSopenharmony_ci-#include "include/api/model.h" 7852be168c0dSopenharmony_ci-#include "include/api/graph.h" 7853be168c0dSopenharmony_ci-#include "include/api/types.h" 7854be168c0dSopenharmony_ci-#include "include/api/status.h" 7855be168c0dSopenharmony_ci-#include "include/api/metrics/accuracy.h" 7856be168c0dSopenharmony_ci-#include "include/dataset/datasets.h" 7857be168c0dSopenharmony_ci- 7858be168c0dSopenharmony_ci-using mindspore::AccuracyMetrics; 7859be168c0dSopenharmony_ci-using mindspore::dataset::Dataset; 7860be168c0dSopenharmony_ci- 7861be168c0dSopenharmony_ci-namespace mindspore::lite { 7862be168c0dSopenharmony_ci-class NetTrainFlags; 7863be168c0dSopenharmony_ci-} 7864be168c0dSopenharmony_ci- 7865be168c0dSopenharmony_ci-class NetRunner { 7866be168c0dSopenharmony_ci- public: 7867be168c0dSopenharmony_ci- int Main(); 7868be168c0dSopenharmony_ci- explicit NetRunner(mindspore::lite::NetTrainFlags *flags) : flags_(flags) {} 7869be168c0dSopenharmony_ci- bool ReadArgs(int argc, int8_t *argv[]); 7870be168c0dSopenharmony_ci- virtual ~NetRunner(); 7871be168c0dSopenharmony_ci- 7872be168c0dSopenharmony_ci- private: 7873be168c0dSopenharmony_ci- void Usage(); 7874be168c0dSopenharmony_ci- mindspore::Status InitAndFigureInputs(); 7875be168c0dSopenharmony_ci- void CheckSum(const mindspore::MSTensor &tensor, std::string node_type, int id, std::string in_out); 7876be168c0dSopenharmony_ci- int InitCallbackParameter(); 7877be168c0dSopenharmony_ci- int TrainLoop(); 7878be168c0dSopenharmony_ci- float CalculateAccuracy(int max_tests = 0); 7879be168c0dSopenharmony_ci- float GetLoss() const; 7880be168c0dSopenharmony_ci- int RunOnce(); 7881be168c0dSopenharmony_ci- int CompareOutput(const std::vector<mindspore::MSTensor> &outputs); 7882be168c0dSopenharmony_ci- int LoadInput(std::vector<mindspore::MSTensor> *ms_inputs); 7883be168c0dSopenharmony_ci- int ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs); 7884be168c0dSopenharmony_ci- 7885be168c0dSopenharmony_ci- mindspore::Model *model_ = nullptr; 7886be168c0dSopenharmony_ci- mindspore::Graph *graph_ = nullptr; 7887be168c0dSopenharmony_ci- 7888be168c0dSopenharmony_ci- std::shared_ptr<Dataset> train_ds_; 7889be168c0dSopenharmony_ci- std::shared_ptr<Dataset> test_ds_; 7890be168c0dSopenharmony_ci- std::shared_ptr<AccuracyMetrics> acc_metrics_; 7891be168c0dSopenharmony_ci- 7892be168c0dSopenharmony_ci- std::string ms_file_ = ""; 7893be168c0dSopenharmony_ci- std::string data_dir_ = ""; 7894be168c0dSopenharmony_ci- unsigned int epochs_ = 10; 7895be168c0dSopenharmony_ci- bool verbose_ = false; 7896be168c0dSopenharmony_ci- bool enable_fp16_ = false; 7897be168c0dSopenharmony_ci- int virtual_batch_ = -1; 7898be168c0dSopenharmony_ci- int save_checkpoint_ = 0; 7899be168c0dSopenharmony_ci- int batch_size_ = 32; 7900be168c0dSopenharmony_ci- int h_ = 32; 7901be168c0dSopenharmony_ci- int w_ = 32; 7902be168c0dSopenharmony_ci- mindspore::lite::NetTrainFlags *flags_{nullptr}; 7903be168c0dSopenharmony_ci- mindspore::MSKernelCallBack after_call_back_; 7904be168c0dSopenharmony_ci-}; 7905be168c0dSopenharmony_ci- 7906be168c0dSopenharmony_ci-#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_RUNNER_H_ 7907be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc 7908be168c0dSopenharmony_ciindex 514bba53..dd7b22a9 100644 7909be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_train.cc 7910be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train.cc 7911be168c0dSopenharmony_ci@@ -24,7 +24,6 @@ 7912be168c0dSopenharmony_ci #ifdef ENABLE_NEON 7913be168c0dSopenharmony_ci #include <arm_neon.h> 7914be168c0dSopenharmony_ci #endif 7915be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_runner.h" 7916be168c0dSopenharmony_ci #include "src/common/common.h" 7917be168c0dSopenharmony_ci #include "include/api/serialization.h" 7918be168c0dSopenharmony_ci #include "securec/include/securec.h" 7919be168c0dSopenharmony_ci-- 7920be168c0dSopenharmony_ci2.25.1 7921be168c0dSopenharmony_ci 7922