1be168c0dSopenharmony_ciFrom a303e237bf5506d75b98703d442f01e18fb2c820 Mon Sep 17 00:00:00 2001
2be168c0dSopenharmony_ciFrom: zhangyanhui <zhangyanhui17@huawei.com>
3be168c0dSopenharmony_ciDate: Mon, 8 Jul 2024 15:44:46 +0800
4be168c0dSopenharmony_ciSubject: [PATCH] ConstantOfShape and StridedSlice kernel support bool type
5be168c0dSopenharmony_ci
6be168c0dSopenharmony_ci---
7be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/constant_of_shape_parameter.h  | 1 +
8be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h  | 7 +++++++
9be168c0dSopenharmony_ci .../plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c  | 1 +
10be168c0dSopenharmony_ci .../ops/operator_populate/constant_of_shape_populate.cc    | 3 +++
11be168c0dSopenharmony_ci .../src/common/ops/populate/constant_of_shape_populate.cc  | 3 +++
12be168c0dSopenharmony_ci .../lite/src/litert/kernel/cpu/base/constant_of_shape.cc   | 5 +++++
13be168c0dSopenharmony_ci .../lite/tools/converter/parser/onnx/onnx_node_parser.cc   | 6 ++++++
14be168c0dSopenharmony_ci 7 files changed, 26 insertions(+)
15be168c0dSopenharmony_ci
16be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h
17be168c0dSopenharmony_ciindex f108ea98..d75edb6f 100644
18be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h
19be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h
20be168c0dSopenharmony_ci@@ -23,6 +23,7 @@ typedef struct ConstantOfShapeParameter {
21be168c0dSopenharmony_ci   union value_ {
22be168c0dSopenharmony_ci     float f32_value_;
23be168c0dSopenharmony_ci     int32_t int32_value_;
24be168c0dSopenharmony_ci+    bool bool_value_;
25be168c0dSopenharmony_ci   } value_;
26be168c0dSopenharmony_ci   int data_type_;
27be168c0dSopenharmony_ci   int element_size_;
28be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h
29be168c0dSopenharmony_ciindex 6c607cf5..c884d031 100644
30be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h
31be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h
32be168c0dSopenharmony_ci@@ -38,6 +38,13 @@ inline int ConstantOfShapeFp32(float *output, int start, int end, float value) {
33be168c0dSopenharmony_ci   return NNACL_OK;
34be168c0dSopenharmony_ci }
35be168c0dSopenharmony_ci 
36be168c0dSopenharmony_ci+inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) {
37be168c0dSopenharmony_ci+  for (int i = start; i < end; i++) {
38be168c0dSopenharmony_ci+    output[i] = value;
39be168c0dSopenharmony_ci+  }
40be168c0dSopenharmony_ci+  return NNACL_OK;
41be168c0dSopenharmony_ci+}
42be168c0dSopenharmony_ci+
43be168c0dSopenharmony_ci #ifdef __cplusplus
44be168c0dSopenharmony_ci }
45be168c0dSopenharmony_ci #endif
46be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c
47be168c0dSopenharmony_ciindex 1460c2cc..714bcaef 100644
48be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c
49be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c
50be168c0dSopenharmony_ci@@ -275,3 +275,4 @@ REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice
51be168c0dSopenharmony_ci REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice)
52be168c0dSopenharmony_ci REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice)
53be168c0dSopenharmony_ci REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice)
54be168c0dSopenharmony_ci+REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice)
55be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc
56be168c0dSopenharmony_ciindex 3552b5f9..743f42f5 100644
57be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc
58be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc
59be168c0dSopenharmony_ci@@ -42,6 +42,9 @@ OpParameter *PopulateConstantOfShapeOpParameter(const BaseOperatorPtr &base_oper
60be168c0dSopenharmony_ci     case kNumberTypeInt32:
61be168c0dSopenharmony_ci       param->value_.int32_value_ = static_cast<int32_t>(value[0]);
62be168c0dSopenharmony_ci       break;
63be168c0dSopenharmony_ci+    case kNumberTypeBool:
64be168c0dSopenharmony_ci+      param->value_.bool_value_ = static_cast<bool>(value[0]);
65be168c0dSopenharmony_ci+      break;
66be168c0dSopenharmony_ci     default:
67be168c0dSopenharmony_ci       MS_LOG(ERROR) << "The value of constant of shape is invalid";
68be168c0dSopenharmony_ci       free(param);
69be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc
70be168c0dSopenharmony_ciindex 56263d13..d8fd6473 100644
71be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc
72be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc
73be168c0dSopenharmony_ci@@ -48,6 +48,9 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) {
74be168c0dSopenharmony_ci     case kNumberTypeInt32:
75be168c0dSopenharmony_ci       param->value_.int32_value_ = static_cast<int32_t>(val[0]);
76be168c0dSopenharmony_ci       break;
77be168c0dSopenharmony_ci+    case kNumberTypeBool:
78be168c0dSopenharmony_ci+      param->value_.bool_value_ = static_cast<bool>(val[0]);
79be168c0dSopenharmony_ci+      break;
80be168c0dSopenharmony_ci     default:
81be168c0dSopenharmony_ci       MS_LOG(ERROR) << "The value of constant of shape is invalid";
82be168c0dSopenharmony_ci       free(param);
83be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc
84be168c0dSopenharmony_ciindex d8d24146..94f4a490 100644
85be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc
86be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc
87be168c0dSopenharmony_ci@@ -53,6 +53,10 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
88be168c0dSopenharmony_ci       ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride,
89be168c0dSopenharmony_ci                            param_->value_.int32_value_);
90be168c0dSopenharmony_ci       break;
91be168c0dSopenharmony_ci+    case kNumberTypeBool:
92be168c0dSopenharmony_ci+      ConstantOfShapeBool(reinterpret_cast<bool *>(output_ptr_), start, start + current_stride,
93be168c0dSopenharmony_ci+                           param_->value_.bool_value_);
94be168c0dSopenharmony_ci+      break;
95be168c0dSopenharmony_ci #ifdef ENABLE_FP16
96be168c0dSopenharmony_ci     case kNumberTypeFloat16:
97be168c0dSopenharmony_ci       ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride,
98be168c0dSopenharmony_ci@@ -100,4 +104,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCr
99be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
100be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
101be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
102be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
103be168c0dSopenharmony_ci }  // namespace mindspore::kernel
104be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
105be168c0dSopenharmony_ciindex 39197be6..4d11561e 100644
106be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
107be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc
108be168c0dSopenharmony_ci@@ -223,6 +223,12 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso
109be168c0dSopenharmony_ci         value->push_back(static_cast<float>(reinterpret_cast<const float16 *>(onnx_tensor.raw_data().data())[i]));
110be168c0dSopenharmony_ci       }
111be168c0dSopenharmony_ci       break;
112be168c0dSopenharmony_ci+    case onnx::TensorProto_DataType_BOOL:
113be168c0dSopenharmony_ci+      *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL);
114be168c0dSopenharmony_ci+      for (size_t i = 0; i < data_count; i++) {
115be168c0dSopenharmony_ci+        value->push_back(static_cast<float>(reinterpret_cast<const bool *>(onnx_tensor.raw_data().data())[i]));
116be168c0dSopenharmony_ci+      }
117be168c0dSopenharmony_ci+      break;
118be168c0dSopenharmony_ci     default:
119be168c0dSopenharmony_ci       MS_LOG(ERROR) << "The data type is not supported.";
120be168c0dSopenharmony_ci       return RET_ERROR;
121be168c0dSopenharmony_ci-- 
122be168c0dSopenharmony_ci2.25.1
123be168c0dSopenharmony_ci
124