1be168c0dSopenharmony_ciFrom a303e237bf5506d75b98703d442f01e18fb2c820 Mon Sep 17 00:00:00 2001 2be168c0dSopenharmony_ciFrom: zhangyanhui <zhangyanhui17@huawei.com> 3be168c0dSopenharmony_ciDate: Mon, 8 Jul 2024 15:44:46 +0800 4be168c0dSopenharmony_ciSubject: [PATCH] ConstantOfShape and StridedSlice kernel support bool type 5be168c0dSopenharmony_ci 6be168c0dSopenharmony_ci--- 7be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/constant_of_shape_parameter.h | 1 + 8be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h | 7 +++++++ 9be168c0dSopenharmony_ci .../plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c | 1 + 10be168c0dSopenharmony_ci .../ops/operator_populate/constant_of_shape_populate.cc | 3 +++ 11be168c0dSopenharmony_ci .../src/common/ops/populate/constant_of_shape_populate.cc | 3 +++ 12be168c0dSopenharmony_ci .../lite/src/litert/kernel/cpu/base/constant_of_shape.cc | 5 +++++ 13be168c0dSopenharmony_ci .../lite/tools/converter/parser/onnx/onnx_node_parser.cc | 6 ++++++ 14be168c0dSopenharmony_ci 7 files changed, 26 insertions(+) 15be168c0dSopenharmony_ci 16be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h 17be168c0dSopenharmony_ciindex f108ea98..d75edb6f 100644 18be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h 19be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/constant_of_shape_parameter.h 20be168c0dSopenharmony_ci@@ -23,6 +23,7 @@ typedef struct ConstantOfShapeParameter { 21be168c0dSopenharmony_ci union value_ { 22be168c0dSopenharmony_ci float f32_value_; 23be168c0dSopenharmony_ci int32_t int32_value_; 24be168c0dSopenharmony_ci+ bool bool_value_; 25be168c0dSopenharmony_ci } value_; 26be168c0dSopenharmony_ci int data_type_; 27be168c0dSopenharmony_ci int element_size_; 28be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h 29be168c0dSopenharmony_ciindex 6c607cf5..c884d031 100644 30be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h 31be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp32/constant_of_shape_fp32.h 32be168c0dSopenharmony_ci@@ -38,6 +38,13 @@ inline int ConstantOfShapeFp32(float *output, int start, int end, float value) { 33be168c0dSopenharmony_ci return NNACL_OK; 34be168c0dSopenharmony_ci } 35be168c0dSopenharmony_ci 36be168c0dSopenharmony_ci+inline int ConstantOfShapeBool(bool *output, int start, int end, bool value) { 37be168c0dSopenharmony_ci+ for (int i = start; i < end; i++) { 38be168c0dSopenharmony_ci+ output[i] = value; 39be168c0dSopenharmony_ci+ } 40be168c0dSopenharmony_ci+ return NNACL_OK; 41be168c0dSopenharmony_ci+} 42be168c0dSopenharmony_ci+ 43be168c0dSopenharmony_ci #ifdef __cplusplus 44be168c0dSopenharmony_ci } 45be168c0dSopenharmony_ci #endif 46be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c 47be168c0dSopenharmony_ciindex 1460c2cc..714bcaef 100644 48be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c 49be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/kernel/strided_slice.c 50be168c0dSopenharmony_ci@@ -275,3 +275,4 @@ REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeFloat16, CreateStridedSlice 51be168c0dSopenharmony_ci REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt64, CreateStridedSlice) 52be168c0dSopenharmony_ci REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt32, CreateStridedSlice) 53be168c0dSopenharmony_ci REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeInt8, CreateStridedSlice) 54be168c0dSopenharmony_ci+REG_KERNEL_CREATOR(PrimType_StridedSlice, kNumberTypeBool, CreateStridedSlice) 55be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc 56be168c0dSopenharmony_ciindex 3552b5f9..743f42f5 100644 57be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc 58be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/operator_populate/constant_of_shape_populate.cc 59be168c0dSopenharmony_ci@@ -42,6 +42,9 @@ OpParameter *PopulateConstantOfShapeOpParameter(const BaseOperatorPtr &base_oper 60be168c0dSopenharmony_ci case kNumberTypeInt32: 61be168c0dSopenharmony_ci param->value_.int32_value_ = static_cast<int32_t>(value[0]); 62be168c0dSopenharmony_ci break; 63be168c0dSopenharmony_ci+ case kNumberTypeBool: 64be168c0dSopenharmony_ci+ param->value_.bool_value_ = static_cast<bool>(value[0]); 65be168c0dSopenharmony_ci+ break; 66be168c0dSopenharmony_ci default: 67be168c0dSopenharmony_ci MS_LOG(ERROR) << "The value of constant of shape is invalid"; 68be168c0dSopenharmony_ci free(param); 69be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc 70be168c0dSopenharmony_ciindex 56263d13..d8fd6473 100644 71be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc 72be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/constant_of_shape_populate.cc 73be168c0dSopenharmony_ci@@ -48,6 +48,9 @@ OpParameter *PopulateConstantOfShapeParameter(const void *prim) { 74be168c0dSopenharmony_ci case kNumberTypeInt32: 75be168c0dSopenharmony_ci param->value_.int32_value_ = static_cast<int32_t>(val[0]); 76be168c0dSopenharmony_ci break; 77be168c0dSopenharmony_ci+ case kNumberTypeBool: 78be168c0dSopenharmony_ci+ param->value_.bool_value_ = static_cast<bool>(val[0]); 79be168c0dSopenharmony_ci+ break; 80be168c0dSopenharmony_ci default: 81be168c0dSopenharmony_ci MS_LOG(ERROR) << "The value of constant of shape is invalid"; 82be168c0dSopenharmony_ci free(param); 83be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc 84be168c0dSopenharmony_ciindex d8d24146..94f4a490 100644 85be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc 86be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/constant_of_shape.cc 87be168c0dSopenharmony_ci@@ -53,6 +53,10 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) { 88be168c0dSopenharmony_ci ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride, 89be168c0dSopenharmony_ci param_->value_.int32_value_); 90be168c0dSopenharmony_ci break; 91be168c0dSopenharmony_ci+ case kNumberTypeBool: 92be168c0dSopenharmony_ci+ ConstantOfShapeBool(reinterpret_cast<bool *>(output_ptr_), start, start + current_stride, 93be168c0dSopenharmony_ci+ param_->value_.bool_value_); 94be168c0dSopenharmony_ci+ break; 95be168c0dSopenharmony_ci #ifdef ENABLE_FP16 96be168c0dSopenharmony_ci case kNumberTypeFloat16: 97be168c0dSopenharmony_ci ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride, 98be168c0dSopenharmony_ci@@ -100,4 +104,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCr 99be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 100be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 101be168c0dSopenharmony_ci REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 102be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>) 103be168c0dSopenharmony_ci } // namespace mindspore::kernel 104be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 105be168c0dSopenharmony_ciindex 39197be6..4d11561e 100644 106be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 107be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.cc 108be168c0dSopenharmony_ci@@ -223,6 +223,12 @@ STATUS OnnxNodeParser::GetTensorDataFromOnnx(const onnx::TensorProto &onnx_tenso 109be168c0dSopenharmony_ci value->push_back(static_cast<float>(reinterpret_cast<const float16 *>(onnx_tensor.raw_data().data())[i])); 110be168c0dSopenharmony_ci } 111be168c0dSopenharmony_ci break; 112be168c0dSopenharmony_ci+ case onnx::TensorProto_DataType_BOOL: 113be168c0dSopenharmony_ci+ *type = GetDataTypeFromOnnx(onnx::TensorProto_DataType_BOOL); 114be168c0dSopenharmony_ci+ for (size_t i = 0; i < data_count; i++) { 115be168c0dSopenharmony_ci+ value->push_back(static_cast<float>(reinterpret_cast<const bool *>(onnx_tensor.raw_data().data())[i])); 116be168c0dSopenharmony_ci+ } 117be168c0dSopenharmony_ci+ break; 118be168c0dSopenharmony_ci default: 119be168c0dSopenharmony_ci MS_LOG(ERROR) << "The data type is not supported."; 120be168c0dSopenharmony_ci return RET_ERROR; 121be168c0dSopenharmony_ci-- 122be168c0dSopenharmony_ci2.25.1 123be168c0dSopenharmony_ci 124