1be168c0dSopenharmony_ciFrom 551353a01bc269ac3a509b916bfe28499d1db714 Mon Sep 17 00:00:00 2001
2be168c0dSopenharmony_ciFrom: chengfeng27 <chengfeng27@huawei.com>
3be168c0dSopenharmony_ciDate: Thu, 1 Aug 2024 20:49:49 +0800
4be168c0dSopenharmony_ciSubject: [PATCH] remove recursive lock, which conflict with ffrt
5be168c0dSopenharmony_ci
6be168c0dSopenharmony_ci---
7be168c0dSopenharmony_ci .../src/litert/cxx_api/model/model_impl.cc    | 26 -------------------
8be168c0dSopenharmony_ci .../src/litert/cxx_api/model/model_impl.h     |  1 -
9be168c0dSopenharmony_ci 2 files changed, 27 deletions(-)
10be168c0dSopenharmony_ci
11be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc
12be168c0dSopenharmony_ciindex cacbf86e..6a73a927 100644
13be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc
14be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc
15be168c0dSopenharmony_ci@@ -161,7 +161,6 @@ bool ModelImpl::IsEnablePreInference() {
16be168c0dSopenharmony_ci #endif
17be168c0dSopenharmony_ci Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
18be168c0dSopenharmony_ci                         const std::shared_ptr<Context> &ms_context) {
19be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
20be168c0dSopenharmony_ci   if (session_ != nullptr) {
21be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has been called Build";
22be168c0dSopenharmony_ci     return kLiteModelRebuild;
23be168c0dSopenharmony_ci@@ -207,7 +206,6 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode
24be168c0dSopenharmony_ci 
25be168c0dSopenharmony_ci Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
26be168c0dSopenharmony_ci                         const std::shared_ptr<Context> &ms_context) {
27be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
28be168c0dSopenharmony_ci   if (session_ != nullptr) {
29be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has been called Build";
30be168c0dSopenharmony_ci     return kLiteModelRebuild;
31be168c0dSopenharmony_ci@@ -243,7 +241,6 @@ Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
32be168c0dSopenharmony_ci }
33be168c0dSopenharmony_ci 
34be168c0dSopenharmony_ci Status ModelImpl::Build() {
35be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
36be168c0dSopenharmony_ci   if (session_ != nullptr) {
37be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has been called Build";
38be168c0dSopenharmony_ci     return kLiteModelRebuild;
39be168c0dSopenharmony_ci@@ -357,7 +354,6 @@ Status ModelImpl::RunGraph(const MSKernelCallBack &before, const MSKernelCallBac
40be168c0dSopenharmony_ci bool ModelImpl::IsTrainModel() { return (graph_ && graph_->graph_data_ && graph_->graph_data_->IsTrainModel()); }
41be168c0dSopenharmony_ci 
42be168c0dSopenharmony_ci Status ModelImpl::LoadConfig(const std::string &config_path) {
43be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
44be168c0dSopenharmony_ci   if (session_ != nullptr) {
45be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has been called Build, please call LoadConfig before Build.";
46be168c0dSopenharmony_ci     return kLiteError;
47be168c0dSopenharmony_ci@@ -380,7 +376,6 @@ Status ModelImpl::LoadConfig(const std::string &config_path) {
48be168c0dSopenharmony_ci }
49be168c0dSopenharmony_ci 
50be168c0dSopenharmony_ci Status ModelImpl::UpdateConfig(const std::string &section, const std::pair<std::string, std::string> &config) {
51be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
52be168c0dSopenharmony_ci   auto iter = config_info_.find(section);
53be168c0dSopenharmony_ci   if (iter == config_info_.end()) {
54be168c0dSopenharmony_ci     if (config_info_.size() >= kMaxSectionNum) {
55be168c0dSopenharmony_ci@@ -400,7 +395,6 @@ Status ModelImpl::UpdateConfig(const std::string &section, const std::pair<std::
56be168c0dSopenharmony_ci 
57be168c0dSopenharmony_ci Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
58be168c0dSopenharmony_ci                           const MSKernelCallBack &before, const MSKernelCallBack &after) {
59be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
60be168c0dSopenharmony_ci   if (session_ == nullptr) {
61be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
62be168c0dSopenharmony_ci     return kLiteNullptr;
63be168c0dSopenharmony_ci@@ -565,7 +559,6 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
64be168c0dSopenharmony_ci }
65be168c0dSopenharmony_ci 
66be168c0dSopenharmony_ci Status ModelImpl::Predict(const MSKernelCallBack &before, const MSKernelCallBack &after) {
67be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
68be168c0dSopenharmony_ci   if (session_ == nullptr) {
69be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
70be168c0dSopenharmony_ci     return kLiteNullptr;
71be168c0dSopenharmony_ci@@ -592,7 +585,6 @@ Status ModelImpl::Predict(const MSKernelCallBack &before, const MSKernelCallBack
72be168c0dSopenharmony_ci }
73be168c0dSopenharmony_ci 
74be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetInputs() {
75be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
76be168c0dSopenharmony_ci   if (session_ == nullptr) {
77be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
78be168c0dSopenharmony_ci     return {};
79be168c0dSopenharmony_ci@@ -617,7 +609,6 @@ std::vector<MSTensor> ModelImpl::GetInputs() {
80be168c0dSopenharmony_ci }
81be168c0dSopenharmony_ci 
82be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetOutputs() {
83be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
84be168c0dSopenharmony_ci   if (session_ == nullptr) {
85be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
86be168c0dSopenharmony_ci     return {};
87be168c0dSopenharmony_ci@@ -655,7 +646,6 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
88be168c0dSopenharmony_ci }
89be168c0dSopenharmony_ci 
90be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetGradients() const {
91be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
92be168c0dSopenharmony_ci   if (session_ == nullptr) {
93be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
94be168c0dSopenharmony_ci     return {};
95be168c0dSopenharmony_ci@@ -670,7 +660,6 @@ std::vector<MSTensor> ModelImpl::GetGradients() const {
96be168c0dSopenharmony_ci }
97be168c0dSopenharmony_ci 
98be168c0dSopenharmony_ci Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
99be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
100be168c0dSopenharmony_ci   if (session_ == nullptr) {
101be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
102be168c0dSopenharmony_ci     return kLiteNullptr;
103be168c0dSopenharmony_ci@@ -699,7 +688,6 @@ Status ModelImpl::ApplyGradients(const std::vector<MSTensor> &gradients) {
104be168c0dSopenharmony_ci }
105be168c0dSopenharmony_ci 
106be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetFeatureMaps() const {
107be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
108be168c0dSopenharmony_ci   if (session_ == nullptr) {
109be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
110be168c0dSopenharmony_ci     return {};
111be168c0dSopenharmony_ci@@ -714,7 +702,6 @@ std::vector<MSTensor> ModelImpl::GetFeatureMaps() const {
112be168c0dSopenharmony_ci }
113be168c0dSopenharmony_ci 
114be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetTrainableParams() const {
115be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
116be168c0dSopenharmony_ci   if (session_ == nullptr) {
117be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
118be168c0dSopenharmony_ci     return {};
119be168c0dSopenharmony_ci@@ -729,7 +716,6 @@ std::vector<MSTensor> ModelImpl::GetTrainableParams() const {
120be168c0dSopenharmony_ci }
121be168c0dSopenharmony_ci 
122be168c0dSopenharmony_ci Status ModelImpl::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
123be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
124be168c0dSopenharmony_ci   if (session_ == nullptr) {
125be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
126be168c0dSopenharmony_ci     return kLiteNullptr;
127be168c0dSopenharmony_ci@@ -758,7 +744,6 @@ Status ModelImpl::UpdateFeatureMaps(const std::vector<MSTensor> &new_weights) {
128be168c0dSopenharmony_ci }
129be168c0dSopenharmony_ci 
130be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
131be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
132be168c0dSopenharmony_ci   if (session_ == nullptr) {
133be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
134be168c0dSopenharmony_ci     return {};
135be168c0dSopenharmony_ci@@ -773,7 +758,6 @@ std::vector<MSTensor> ModelImpl::GetOptimizerParams() const {
136be168c0dSopenharmony_ci }
137be168c0dSopenharmony_ci 
138be168c0dSopenharmony_ci Status ModelImpl::SetOptimizerParams(const std::vector<MSTensor> &params) {
139be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
140be168c0dSopenharmony_ci   if (session_ == nullptr) {
141be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
142be168c0dSopenharmony_ci     return kLiteNullptr;
143be168c0dSopenharmony_ci@@ -802,7 +786,6 @@ Status ModelImpl::SetOptimizerParams(const std::vector<MSTensor> &params) {
144be168c0dSopenharmony_ci }
145be168c0dSopenharmony_ci 
146be168c0dSopenharmony_ci MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
147be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
148be168c0dSopenharmony_ci   if (session_ == nullptr) {
149be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
150be168c0dSopenharmony_ci     return MSTensor(nullptr);
151be168c0dSopenharmony_ci@@ -822,7 +805,6 @@ MSTensor ModelImpl::GetInputByTensorName(const std::string &name) {
152be168c0dSopenharmony_ci }
153be168c0dSopenharmony_ci 
154be168c0dSopenharmony_ci std::vector<std::string> ModelImpl::GetOutputTensorNames() {
155be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
156be168c0dSopenharmony_ci   if (session_ == nullptr) {
157be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
158be168c0dSopenharmony_ci     return {};
159be168c0dSopenharmony_ci@@ -831,7 +813,6 @@ std::vector<std::string> ModelImpl::GetOutputTensorNames() {
160be168c0dSopenharmony_ci }
161be168c0dSopenharmony_ci 
162be168c0dSopenharmony_ci MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) {
163be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
164be168c0dSopenharmony_ci   if (session_ == nullptr) {
165be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
166be168c0dSopenharmony_ci     return MSTensor(nullptr);
167be168c0dSopenharmony_ci@@ -851,7 +832,6 @@ MSTensor ModelImpl::GetOutputByTensorName(const std::string &name) {
168be168c0dSopenharmony_ci }
169be168c0dSopenharmony_ci 
170be168c0dSopenharmony_ci std::vector<MSTensor> ModelImpl::GetOutputsByNodeName(const std::string &name) {
171be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
172be168c0dSopenharmony_ci   if (session_ == nullptr) {
173be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
174be168c0dSopenharmony_ci     return {};
175be168c0dSopenharmony_ci@@ -881,7 +861,6 @@ std::vector<MSTensor> ModelImpl::GetOutputsByNodeName(const std::string &name) {
176be168c0dSopenharmony_ci 
177be168c0dSopenharmony_ci Status ModelImpl::BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture,
178be168c0dSopenharmony_ci                                         std::map<std::string, unsigned int> *outputGLTexture) {
179be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
180be168c0dSopenharmony_ci   MS_LOG(INFO) << "Bind GLTexture2D to Input MsTensors and Output MsTensors";
181be168c0dSopenharmony_ci   if (session_ == nullptr) {
182be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
183be168c0dSopenharmony_ci@@ -896,7 +875,6 @@ Status ModelImpl::BindGLTexture2DMemory(const std::map<std::string, unsigned int
184be168c0dSopenharmony_ci }
185be168c0dSopenharmony_ci 
186be168c0dSopenharmony_ci Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
187be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
188be168c0dSopenharmony_ci   if (session_ == nullptr) {
189be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
190be168c0dSopenharmony_ci     return kLiteNullptr;
191be168c0dSopenharmony_ci@@ -950,7 +928,6 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
192be168c0dSopenharmony_ci }
193be168c0dSopenharmony_ci 
194be168c0dSopenharmony_ci Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
195be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
196be168c0dSopenharmony_ci   if (session_ == nullptr) {
197be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
198be168c0dSopenharmony_ci     return kLiteNullptr;
199be168c0dSopenharmony_ci@@ -982,7 +959,6 @@ Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
200be168c0dSopenharmony_ci }
201be168c0dSopenharmony_ci 
202be168c0dSopenharmony_ci Status ModelImpl::SetupVirtualBatch(int virtual_batch_multiplier, float lr, float momentum) {
203be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
204be168c0dSopenharmony_ci   if (session_ == nullptr) {
205be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
206be168c0dSopenharmony_ci     return kLiteNullptr;
207be168c0dSopenharmony_ci@@ -992,7 +968,6 @@ Status ModelImpl::SetupVirtualBatch(int virtual_batch_multiplier, float lr, floa
208be168c0dSopenharmony_ci }
209be168c0dSopenharmony_ci 
210be168c0dSopenharmony_ci Status ModelImpl::SetLearningRate(float learning_rate) {
211be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
212be168c0dSopenharmony_ci   if (session_ == nullptr) {
213be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Model has not been called Build, or Model Build has failed";
214be168c0dSopenharmony_ci     return kLiteNullptr;
215be168c0dSopenharmony_ci@@ -1002,7 +977,6 @@ Status ModelImpl::SetLearningRate(float learning_rate) {
216be168c0dSopenharmony_ci }
217be168c0dSopenharmony_ci 
218be168c0dSopenharmony_ci float ModelImpl::GetLearningRate() {
219be168c0dSopenharmony_ci-  std::lock_guard<std::recursive_mutex> lock(mutex_);
220be168c0dSopenharmony_ci   if (session_ == nullptr) {
221be168c0dSopenharmony_ci     MS_LOG(WARNING) << "Model has not been called Build, or Model Build has failed";
222be168c0dSopenharmony_ci     return 0.0;
223be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.h b/mindspore/lite/src/litert/cxx_api/model/model_impl.h
224be168c0dSopenharmony_ciindex 8e11ee55..17cafba8 100644
225be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.h
226be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.h
227be168c0dSopenharmony_ci@@ -131,7 +131,6 @@ class ModelImpl {
228be168c0dSopenharmony_ci   std::shared_ptr<Context> context_ = nullptr;
229be168c0dSopenharmony_ci   std::shared_ptr<TrainCfg> cfg_ = nullptr;
230be168c0dSopenharmony_ci   std::vector<Metrics *> metrics_;
231be168c0dSopenharmony_ci-  mutable std::recursive_mutex mutex_;
232be168c0dSopenharmony_ci   void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
233be168c0dSopenharmony_ci   void SetContext(const std::shared_ptr<Context> &context) { context_ = context; }
234be168c0dSopenharmony_ci   void SetConfig(const std::shared_ptr<TrainCfg> cfg) { cfg_ = cfg; }
235be168c0dSopenharmony_ci-- 
236be168c0dSopenharmony_ci2.17.1
237be168c0dSopenharmony_ci
238