1be168c0dSopenharmony_ciFrom a4c343574d6d6998a6f1b95f436401c8eb8a2c90 Mon Sep 17 00:00:00 2001
2be168c0dSopenharmony_ciFrom: zhangyanhui <zhangyanhui17@huawei.com>
3be168c0dSopenharmony_ciDate: Mon, 1 Jul 2024 21:12:15 +0800
4be168c0dSopenharmony_ciSubject: [PATCH] auto-apply 0015-bugfix-for-cpu-kernel.patch
5be168c0dSopenharmony_ci
6be168c0dSopenharmony_ci---
7be168c0dSopenharmony_ci .../cpu/kernel/nnacl/infer/where_infer.c      | 66 ++++++-------
8be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/kernel/clip.c     |  2 +
9be168c0dSopenharmony_ci .../src/litert/kernel/cpu/fp32/prelu_fp32.cc  | 12 +--
10be168c0dSopenharmony_ci .../src/litert/kernel/cpu/fp32/where_fp32.cc  | 96 ++++++++++++++++---
11be168c0dSopenharmony_ci .../src/litert/kernel/cpu/fp32/where_fp32.h   |  2 +
12be168c0dSopenharmony_ci .../lite/tools/optimizer/fusion/glu_fusion.h  |  4 +-
13be168c0dSopenharmony_ci 6 files changed, 124 insertions(+), 58 deletions(-)
14be168c0dSopenharmony_ci
15be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/where_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/where_infer.c
16be168c0dSopenharmony_ciindex f6d4e1b2..c714627a 100644
17be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/where_infer.c
18be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/where_infer.c
19be168c0dSopenharmony_ci@@ -17,18 +17,19 @@
20be168c0dSopenharmony_ci #include "nnacl/infer/where_infer.h"
21be168c0dSopenharmony_ci #include "nnacl/infer/infer_register.h"
22be168c0dSopenharmony_ci #include "nnacl/tensor_c_utils.h"
23be168c0dSopenharmony_ci+#include "nnacl/infer/broadcast_to_infer.h"
24be168c0dSopenharmony_ci 
25be168c0dSopenharmony_ci-static size_t GetAxisout(const TensorC *input0, const TensorC *input1, const TensorC *input2, size_t index) {
26be168c0dSopenharmony_ci-  if (input0->shape_[index] == input1->shape_[index] && input0->shape_[index] != input2->shape_[index]) {
27be168c0dSopenharmony_ci-    return index;
28be168c0dSopenharmony_ci+int WhereBroadCastInferShape(const int input_shape0_size, const int input_shape1_size, const int *input_shape0,
29be168c0dSopenharmony_ci+                             const int *input_shape1, int *ndim, int *in_shape0, int *in_shape1, int *out_shape,
30be168c0dSopenharmony_ci+                             bool *has_broad_cast) {
31be168c0dSopenharmony_ci+  if (input_shape0_size > MAX_SHAPE_SIZE || input_shape1_size > MAX_SHAPE_SIZE) {
32be168c0dSopenharmony_ci+    return NNACL_ERR;
33be168c0dSopenharmony_ci   }
34be168c0dSopenharmony_ci-  if (input0->shape_[index] == input2->shape_[index] && input0->shape_[index] != input1->shape_[index]) {
35be168c0dSopenharmony_ci-    return index;
36be168c0dSopenharmony_ci-  }
37be168c0dSopenharmony_ci-  if (input1->shape_[index] == input2->shape_[index] && input0->shape_[index] != input1->shape_[index]) {
38be168c0dSopenharmony_ci-    return index;
39be168c0dSopenharmony_ci+  MakeUpInputShapes(input_shape0_size, input_shape1_size, input_shape0, input_shape1, ndim, in_shape0, in_shape1);
40be168c0dSopenharmony_ci+  if (*ndim >= MAX_SHAPE_SIZE) {
41be168c0dSopenharmony_ci+    return NNACL_INFER_INVALID;
42be168c0dSopenharmony_ci   }
43be168c0dSopenharmony_ci-  return MAX_SHAPE_SIZE + 1;
44be168c0dSopenharmony_ci+  return BroadCastOutputShape(in_shape0, in_shape1, *ndim, out_shape, has_broad_cast);
45be168c0dSopenharmony_ci }
46be168c0dSopenharmony_ci 
47be168c0dSopenharmony_ci int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
48be168c0dSopenharmony_ci@@ -59,35 +60,28 @@ int WhereInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **
49be168c0dSopenharmony_ci   if (!InferFlag(inputs, inputs_size)) {
50be168c0dSopenharmony_ci     return NNACL_INFER_INVALID;
51be168c0dSopenharmony_ci   }
52be168c0dSopenharmony_ci-
53be168c0dSopenharmony_ci-  int num = GetElementNum(input0);
54be168c0dSopenharmony_ci-  int num1 = GetElementNum(input1);
55be168c0dSopenharmony_ci-  int num2 = GetElementNum(input2);
56be168c0dSopenharmony_ci-  int nummax = num > num1 ? num : (num1 > num2 ? num1 : num2);
57be168c0dSopenharmony_ci-  size_t min_input_shape_size = input1->shape_size_ < input2->shape_size_ ? input1->shape_size_ : input2->shape_size_;
58be168c0dSopenharmony_ci-  size_t axisout = MAX_SHAPE_SIZE + 1;
59be168c0dSopenharmony_ci-  size_t temp = 0;
60be168c0dSopenharmony_ci-  for (size_t j = 0; j < input0->shape_size_; j++) {
61be168c0dSopenharmony_ci-    if (j >= MAX_SHAPE_SIZE) {
62be168c0dSopenharmony_ci-      return NNACL_ERR;
63be168c0dSopenharmony_ci-    }
64be168c0dSopenharmony_ci-    if (j < min_input_shape_size) {
65be168c0dSopenharmony_ci-      axisout = GetAxisout(input0, input1, input2, j);
66be168c0dSopenharmony_ci-      if (axisout != MAX_SHAPE_SIZE + 1) {
67be168c0dSopenharmony_ci-        break;
68be168c0dSopenharmony_ci-      }
69be168c0dSopenharmony_ci-    }
70be168c0dSopenharmony_ci-    temp += 1;
71be168c0dSopenharmony_ci-    if (temp == input0->shape_size_) {
72be168c0dSopenharmony_ci-      SetShapeTensor(output, input);
73be168c0dSopenharmony_ci-      return NNACL_OK;
74be168c0dSopenharmony_ci-    }
75be168c0dSopenharmony_ci+  int in_shape0[MAX_SHAPE_SIZE] = {0};
76be168c0dSopenharmony_ci+  int in_shape1[MAX_SHAPE_SIZE] = {0};
77be168c0dSopenharmony_ci+  int in_shape2[MAX_SHAPE_SIZE] = {0};
78be168c0dSopenharmony_ci+  int output_shape[MAX_SHAPE_SIZE] = {0};
79be168c0dSopenharmony_ci+  size_t input_shape0_size = input0->shape_size_;
80be168c0dSopenharmony_ci+  size_t input_shape1_size = input1->shape_size_;
81be168c0dSopenharmony_ci+  size_t input_shape2_size = input2->shape_size_;
82be168c0dSopenharmony_ci+  const int *input_shape0 = input0->shape_;
83be168c0dSopenharmony_ci+  const int *input_shape1 = input1->shape_;
84be168c0dSopenharmony_ci+  const int *input_shape2 = input2->shape_;
85be168c0dSopenharmony_ci+  int ndim = (int)input_shape0_size;
86be168c0dSopenharmony_ci+  bool has_broad_cast_1 = false;
87be168c0dSopenharmony_ci+  bool has_broad_cast_2 = false;
88be168c0dSopenharmony_ci+  if (WhereBroadCastInferShape(input_shape0_size, input_shape1_size, input_shape0, input_shape1, &ndim, in_shape0,
89be168c0dSopenharmony_ci+                               in_shape1, output_shape, &has_broad_cast_1) != NNACL_OK) {
90be168c0dSopenharmony_ci+    return NNACL_ERR;
91be168c0dSopenharmony_ci   }
92be168c0dSopenharmony_ci-
93be168c0dSopenharmony_ci-  ShapeSet(output->shape_, &output->shape_size_, input0->shape_, input0->shape_size_);
94be168c0dSopenharmony_ci-  if (axisout != MAX_SHAPE_SIZE + 1) {
95be168c0dSopenharmony_ci-    output->shape_[axisout] = nummax;
96be168c0dSopenharmony_ci+  if (WhereBroadCastInferShape(ndim, input_shape2_size, output_shape, input_shape2, &ndim, in_shape0, in_shape2,
97be168c0dSopenharmony_ci+                               output_shape, &has_broad_cast_2) != NNACL_OK) {
98be168c0dSopenharmony_ci+    return NNACL_ERR;
99be168c0dSopenharmony_ci   }
100be168c0dSopenharmony_ci+  ShapeSet(output->shape_, &output->shape_size_, output_shape, ndim);
101be168c0dSopenharmony_ci   return NNACL_OK;
102be168c0dSopenharmony_ci }
103be168c0dSopenharmony_ci 
104be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/clip.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/clip.c
105be168c0dSopenharmony_ciindex ece0eff0..ae8ac5d8 100644
106be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/clip.c
107be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/clip.c
108be168c0dSopenharmony_ci@@ -81,6 +81,8 @@ int ClipCompute(struct KernelBase *self) {
109be168c0dSopenharmony_ci   NNACL_CHECK_NULL_RETURN_ERR(clip);
110be168c0dSopenharmony_ci   ClipParameter *param = (ClipParameter *)clip->base_.param_;
111be168c0dSopenharmony_ci   NNACL_CHECK_NULL_RETURN_ERR(param);
112be168c0dSopenharmony_ci+  clip->min_val_ = param->min_val_;
113be168c0dSopenharmony_ci+  clip->max_val_ = param->max_val_;
114be168c0dSopenharmony_ci 
115be168c0dSopenharmony_ci   int ret = NNACL_OK;
116be168c0dSopenharmony_ci   if (clip->base_.in_size_ > ONE_TENSOR) {
117be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/prelu_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/prelu_fp32.cc
118be168c0dSopenharmony_ciindex cae491f5..74639503 100644
119be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp32/prelu_fp32.cc
120be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/prelu_fp32.cc
121be168c0dSopenharmony_ci@@ -44,12 +44,6 @@ int PReluCPUKernel::Prepare() {
122be168c0dSopenharmony_ci   CHECK_NULL_RETURN(in_tensors_[kInputIndex]);
123be168c0dSopenharmony_ci   CHECK_NULL_RETURN(in_tensors_[kSlopeIndex]);
124be168c0dSopenharmony_ci   CHECK_NULL_RETURN(out_tensors_[kOutputIndex]);
125be168c0dSopenharmony_ci-  auto slope_shapes = in_tensors_[C1NUM]->ElementsNum();
126be168c0dSopenharmony_ci-  auto input_channel = in_tensors_[C0NUM]->Channel();
127be168c0dSopenharmony_ci-  if ((slope_shapes != C1NUM) && (slope_shapes != input_channel)) {
128be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "slope_shapes: " << slope_shapes << " is not equal to 1 or input_channel: " << input_channel;
129be168c0dSopenharmony_ci-    return lite::RET_ERROR;
130be168c0dSopenharmony_ci-  }
131be168c0dSopenharmony_ci   if (in_tensors_[1]->ElementsNum() == 1) {
132be168c0dSopenharmony_ci     param_->channelShared = true;
133be168c0dSopenharmony_ci   } else {
134be168c0dSopenharmony_ci@@ -83,6 +77,12 @@ int PReluCPUKernel::DoExcute(int task_id) const {
135be168c0dSopenharmony_ci }
136be168c0dSopenharmony_ci 
137be168c0dSopenharmony_ci int PReluCPUKernel::ReSize() {
138be168c0dSopenharmony_ci+  auto slope_shapes = in_tensors_[C1NUM]->ElementsNum();
139be168c0dSopenharmony_ci+  auto input_channel = in_tensors_[C0NUM]->Channel();
140be168c0dSopenharmony_ci+  if ((slope_shapes != C1NUM) && (slope_shapes != input_channel)) {
141be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "slope_shapes: " << slope_shapes << " is not equal to 1 or input_channel: " << input_channel;
142be168c0dSopenharmony_ci+    return lite::RET_ERROR;
143be168c0dSopenharmony_ci+  }
144be168c0dSopenharmony_ci   auto &input = in_tensors_[kInputIndex];
145be168c0dSopenharmony_ci   param_->input_num_ = input->ElementsNum();
146be168c0dSopenharmony_ci   CHECK_NOT_EQUAL_RETURN(out_tensors_.front()->ElementsNum(), param_->input_num_);
147be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.cc b/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.cc
148be168c0dSopenharmony_ciindex d7c987e3..a73fda7c 100644
149be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.cc
150be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.cc
151be168c0dSopenharmony_ci@@ -20,6 +20,7 @@
152be168c0dSopenharmony_ci #include "src/litert/kernel_registry.h"
153be168c0dSopenharmony_ci #include "include/errorcode.h"
154be168c0dSopenharmony_ci #include "nnacl/common_func.h"
155be168c0dSopenharmony_ci+#include "nnacl/base/broadcast_to.h"
156be168c0dSopenharmony_ci 
157be168c0dSopenharmony_ci using mindspore::kernel::KERNEL_ARCH;
158be168c0dSopenharmony_ci using mindspore::lite::KernelRegistrar;
159be168c0dSopenharmony_ci@@ -153,36 +154,58 @@ int WhereCPUKernel::RunWithSingleInput() {
160be168c0dSopenharmony_ci }
161be168c0dSopenharmony_ci 
162be168c0dSopenharmony_ci int WhereCPUKernel::RunWithTripleInputs() {
163be168c0dSopenharmony_ci-  auto condition = in_tensors_.at(0);
164be168c0dSopenharmony_ci+  TensorC *condition = in_tensors_.at(0)->ConvertToTensorC();
165be168c0dSopenharmony_ci   CHECK_NULL_RETURN(condition);
166be168c0dSopenharmony_ci-  auto x = in_tensors_.at(1);
167be168c0dSopenharmony_ci+  TensorC *x = in_tensors_.at(1)->ConvertToTensorC();
168be168c0dSopenharmony_ci   CHECK_NULL_RETURN(x);
169be168c0dSopenharmony_ci-  auto y = in_tensors_.at(C2NUM);
170be168c0dSopenharmony_ci+  TensorC *y = in_tensors_.at(C2NUM)->ConvertToTensorC();
171be168c0dSopenharmony_ci   CHECK_NULL_RETURN(y);
172be168c0dSopenharmony_ci-  int condition_nums = condition->ElementsNum();
173be168c0dSopenharmony_ci-  int x_num = x->ElementsNum();
174be168c0dSopenharmony_ci-  int y_num = y->ElementsNum();
175be168c0dSopenharmony_ci-  int out_num = out_tensors_.front()->ElementsNum();
176be168c0dSopenharmony_ci+  TensorC *output = out_tensors_.at(0)->ConvertToTensorC();
177be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(output);
178be168c0dSopenharmony_ci+  int condition_nums = GetElementNum(condition);
179be168c0dSopenharmony_ci+  int x_num = GetElementNum(x);
180be168c0dSopenharmony_ci+  int y_num = GetElementNum(y);
181be168c0dSopenharmony_ci+  int out_num = GetElementNum(output);
182be168c0dSopenharmony_ci 
183be168c0dSopenharmony_ci-  condition_ = reinterpret_cast<bool *>(condition->data());
184be168c0dSopenharmony_ci+  condition_ = reinterpret_cast<bool *>(condition->data_);
185be168c0dSopenharmony_ci   CHECK_NULL_RETURN(condition_);
186be168c0dSopenharmony_ci-  x_ = x->data();
187be168c0dSopenharmony_ci+  x_ = x->data_;
188be168c0dSopenharmony_ci   CHECK_NULL_RETURN(x_);
189be168c0dSopenharmony_ci-  y_ = y->data();
190be168c0dSopenharmony_ci+  y_ = y->data_;
191be168c0dSopenharmony_ci   CHECK_NULL_RETURN(y_);
192be168c0dSopenharmony_ci-  output_data_ = out_tensors_.at(0)->data();
193be168c0dSopenharmony_ci+  output_data_ = output->data_;
194be168c0dSopenharmony_ci   int num_max = condition_nums > x_num ? condition_nums : (x_num > y_num ? x_num : y_num);
195be168c0dSopenharmony_ci   where_param_->condition_num_ = condition_nums;
196be168c0dSopenharmony_ci   where_param_->x_num_ = x_num;
197be168c0dSopenharmony_ci   where_param_->y_num_ = y_num;
198be168c0dSopenharmony_ci   where_param_->max_num_ = num_max;
199be168c0dSopenharmony_ci-
200be168c0dSopenharmony_ci+  void *condition_broadcast_buf = nullptr;
201be168c0dSopenharmony_ci+  void *x_broadcast_buf = nullptr;
202be168c0dSopenharmony_ci+  void *y_broadcast_buf = nullptr;
203be168c0dSopenharmony_ci   CHECK_LESS_RETURN(out_num, num_max);
204be168c0dSopenharmony_ci 
205be168c0dSopenharmony_ci   if (((condition_nums != 1) && (condition_nums != num_max)) || ((x_num != 1) && (x_num != num_max)) ||
206be168c0dSopenharmony_ci       ((y_num != 1) && (y_num != num_max))) {
207be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "The length of three inputs are not equal to 1 or length of output, which is unacceptable";
208be168c0dSopenharmony_ci-    return RET_ERROR;
209be168c0dSopenharmony_ci+    if (condition_nums != GetElementNum(y)) {
210be168c0dSopenharmony_ci+      int ret =
211be168c0dSopenharmony_ci+        BroadcastForInput(condition, x, y, &condition_broadcast_buf, &x_broadcast_buf, &y_broadcast_buf, output);
212be168c0dSopenharmony_ci+      if (ret != RET_OK) {
213be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "BroadcastForInput failed.";
214be168c0dSopenharmony_ci+        return RET_ERROR;
215be168c0dSopenharmony_ci+      }
216be168c0dSopenharmony_ci+      int max_num = GetElementNum(output);
217be168c0dSopenharmony_ci+      condition_ = reinterpret_cast<bool *>(condition_broadcast_buf);
218be168c0dSopenharmony_ci+      x_ = x_broadcast_buf;
219be168c0dSopenharmony_ci+      y_ = y_broadcast_buf;
220be168c0dSopenharmony_ci+      output_data_ = output->data_;
221be168c0dSopenharmony_ci+      where_param_->condition_num_ = max_num;
222be168c0dSopenharmony_ci+      where_param_->x_num_ = max_num;
223be168c0dSopenharmony_ci+      where_param_->y_num_ = max_num;
224be168c0dSopenharmony_ci+      where_param_->max_num_ = max_num;
225be168c0dSopenharmony_ci+    } else {
226be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "The length of three inputs are not equal to 1 or length of output, which is unacceptable";
227be168c0dSopenharmony_ci+      return RET_ERROR;
228be168c0dSopenharmony_ci+    }
229be168c0dSopenharmony_ci   }
230be168c0dSopenharmony_ci   if (num_max <= 0) {
231be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Error, inputs' length are zero !!!";
232be168c0dSopenharmony_ci@@ -193,6 +216,9 @@ int WhereCPUKernel::RunWithTripleInputs() {
233be168c0dSopenharmony_ci     MS_LOG(ERROR) << "WhereDwRun error: error_code[" << ret << "]";
234be168c0dSopenharmony_ci     return RET_ERROR;
235be168c0dSopenharmony_ci   }
236be168c0dSopenharmony_ci+  ms_context_->allocator->Free(condition_broadcast_buf);
237be168c0dSopenharmony_ci+  ms_context_->allocator->Free(x_broadcast_buf);
238be168c0dSopenharmony_ci+  ms_context_->allocator->Free(y_broadcast_buf);
239be168c0dSopenharmony_ci   return RET_OK;
240be168c0dSopenharmony_ci }
241be168c0dSopenharmony_ci 
242be168c0dSopenharmony_ci@@ -214,6 +240,48 @@ int WhereCPUKernel::Run() {
243be168c0dSopenharmony_ci   return ret;
244be168c0dSopenharmony_ci }
245be168c0dSopenharmony_ci 
246be168c0dSopenharmony_ci+int WhereCPUKernel::BroadcastForInput(TensorC *condition, TensorC *x, TensorC *y, void **condition_broadcast_buf,
247be168c0dSopenharmony_ci+                                      void **x_broadcast_buf, void **y_broadcast_buf, TensorC *output) {
248be168c0dSopenharmony_ci+  size_t broad_cast_buf_size = GetSize(output);
249be168c0dSopenharmony_ci+  BroadcastShapeInfo condition_info;
250be168c0dSopenharmony_ci+  condition_info.input_shape_size_ = condition->shape_size_;
251be168c0dSopenharmony_ci+  condition_info.output_shape_size_ = output->shape_size_;
252be168c0dSopenharmony_ci+  (void)memcpy(condition_info.input_shape_, condition->shape_, condition->shape_size_ * sizeof(int));
253be168c0dSopenharmony_ci+  (void)memcpy(condition_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int));
254be168c0dSopenharmony_ci+  BroadcastShapeInfo x_info;
255be168c0dSopenharmony_ci+  x_info.input_shape_size_ = x->shape_size_;
256be168c0dSopenharmony_ci+  x_info.output_shape_size_ = output->shape_size_;
257be168c0dSopenharmony_ci+  (void)memcpy(x_info.input_shape_, x->shape_, x->shape_size_ * sizeof(int));
258be168c0dSopenharmony_ci+  (void)memcpy(x_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int));
259be168c0dSopenharmony_ci+  BroadcastShapeInfo y_info;
260be168c0dSopenharmony_ci+  y_info.input_shape_size_ = y->shape_size_;
261be168c0dSopenharmony_ci+  y_info.output_shape_size_ = output->shape_size_;
262be168c0dSopenharmony_ci+  (void)memcpy(y_info.input_shape_, y->shape_, y->shape_size_ * sizeof(int));
263be168c0dSopenharmony_ci+  (void)memcpy(y_info.output_shape_, output->shape_, output->shape_size_ * sizeof(int));
264be168c0dSopenharmony_ci+
265be168c0dSopenharmony_ci+  *condition_broadcast_buf = ms_context_->allocator->Malloc(broad_cast_buf_size);
266be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(*condition_broadcast_buf);
267be168c0dSopenharmony_ci+  BroadcastToSize8(condition->data_, &condition_info, *condition_broadcast_buf);
268be168c0dSopenharmony_ci+
269be168c0dSopenharmony_ci+  *x_broadcast_buf = ms_context_->allocator->Malloc(broad_cast_buf_size);
270be168c0dSopenharmony_ci+  if (*x_broadcast_buf == nullptr) {
271be168c0dSopenharmony_ci+    ms_context_->allocator->Free(*condition_broadcast_buf);
272be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc x_broadcast_buf error";
273be168c0dSopenharmony_ci+    return RET_ERROR;
274be168c0dSopenharmony_ci+  }
275be168c0dSopenharmony_ci+  BroadcastToSize32(x->data_, &x_info, *x_broadcast_buf);
276be168c0dSopenharmony_ci+
277be168c0dSopenharmony_ci+  *y_broadcast_buf = ms_context_->allocator->Malloc(broad_cast_buf_size);
278be168c0dSopenharmony_ci+  if (*y_broadcast_buf == nullptr) {
279be168c0dSopenharmony_ci+    ms_context_->allocator->Free(*condition_broadcast_buf);
280be168c0dSopenharmony_ci+    ms_context_->allocator->Free(*x_broadcast_buf);
281be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc y_broadcast_buf error";
282be168c0dSopenharmony_ci+    return RET_ERROR;
283be168c0dSopenharmony_ci+  }
284be168c0dSopenharmony_ci+  BroadcastToSize32(y->data_, &y_info, *y_broadcast_buf);
285be168c0dSopenharmony_ci+  return RET_OK;
286be168c0dSopenharmony_ci+}
287be168c0dSopenharmony_ci+
288be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Where, LiteKernelCreator<WhereCPUKernel>)
289be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Where, LiteKernelCreator<WhereCPUKernel>)
290be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Where, LiteKernelCreator<WhereCPUKernel>)
291be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.h b/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.h
292be168c0dSopenharmony_ciindex 0d785732..ae6e3eba 100644
293be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.h
294be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/fp32/where_fp32.h
295be168c0dSopenharmony_ci@@ -51,6 +51,8 @@ class WhereCPUKernel : public LiteKernel {
296be168c0dSopenharmony_ci  private:
297be168c0dSopenharmony_ci   int RunWithSingleInput();
298be168c0dSopenharmony_ci   int RunWithTripleInputs();
299be168c0dSopenharmony_ci+  int BroadcastForInput(TensorC *condition, TensorC *x, TensorC *y, void **condition_broadcast_buf,
300be168c0dSopenharmony_ci+                        void **x_broadcast_buf, void **y_broadcast_buf, TensorC *output);
301be168c0dSopenharmony_ci };
302be168c0dSopenharmony_ci }  // namespace mindspore::kernel
303be168c0dSopenharmony_ci #endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CPU_FP32_WHERE_FP32_H_
304be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/optimizer/fusion/glu_fusion.h b/mindspore/lite/tools/optimizer/fusion/glu_fusion.h
305be168c0dSopenharmony_ciindex 5e6a7e79..513a49d9 100644
306be168c0dSopenharmony_ci--- a/mindspore/lite/tools/optimizer/fusion/glu_fusion.h
307be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/optimizer/fusion/glu_fusion.h
308be168c0dSopenharmony_ci@@ -1,5 +1,5 @@
309be168c0dSopenharmony_ci /**
310be168c0dSopenharmony_ci- * Copyright 2021 Huawei Technologies Co., Ltd
311be168c0dSopenharmony_ci+ * Copyright 2021~2024 Huawei Technologies Co., Ltd
312be168c0dSopenharmony_ci  *
313be168c0dSopenharmony_ci  * Licensed under the Apache License, Version 2.0 (the "License");
314be168c0dSopenharmony_ci  * you may not use this file except in compliance with the License.
315be168c0dSopenharmony_ci@@ -26,7 +26,7 @@ namespace mindspore {
316be168c0dSopenharmony_ci namespace opt {
317be168c0dSopenharmony_ci class GLUFusion : public LitePatternProcessPass {
318be168c0dSopenharmony_ci  public:
319be168c0dSopenharmony_ci-  explicit GLUFusion(const std::string &name = "glu_fusion", bool multigraph = true)
320be168c0dSopenharmony_ci+  explicit GLUFusion(const std::string &name = "GLUFusion", bool multigraph = true)
321be168c0dSopenharmony_ci       : LitePatternProcessPass(name, multigraph) {}
322be168c0dSopenharmony_ci 
323be168c0dSopenharmony_ci   ~GLUFusion() override = default;
324be168c0dSopenharmony_ci-- 
325be168c0dSopenharmony_ci2.25.1
326be168c0dSopenharmony_ci
327