1From f0daa7ef13e1741f8bcd1dfad7517a4a8ae4a209 Mon Sep 17 00:00:00 2001
2From: xuanyue <xuanyue@huawei.com>
3Date: Thu, 21 Mar 2024 19:38:34 +0800
4Subject: [PATCH] DynamicQuant strategy opyimization
5
6---
7 .../kernel/nnacl/dynamic_quant_parameter.h    |   7 +-
8 mindspore/core/ops/dynamic_quant.cc           |  12 +
9 mindspore/core/ops/dynamic_quant.h            |  10 +
10 mindspore/core/ops/op_name.h                  |   1 +
11 mindspore/lite/schema/inner/ops_generated.h   |  53 +++-
12 mindspore/lite/schema/ops.fbs                 |   1 +
13 mindspore/lite/schema/ops_generated.h         |  34 +-
14 mindspore/lite/src/common/ops/ops_def.cc      |   1 +
15 .../ops/populate/dynamic_quant_populate.cc    |  24 +-
16 .../litert/kernel/cpu/int8/dynamic_quant.cc   | 299 +++++++++++-------
17 .../litert/kernel/cpu/int8/dynamic_quant.h    |  59 ++--
18 .../cpu/int8/matmul_dynamic_base_int8.cc      |  43 ++-
19 .../cpu/int8/matmul_dynamic_base_int8.h       |   7 +-
20 .../quantizer/insert_quant_node_manager.cc    |  27 +-
21 .../quantizer/insert_quant_node_manager.h     |   5 +-
22 15 files changed, 395 insertions(+), 188 deletions(-)
23
24diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h
25index aaabe041..1fc166cb 100644
26--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h
27+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/dynamic_quant_parameter.h
28@@ -21,10 +21,9 @@
29 typedef struct DynamicQuantParameter {
30   OpParameter op_parameter_;
31   bool symmetric_;
32-  int64_t dst_type_;
33-  bool activation_perchannel_;
34-  int64_t prefer_axis_;
35-  bool transpose_;
36+  int dst_type_;
37+  int axis_num_;
38+  int prefer_axes_[MAX_SHAPE_SIZE];
39 } DynamicQuantParameter;
40 
41 #endif  // NNACL_DYNAMIC_QUANT_PARAMETER_H_
42diff --git a/mindspore/core/ops/dynamic_quant.cc b/mindspore/core/ops/dynamic_quant.cc
43index 63ea0be5..1949f809 100644
44--- a/mindspore/core/ops/dynamic_quant.cc
45+++ b/mindspore/core/ops/dynamic_quant.cc
46@@ -48,6 +48,18 @@ bool DynamicQuant::get_transpose() const {
47   auto value_ptr = this->GetAttr(kTrans);
48   return GetValue<bool>(value_ptr);
49 }
50+
51+void DynamicQuant::set_prefer_axes(const std::vector<int> &prefer_axes) {
52+  (void)AddAttr(kPreferAxes, api::MakeValue(prefer_axes));
53+}
54+
55+std::vector<int> DynamicQuant::get_prefer_axes() const {
56+  auto value_ptr = GetAttr(kPreferAxes);
57+  auto tmp = GetValue<std::vector<int64_t>>(value_ptr);
58+  std::vector<int> res(tmp.begin(), tmp.end());
59+  return res;
60+}
61+
62 void DynamicQuant::Init(const bool symmetric, const int64_t dst_type) {
63   this->set_symmetric(symmetric);
64   this->set_dst_type(dst_type);
65diff --git a/mindspore/core/ops/dynamic_quant.h b/mindspore/core/ops/dynamic_quant.h
66index 4cb446c3..963dfb37 100644
67--- a/mindspore/core/ops/dynamic_quant.h
68+++ b/mindspore/core/ops/dynamic_quant.h
69@@ -91,6 +91,16 @@ class MIND_API DynamicQuant : public BaseOperator {
70   ///
71   /// \return Whether transpose matrix.
72   bool get_transpose() const;
73+
74+  /// \brief Method to set prefer_axis attribute.
75+  ///
76+  /// \param[in] prefer_axis Define the preferred axis.
77+  void set_prefer_axes(const std::vector<int> &prefer_axes);
78+
79+  /// \brief Method to get prefer_axis attribute.
80+  ///
81+  /// \return the preferred axis.
82+  std::vector<int> get_prefer_axes() const;
83 };
84 MIND_API abstract::AbstractBasePtr DynamicQuantInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
85                                                      const std::vector<abstract::AbstractBasePtr> &input_args);
86diff --git a/mindspore/core/ops/op_name.h b/mindspore/core/ops/op_name.h
87index ad9066e7..1282e6ea 100644
88--- a/mindspore/core/ops/op_name.h
89+++ b/mindspore/core/ops/op_name.h
90@@ -410,6 +410,7 @@ constexpr auto KCurrChunkIndex = "curr_chunk_index";
91 constexpr auto KCurrBitCount = "curr_bit_count";
92 constexpr auto KTableLog = "table_log";
93 constexpr auto kIgnoreIndex = "ignore_index";
94+constexpr auto kPreferAxes = "prefer_axes";
95 
96 constexpr size_t kInputIndex0 = 0;
97 constexpr size_t kInputIndex1 = 1;
98diff --git a/mindspore/lite/schema/inner/ops_generated.h b/mindspore/lite/schema/inner/ops_generated.h
99index 6c861aa5..b595f4b2 100644
100--- a/mindspore/lite/schema/inner/ops_generated.h
101+++ b/mindspore/lite/schema/inner/ops_generated.h
102@@ -19790,6 +19790,7 @@ struct DynamicQuantT : public flatbuffers::NativeTable {
103   bool activation_channel = false;
104   int64_t prefer_axis = 0;
105   bool transpose = false;
106+  std::vector<int32_t> prefer_axes{};
107 };
108 
109 struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
110@@ -19803,7 +19804,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
111     VT_DST_TYPE = 6,
112     VT_ACTIVATION_CHANNEL = 8,
113     VT_PREFER_AXIS = 10,
114-    VT_TRANSPOSE = 12
115+    VT_TRANSPOSE = 12,
116+    VT_PREFER_AXES = 14
117   };
118   bool symmetric() const {
119     return GetField<uint8_t>(VT_SYMMETRIC, 0) != 0;
120@@ -19835,6 +19837,12 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
121   bool mutate_transpose(bool _transpose) {
122     return SetField<uint8_t>(VT_TRANSPOSE, static_cast<uint8_t>(_transpose), 0);
123   }
124+  const flatbuffers::Vector<int32_t> *prefer_axes() const {
125+    return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PREFER_AXES);
126+  }
127+  flatbuffers::Vector<int32_t> *mutable_prefer_axes() {
128+    return GetPointer<flatbuffers::Vector<int32_t> *>(VT_PREFER_AXES);
129+  }
130   bool Verify(flatbuffers::Verifier &verifier) const {
131     return VerifyTableStart(verifier) &&
132            VerifyField<uint8_t>(verifier, VT_SYMMETRIC) &&
133@@ -19842,6 +19850,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
134            VerifyField<uint8_t>(verifier, VT_ACTIVATION_CHANNEL) &&
135            VerifyField<int64_t>(verifier, VT_PREFER_AXIS) &&
136            VerifyField<uint8_t>(verifier, VT_TRANSPOSE) &&
137+           VerifyOffset(verifier, VT_PREFER_AXES) &&
138+           verifier.VerifyVector(prefer_axes()) &&
139            verifier.EndTable();
140   }
141   DynamicQuantT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
142@@ -19868,6 +19878,9 @@ struct DynamicQuantBuilder {
143   void add_transpose(bool transpose) {
144     fbb_.AddElement<uint8_t>(DynamicQuant::VT_TRANSPOSE, static_cast<uint8_t>(transpose), 0);
145   }
146+  void add_prefer_axes(flatbuffers::Offset<flatbuffers::Vector<int32_t>> prefer_axes) {
147+    fbb_.AddOffset(DynamicQuant::VT_PREFER_AXES, prefer_axes);
148+  }
149   explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb)
150         : fbb_(_fbb) {
151     start_ = fbb_.StartTable();
152@@ -19885,16 +19898,37 @@ inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(
153     int64_t dst_type = 32LL,
154     bool activation_channel = false,
155     int64_t prefer_axis = 0,
156-    bool transpose = false) {
157+    bool transpose = false,
158+    flatbuffers::Offset<flatbuffers::Vector<int32_t>> prefer_axes = 0) {
159   DynamicQuantBuilder builder_(_fbb);
160   builder_.add_prefer_axis(prefer_axis);
161   builder_.add_dst_type(dst_type);
162+  builder_.add_prefer_axes(prefer_axes);
163   builder_.add_transpose(transpose);
164   builder_.add_activation_channel(activation_channel);
165   builder_.add_symmetric(symmetric);
166   return builder_.Finish();
167 }
168 
169+inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuantDirect(
170+    flatbuffers::FlatBufferBuilder &_fbb,
171+    bool symmetric = false,
172+    int64_t dst_type = 32LL,
173+    bool activation_channel = false,
174+    int64_t prefer_axis = 0,
175+    bool transpose = false,
176+    const std::vector<int32_t> *prefer_axes = nullptr) {
177+  auto prefer_axes__ = prefer_axes ? _fbb.CreateVector<int32_t>(*prefer_axes) : 0;
178+  return mindspore::schema::CreateDynamicQuant(
179+      _fbb,
180+      symmetric,
181+      dst_type,
182+      activation_channel,
183+      prefer_axis,
184+      transpose,
185+      prefer_axes__);
186+}
187+
188 flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(flatbuffers::FlatBufferBuilder &_fbb, const DynamicQuantT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
189 
190 struct LSTMGradDataT : public flatbuffers::NativeTable {
191@@ -26903,6 +26937,7 @@ inline void DynamicQuant::UnPackTo(DynamicQuantT *_o, const flatbuffers::resolve
192   { auto _e = activation_channel(); _o->activation_channel = _e; }
193   { auto _e = prefer_axis(); _o->prefer_axis = _e; }
194   { auto _e = transpose(); _o->transpose = _e; }
195+  { auto _e = prefer_axes(); if (_e) { _o->prefer_axes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->prefer_axes[_i] = _e->Get(_i); } } }
196 }
197 
198 inline flatbuffers::Offset<DynamicQuant> DynamicQuant::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DynamicQuantT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
199@@ -26918,13 +26953,15 @@ inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(flatbuffers::FlatBuf
200   auto _activation_channel = _o->activation_channel;
201   auto _prefer_axis = _o->prefer_axis;
202   auto _transpose = _o->transpose;
203+  auto _prefer_axes = _o->prefer_axes.size() ? _fbb.CreateVector(_o->prefer_axes) : 0;
204   return mindspore::schema::CreateDynamicQuant(
205       _fbb,
206       _symmetric,
207       _dst_type,
208       _activation_channel,
209       _prefer_axis,
210-      _transpose);
211+      _transpose,
212+      _prefer_axes);
213 }
214 
215 inline LSTMGradDataT *LSTMGradData::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
216@@ -33509,10 +33546,11 @@ inline const flatbuffers::TypeTable *LSTMTypeTable() {
217     { flatbuffers::ET_LONG, 0, -1 },
218     { flatbuffers::ET_FLOAT, 0, -1 },
219     { flatbuffers::ET_FLOAT, 0, -1 },
220-    { flatbuffers::ET_FLOAT, 0, -1 }
221+    { flatbuffers::ET_FLOAT, 0, -1 },
222+    { flatbuffers::ET_LONG, 0, -1 }
223   };
224   static const flatbuffers::TypeTable tt = {
225-    flatbuffers::ST_TABLE, 9, type_codes, nullptr, nullptr, nullptr, nullptr
226+    flatbuffers::ST_TABLE, 10, type_codes, nullptr, nullptr, nullptr, nullptr
227   };
228   return &tt;
229 }
230@@ -34744,10 +34782,11 @@ inline const flatbuffers::TypeTable *DynamicQuantTypeTable() {
231     { flatbuffers::ET_LONG, 0, -1 },
232     { flatbuffers::ET_BOOL, 0, -1 },
233     { flatbuffers::ET_LONG, 0, -1 },
234-    { flatbuffers::ET_BOOL, 0, -1 }
235+    { flatbuffers::ET_BOOL, 0, -1 },
236+    { flatbuffers::ET_INT, 1, -1 }
237   };
238   static const flatbuffers::TypeTable tt = {
239-    flatbuffers::ST_TABLE, 5, type_codes, nullptr, nullptr, nullptr, nullptr
240+    flatbuffers::ST_TABLE, 6, type_codes, nullptr, nullptr, nullptr, nullptr
241   };
242   return &tt;
243 }
244diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs
245index 920c0d31..153a21d0 100644
246--- a/mindspore/lite/schema/ops.fbs
247+++ b/mindspore/lite/schema/ops.fbs
248@@ -1250,6 +1250,7 @@ table DynamicQuant {
249     activation_channel: bool = false;
250     prefer_axis: long = 0;
251     transpose: bool = false;
252+    prefer_axes: [int];
253 }
254 
255 table LSTMGradData {
256diff --git a/mindspore/lite/schema/ops_generated.h b/mindspore/lite/schema/ops_generated.h
257index 8d387e9d..d2d89bff 100644
258--- a/mindspore/lite/schema/ops_generated.h
259+++ b/mindspore/lite/schema/ops_generated.h
260@@ -13118,7 +13118,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
261     VT_DST_TYPE = 6,
262     VT_ACTIVATION_CHANNEL = 8,
263     VT_PREFER_AXIS = 10,
264-    VT_TRANSPOSE = 12
265+    VT_TRANSPOSE = 12,
266+    VT_PREFER_AXES = 14
267   };
268   bool symmetric() const {
269     return GetField<uint8_t>(VT_SYMMETRIC, 0) != 0;
270@@ -13135,6 +13136,9 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
271   bool transpose() const {
272     return GetField<uint8_t>(VT_TRANSPOSE, 0) != 0;
273   }
274+  const flatbuffers::Vector<int32_t> *prefer_axes() const {
275+    return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PREFER_AXES);
276+  }
277   bool Verify(flatbuffers::Verifier &verifier) const {
278     return VerifyTableStart(verifier) &&
279            VerifyField<uint8_t>(verifier, VT_SYMMETRIC) &&
280@@ -13142,6 +13146,8 @@ struct DynamicQuant FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
281            VerifyField<uint8_t>(verifier, VT_ACTIVATION_CHANNEL) &&
282            VerifyField<int64_t>(verifier, VT_PREFER_AXIS) &&
283            VerifyField<uint8_t>(verifier, VT_TRANSPOSE) &&
284+           VerifyOffset(verifier, VT_PREFER_AXES) &&
285+           verifier.VerifyVector(prefer_axes()) &&
286            verifier.EndTable();
287   }
288 };
289@@ -13165,6 +13171,9 @@ struct DynamicQuantBuilder {
290   void add_transpose(bool transpose) {
291     fbb_.AddElement<uint8_t>(DynamicQuant::VT_TRANSPOSE, static_cast<uint8_t>(transpose), 0);
292   }
293+  void add_prefer_axes(flatbuffers::Offset<flatbuffers::Vector<int32_t>> prefer_axes) {
294+    fbb_.AddOffset(DynamicQuant::VT_PREFER_AXES, prefer_axes);
295+  }
296   explicit DynamicQuantBuilder(flatbuffers::FlatBufferBuilder &_fbb)
297         : fbb_(_fbb) {
298     start_ = fbb_.StartTable();
299@@ -13182,16 +13191,37 @@ inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuant(
300     int64_t dst_type = 32LL,
301     bool activation_channel = false,
302     int64_t prefer_axis = 0,
303-    bool transpose = false) {
304+    bool transpose = false,
305+    flatbuffers::Offset<flatbuffers::Vector<int32_t>> prefer_axes = 0) {
306   DynamicQuantBuilder builder_(_fbb);
307   builder_.add_prefer_axis(prefer_axis);
308   builder_.add_dst_type(dst_type);
309+  builder_.add_prefer_axes(prefer_axes);
310   builder_.add_transpose(transpose);
311   builder_.add_activation_channel(activation_channel);
312   builder_.add_symmetric(symmetric);
313   return builder_.Finish();
314 }
315 
316+inline flatbuffers::Offset<DynamicQuant> CreateDynamicQuantDirect(
317+    flatbuffers::FlatBufferBuilder &_fbb,
318+    bool symmetric = false,
319+    int64_t dst_type = 32LL,
320+    bool activation_channel = false,
321+    int64_t prefer_axis = 0,
322+    bool transpose = false,
323+    const std::vector<int32_t> *prefer_axes = nullptr) {
324+  auto prefer_axes__ = prefer_axes ? _fbb.CreateVector<int32_t>(*prefer_axes) : 0;
325+  return mindspore::schema::CreateDynamicQuant(
326+      _fbb,
327+      symmetric,
328+      dst_type,
329+      activation_channel,
330+      prefer_axis,
331+      transpose,
332+      prefer_axes__);
333+}
334+
335 struct LSTMGradData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
336   typedef LSTMGradDataBuilder Builder;
337   enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
338diff --git a/mindspore/lite/src/common/ops/ops_def.cc b/mindspore/lite/src/common/ops/ops_def.cc
339index baa2497a..1e973362 100644
340--- a/mindspore/lite/src/common/ops/ops_def.cc
341+++ b/mindspore/lite/src/common/ops/ops_def.cc
342@@ -1254,6 +1254,7 @@ OP_ATTR_WITH_VALUE(dst_type, long, 32)
343 OP_ATTR_WITH_VALUE(activation_channel, bool, false)
344 OP_ATTR_WITH_VALUE(prefer_axis, long, 0)
345 OP_ATTR_WITH_VALUE(transpose, bool, false)
346+OP_ATTR(prefer_axes, [int])
347 OP_SCHEMA_DEF_END(DynamicQuant)
348 
349 OP_SCHEMA_DEF(LSTMGradData)
350diff --git a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc
351index 3566f082..8e393320 100644
352--- a/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc
353+++ b/mindspore/lite/src/common/ops/populate/dynamic_quant_populate.cc
354@@ -36,11 +36,27 @@ OpParameter *PopulateDynamicQuantParameter(const void *prim) {
355   memset(param, 0, sizeof(DynamicQuantParameter));
356 
357   param->op_parameter_.type_ = primitive->value_type();
358-  param->dst_type_ = value->dst_type();
359+  param->dst_type_ = static_cast<int>(value->dst_type());
360   param->symmetric_ = value->symmetric();
361-  param->activation_perchannel_ = value->activation_channel();
362-  param->prefer_axis_ = value->prefer_axis();
363-  param->transpose_ = value->transpose();
364+  auto prefer_axes = value->prefer_axes();
365+  if (prefer_axes != nullptr) {
366+    param->axis_num_ = static_cast<int>(prefer_axes->size());
367+    if (param->axis_num_ > MAX_SHAPE_SIZE) {
368+      MS_LOG(ERROR) << "Dynamic quant's prefer_axes's number is more than 8.";
369+      free(param);
370+      return nullptr;
371+    }
372+    for (int i = 0; i < param->axis_num_; ++i) {
373+      param->prefer_axes_[i] = prefer_axes->Get(i);
374+    }
375+    return reinterpret_cast<OpParameter *>(param);
376+  }
377+  auto activation_channel = value->activation_channel();
378+  if (!activation_channel) {
379+    return reinterpret_cast<OpParameter *>(param);
380+  }
381+  param->axis_num_ = 1;
382+  param->prefer_axes_[0] = static_cast<int>(value->prefer_axis());
383   return reinterpret_cast<OpParameter *>(param);
384 }
385 REG_POPULATE(PrimitiveType_DynamicQuant, PopulateDynamicQuantParameter, SCHEMA_CUR);
386diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc
387index e9404ef2..acc43c97 100644
388--- a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc
389+++ b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.cc
390@@ -14,14 +14,16 @@
391  * limitations under the License.
392  */
393 #include "src/litert/kernel/cpu/int8/dynamic_quant.h"
394+#include <set>
395 #include <vector>
396 #include <algorithm>
397 #include "src/litert/kernel_registry.h"
398 #include "schema/model_generated.h"
399 #include "include/errorcode.h"
400-#include "nnacl/dynamic_quant_parameter.h"
401 #include "nnacl/int8/dynamic_quant_int8.h"
402 #include "nnacl/int8/quant_dtype_cast_int8.h"
403+#include "nnacl/fp32/transpose_fp32.h"
404+#include "nnacl/int8/transpose_int8.h"
405 
406 using mindspore::kernel::KERNEL_ARCH;
407 using mindspore::lite::KernelRegistrar;
408@@ -44,19 +46,10 @@ int DynamicQuantCPUKernel::Prepare() {
409   CHECK_NULL_RETURN(in_tensor);
410   auto out_tensor = out_tensors_.front();
411   CHECK_NULL_RETURN(out_tensor);
412-  auto param = reinterpret_cast<DynamicQuantParameter *>(op_parameter_);
413-  CHECK_NULL_RETURN(param);
414-  src_dtype_ = in_tensor->data_type();
415-  dst_dtype_ = param->dst_type_;
416-  symmetric_ = param->symmetric_;
417-  activation_perchannel_ = param->activation_perchannel_;
418-  prefer_axis_ = param->prefer_axis_;
419-  transpose_ = param->transpose_;
420-  if (out_tensor->data_type() != dst_dtype_) {
421-    MS_LOG(ERROR) << "param data type and tensor data type do not match.";
422-    return RET_ERROR;
423-  }
424-
425+  param_ = reinterpret_cast<DynamicQuantParameter *>(op_parameter_);
426+  CHECK_NULL_RETURN(param_);
427+  MS_CHECK_TRUE_MSG(param_->dst_type_ == out_tensor->data_type(), lite::RET_ERROR,
428+                    "param data type and tensor data type do not match.");
429   if (!InferShapeDone()) {
430     return RET_OK;
431   }
432@@ -65,71 +58,86 @@ int DynamicQuantCPUKernel::Prepare() {
433 
434 int DynamicQuantCPUKernel::ReSize() {
435   auto in_tensor = in_tensors_.front();
436-  num_unit_ = static_cast<int>(in_tensor->ElementsNum());
437-  if (num_unit_ < kMinNums) {
438-    thread_n_num_ = 1;
439+  auto ele_num = static_cast<int>(in_tensor->ElementsNum());
440+  auto shape = in_tensor->shape();
441+  int segment_num = 1;
442+  if (param_->axis_num_ == 0) {
443+    segment_num = MSMIN(kBucketNums, ele_num / kMinNums);
444   } else {
445-    thread_n_num_ = MSMIN(thread_num_, num_unit_);
446-    // Limit for 8 thread
447-    thread_n_num_ = MSMIN(thread_n_num_, kBucketNums);
448+    std::set<int> prefer_axes;
449+    for (int i = 0; i < param_->axis_num_; ++i) {
450+      int axis = param_->prefer_axes_[i] < 0 ? param_->prefer_axes_[i] + static_cast<int>(shape.size())
451+                                             : param_->prefer_axes_[i];
452+      MS_CHECK_TRUE_MSG(axis >= 0 && axis < static_cast<int>(shape.size()), lite::RET_ERROR,
453+                        "The prefer axis is out of range.");
454+      if (prefer_axes.find(axis) != prefer_axes.end()) {
455+        continue;
456+      }
457+      segment_num *= shape[axis];
458+      (void)prefer_axes.insert(axis);
459+    }
460+    pre_perm_.resize(shape.size());
461+    post_perm_.resize(shape.size());
462+    int pre_point0 = 0;
463+    int pre_point1 = param_->axis_num_;
464+    for (int i = 0; i < static_cast<int>(shape.size()); ++i) {
465+      if (prefer_axes.find(i) != prefer_axes.end()) {
466+        pre_perm_[pre_point0] = i;
467+        post_perm_[i] = pre_point0;
468+        ++pre_point0;
469+      } else {
470+        pre_perm_[pre_point1] = i;
471+        post_perm_[i] = pre_point1;
472+        ++pre_point1;
473+      }
474+    }
475   }
476-
477-  int min_max_array_size = 0;
478-  if (activation_perchannel_) {
479-    auto dims = in_tensor->shape();
480-    prefer_axis_ = (prefer_axis_ < 0) ? prefer_axis_ + dims.size() : prefer_axis_;
481-    channel_num_ = dims[prefer_axis_];
482-    MS_CHECK_GT(channel_num_, 0, RET_ERROR);
483-    scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
484-    MS_CHECK_TRUE_MSG(scale_ != nullptr, RET_ERROR, "Malloc scale_ failed.");
485-    zero_point_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
486-    MS_CHECK_TRUE_MSG(zero_point_ != nullptr, RET_ERROR, "Malloc zero_point_ failed.");
487-    size_t last_axis = dims.size() - 1;
488-    row_length_ = dims[last_axis];
489-    channel_length_ = num_unit_ / channel_num_;
490-    thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
491-    if (!transpose_ && channel_length_ > thread_n_stride_) {
492-      thread_n_num_ = 1;
493+  need_transpose_ = false;
494+  for (size_t i = 0; i < pre_perm_.size(); ++i) {
495+    if (pre_perm_[i] != static_cast<int>(i)) {
496+      need_transpose_ = true;
497     }
498-    min_max_array_size = channel_num_;
499-  } else {
500-    min_max_array_size = kBucketNums;
501   }
502-  real_min_ = reinterpret_cast<float *>(malloc(min_max_array_size * sizeof(float)));
503-  real_max_ = reinterpret_cast<float *>(malloc(min_max_array_size * sizeof(float)));
504-  if (real_min_ == nullptr || real_max_ == nullptr) {
505-    return RET_NULL_PTR;
506+  if (segment_num <= 0) {
507+    segment_num = 1;
508   }
509-  for (int i = 0; i < min_max_array_size; ++i) {
510+  real_min_.resize(segment_num);
511+  real_max_.resize(segment_num);
512+  scale_.resize(segment_num);
513+  zero_point_.resize(segment_num);
514+  for (int i = 0; i < segment_num; ++i) {
515     real_min_[i] = FLT_MAX;
516     real_max_[i] = -FLT_MAX;
517   }
518-  MS_CHECK_GT(thread_n_num_, 0, RET_ERROR);
519-  thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
520+  thread_num_ = MSMIN(segment_num, op_parameter_->thread_num_);
521+  unit_num_ = UP_DIV(ele_num, segment_num);
522+  unit_segment_num_ = UP_DIV(segment_num, thread_num_);
523   return RET_OK;
524 }
525 
526 int DynamicQuantCPUKernel::CalculateMinMax(int task_id) {
527-  int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
528-  if (num_unit_thread <= 0) {
529-    return RET_OK;
530-  }
531-  int thread_offset = task_id * thread_n_stride_;
532-  float *data = float32_ptr_ + thread_offset;
533-  if (activation_perchannel_) {
534-    if (transpose_) {
535-      MS_LOG(INFO) << "attribute transpose is true.";
536-      CalculateChannelColMinMax(data, num_unit_thread, real_min_, real_max_, row_length_);
537-    } else {
538-      int channel_offset = task_id * thread_n_stride_ / channel_length_;
539-      float *real_min = real_min_ + channel_offset;
540-      float *real_max = real_max_ + channel_offset;
541-      CalculateChannelRowMinMax(data, num_unit_thread, real_min, real_max, row_length_);
542+  int task_unit = unit_segment_num_ * unit_num_;
543+  int offset = task_id * task_unit;
544+  int ele_num = static_cast<int>(in_tensors_.front()->ElementsNum());
545+  int remain = ele_num - offset;
546+  if (task_unit <= remain) {
547+    for (int i = 0; i < unit_segment_num_; ++i) {
548+      CalculateMinMaxFp32(float32_ptr_ + offset + i * unit_num_, unit_num_, &real_min_[task_id * unit_segment_num_ + i],
549+                          &real_max_[task_id * unit_segment_num_ + i]);
550     }
551   } else {
552-    float *real_min = real_min_ + task_id;
553-    float *real_max = real_max_ + task_id;
554-    CalculateMinMaxFp32(data, num_unit_thread, real_min, real_max);
555+    int segment_num = remain / unit_num_;
556+    int remain_ele_num = remain - segment_num * unit_num_;
557+    for (int i = 0; i < segment_num; ++i) {
558+      CalculateMinMaxFp32(float32_ptr_ + offset + i * unit_num_, unit_num_, &real_min_[task_id * unit_segment_num_ + i],
559+                          &real_max_[task_id * unit_segment_num_ + i]);
560+    }
561+    if (remain_ele_num == 0) {
562+      return RET_OK;
563+    }
564+    CalculateMinMaxFp32(float32_ptr_ + offset + segment_num * unit_num_, remain_ele_num,
565+                        &real_min_[task_id * unit_segment_num_ + segment_num],
566+                        &real_max_[task_id * unit_segment_num_ + segment_num]);
567   }
568   return RET_OK;
569 }
570@@ -148,7 +156,7 @@ int CalculateMinMaxRun(void *cdata, int task_id, float, float) {
571 void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() {
572   float real_min = FLT_MAX;
573   float real_max = -FLT_MAX;
574-  for (int i = 0; i < kBucketNums; i++) {
575+  for (size_t i = 0; i < real_min_.size(); ++i) {
576     real_min = (real_min_[i] < real_min) ? real_min_[i] : real_min;
577     real_max = (real_max_[i] > real_max) ? real_max_[i] : real_max;
578   }
579@@ -158,7 +166,7 @@ void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() {
580   int zp = 0;
581   constexpr int kQSymmetricRange = 255;
582   constexpr int kQAsymmetricRange = 254;
583-  if (!symmetric_) {
584+  if (!param_->symmetric_) {
585     auto range = real_max - real_min;
586     if (range <= 0) {
587       range = kDefaultRange;
588@@ -175,12 +183,11 @@ void DynamicQuantCPUKernel::CalculatePerlayerScaleZp() {
589   quant_parm.bitNum = k8Bit;
590   quant_parm.inited = true;
591   this->out_tensors_.front()->set_quant_params({quant_parm});
592-  return;
593 }
594 
595 void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() {
596   std::vector<lite::LiteQuantParam> quant_params;
597-  for (int i = 0; i < channel_num_; ++i) {
598+  for (size_t i = 0; i < real_min_.size(); ++i) {
599     float real_min = real_min_[i];
600     float real_max = real_max_[i];
601 
602@@ -189,7 +196,7 @@ void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() {
603     int zp = 0;
604     constexpr int kQSymmetricRange = 255;
605     constexpr int kQAsymmetricRange = 254;
606-    if (!symmetric_) {
607+    if (!param_->symmetric_) {
608       auto range = real_max - real_min;
609       if (range <= 0) {
610         range = kDefaultRange;
611@@ -208,40 +215,34 @@ void DynamicQuantCPUKernel::CalculatePerChannelScaleZp() {
612     quant_params.push_back(quant_parm);
613   }
614   this->out_tensors_.front()->set_quant_params(quant_params);
615-  return;
616 }
617+
618 int DynamicQuantCPUKernel::QuantData(int task_id) {
619-  int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
620-  MS_CHECK_GT(num_unit_thread, 0, RET_ERROR);
621-  TypeId data_type = out_tensors_.front()->data_type();
622-  if (data_type != TypeId::kNumberTypeInt8) {
623-    MS_LOG(ERROR) << "Data type not supported:" << data_type;
624-    return RET_PARAM_INVALID;
625-  }
626-  int thread_offset = task_id * thread_n_stride_;
627-  int ret;
628-  if (activation_perchannel_) {
629-    MS_CHECK_EQ(out_tensors_.front()->quant_params().size(), static_cast<size_t>(channel_num_), RET_ERROR);
630-    for (int i = 0; i < channel_num_; i++) {
631-      auto quant_arg = out_tensors_.front()->quant_params().at(i);
632-      scale_[i] = quant_arg.scale;
633-      zero_point_[i] = quant_arg.zeroPoint;
634-    }
635-    if (transpose_) {
636-      ret = DoChannelColFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, scale_, zero_point_,
637-                                   num_unit_thread, row_length_, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
638-    } else {
639-      ret = DoChannelRowFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, scale_, zero_point_,
640-                                   num_unit_thread, row_length_, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
641-    }
642-  } else {
643+  int task_unit = unit_segment_num_ * unit_num_;
644+  int offset = task_id * task_unit;
645+  int ele_num = static_cast<int>(in_tensors_.front()->ElementsNum());
646+  int remain = ele_num - offset;
647+  task_unit = MSMIN(task_unit, remain);
648+  if (param_->axis_num_ == 0) {  // per-tensor
649     auto quant_arg = out_tensors_.front()->quant_params().front();
650-    ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale,
651-                               quant_arg.zeroPoint, num_unit_thread, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
652+    auto ret = DoQuantizeFp32ToInt8(float32_ptr_ + offset, int8_ptr_ + offset, quant_arg.scale, quant_arg.zeroPoint,
653+                                    task_unit, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
654+    if (ret != RET_OK) {
655+      MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
656+      return RET_ERROR;
657+    }
658+    return RET_OK;
659   }
660-  if (ret != RET_OK) {
661-    MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
662-    return RET_ERROR;
663+  int segment_num = task_unit / unit_num_;
664+  for (int i = 0; i < segment_num; ++i) {
665+    auto quant_arg = out_tensors_.front()->quant_params()[task_id * unit_segment_num_ + i];
666+    auto ret =
667+      DoQuantizeFp32ToInt8(float32_ptr_ + offset + i * unit_num_, int8_ptr_ + offset + i * unit_num_, quant_arg.scale,
668+                           quant_arg.zeroPoint, unit_num_, (int32_t)INT8_MIN, (int32_t)INT8_MAX);
669+    if (ret != RET_OK) {
670+      MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
671+      return RET_ERROR;
672+    }
673   }
674   return RET_OK;
675 }
676@@ -257,26 +258,110 @@ int QuantDataRun(void *cdata, int task_id, float, float) {
677   return RET_OK;
678 }
679 
680+int DynamicQuantCPUKernel::MallocTmpBuffer() {
681+  auto in_size = in_tensors_.front()->Size();
682+  auto out_size = out_tensors_.front()->Size();
683+  if (ms_context_ != nullptr && ms_context_->allocator != nullptr) {
684+    int8_ptr_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(in_size + out_size));
685+  } else {
686+    int8_ptr_ = static_cast<int8_t *>(malloc(in_size + out_size));
687+  }
688+  MS_CHECK_TRUE_MSG(int8_ptr_ != nullptr, lite::RET_NULL_PTR, "DynamicQuant malloc tmp buffer failed.");
689+  float32_ptr_ = reinterpret_cast<float *>(int8_ptr_ + out_size);
690+  return lite::RET_OK;
691+}
692+
693+void DynamicQuantCPUKernel::FreeTmpBuffer() {
694+  if (need_transpose_) {
695+    if (int8_ptr_ != nullptr) {
696+      if (ms_context_ != nullptr && ms_context_->allocator != nullptr) {
697+        ms_context_->allocator->Free(int8_ptr_);
698+      } else {
699+        free(int8_ptr_);
700+      }
701+    }
702+  }
703+  int8_ptr_ = nullptr;
704+  float32_ptr_ = nullptr;
705+}
706+
707 int DynamicQuantCPUKernel::Run() {
708-  int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data());
709-  float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data());
710-  CHECK_NULL_RETURN(int8_ptr_);
711-  CHECK_NULL_RETURN(float32_ptr_);
712-  auto ret = ParallelLaunch(this->ms_context_, CalculateMinMaxRun, this, thread_n_num_);
713+  std::vector<int> transpose_shape;
714+  if (need_transpose_) {
715+    auto shape = in_tensors_.front()->shape();
716+    transpose_shape.resize(shape.size());
717+    for (size_t i = 0; i < shape.size(); ++i) {
718+      transpose_shape[i] = shape[pre_perm_[i]];
719+    }
720+    if (MallocTmpBuffer() != lite::RET_OK) {
721+      MS_LOG(ERROR) << "DynamicQuant MallocTmpBuffer failed.";
722+      return lite::RET_NULL_PTR;
723+    }
724+    std::vector<int> strides(shape.size(), 1);
725+    std::vector<int> out_strides(shape.size(), 1);
726+    for (int i = static_cast<int>(shape.size()) - C2NUM; i >= 0; i--) {
727+      strides[i] = shape[i + 1] * strides[i + 1];
728+      out_strides[i] = transpose_shape[i + 1] * out_strides[i + 1];
729+    }
730+    if (shape.size() <= C6NUM) {
731+      (void)DoTransposeFp32(in_tensors_.front()->data(), float32_ptr_, transpose_shape.data(), pre_perm_.data(),
732+                            strides.data(), out_strides.data(), in_tensors_.front()->Size(), shape.size());
733+    } else {
734+      TransposeDimsFp32(in_tensors_.front()->data(), float32_ptr_, transpose_shape.data(), pre_perm_.data(),
735+                        strides.data(), out_strides.data(), shape.size(), 0, 1);
736+    }
737+  } else {
738+    int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->data());
739+    float32_ptr_ = reinterpret_cast<float *>(in_tensors_[0]->data());
740+  }
741+  if (int8_ptr_ == nullptr || float32_ptr_ == nullptr) {
742+    FreeTmpBuffer();
743+    MS_LOG(ERROR) << "DynamicQuant's original data exists nullptr.";
744+    return lite::RET_NULL_PTR;
745+  }
746+  auto ret = ParallelLaunch(this->ms_context_, CalculateMinMaxRun, this, thread_num_);
747   if (ret != RET_OK) {
748+    FreeTmpBuffer();
749     MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
750     return RET_ERROR;
751   }
752-  if (activation_perchannel_) {
753+  if (param_->axis_num_ != 0) {
754     CalculatePerChannelScaleZp();
755   } else {
756     CalculatePerlayerScaleZp();
757   }
758-  ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_n_num_);
759+  ret = ParallelLaunch(this->ms_context_, QuantDataRun, this, thread_num_);
760   if (ret != RET_OK) {
761+    FreeTmpBuffer();
762     MS_LOG(ERROR) << "Run error error_code[" << ret << "]";
763     return RET_ERROR;
764   }
765+  if (need_transpose_) {
766+    auto out_shape = out_tensors_.front()->shape();
767+    TransposeParameter trans_parameter;
768+    (void)memset(&trans_parameter, 0, sizeof(TransposeParameter));
769+    trans_parameter.op_parameter_.thread_num_ = 1;
770+    trans_parameter.num_axes_ = static_cast<int>(out_shape.size());
771+    trans_parameter.data_num_ = out_tensors_[0]->ElementsNum();
772+    trans_parameter.perm_size_ = post_perm_.size();
773+    int last_index = static_cast<int>(out_shape.size()) - 1;
774+    trans_parameter.perm_[last_index] = post_perm_[last_index];
775+    trans_parameter.strides_[last_index] = 1;
776+    trans_parameter.out_strides_[last_index] = 1;
777+    for (int i = last_index - 1; i >= 0; i--) {
778+      trans_parameter.perm_[i] = post_perm_[i];
779+      trans_parameter.strides_[i] = transpose_shape[i + 1] * trans_parameter.strides_[i + 1];
780+      trans_parameter.out_strides_[i] = out_shape[i + 1] * trans_parameter.out_strides_[i + 1];
781+    }
782+    if (out_shape.size() <= C6NUM) {
783+      (void)DoTransposeInt8(int8_ptr_, reinterpret_cast<int8_t *>(out_tensors_[0]->data()), out_shape.data(),
784+                            &trans_parameter);
785+    } else {
786+      TransposeDimsInt8(int8_ptr_, reinterpret_cast<int8_t *>(out_tensors_[0]->data()), out_shape.data(),
787+                        &trans_parameter, 0, 1);
788+    }
789+  }
790+  FreeTmpBuffer();
791   return RET_OK;
792 }
793 
794diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h
795index ca84f088..023f1fab 100644
796--- a/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h
797+++ b/mindspore/lite/src/litert/kernel/cpu/int8/dynamic_quant.h
798@@ -21,31 +21,15 @@
799 #include <cfloat>
800 #include <map>
801 #include "src/litert/lite_kernel.h"
802+#include "nnacl/dynamic_quant_parameter.h"
803 
804 namespace mindspore::kernel {
805 class DynamicQuantCPUKernel : public LiteKernel {
806  public:
807   DynamicQuantCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
808                         const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
809-      : LiteKernel(parameter, inputs, outputs, ctx), thread_num_(ctx->thread_num_) {}
810-  ~DynamicQuantCPUKernel() override {
811-    if (real_min_ != nullptr) {
812-      free(real_min_);
813-      real_min_ = nullptr;
814-    }
815-    if (real_max_ != nullptr) {
816-      free(real_max_);
817-      real_max_ = nullptr;
818-    }
819-    if (scale_ != nullptr) {
820-      free(scale_);
821-      scale_ = nullptr;
822-    }
823-    if (zero_point_ != nullptr) {
824-      free(zero_point_);
825-      zero_point_ = nullptr;
826-    }
827-  };
828+      : LiteKernel(parameter, inputs, outputs, ctx) {}
829+  ~DynamicQuantCPUKernel() override = default;
830 
831   int Prepare() override;
832   int ReSize() override;
833@@ -57,28 +41,21 @@ class DynamicQuantCPUKernel : public LiteKernel {
834  private:
835   void CalculatePerlayerScaleZp();
836   void CalculatePerChannelScaleZp();
837-
838- private:
839-  int thread_num_;
840-  int thread_n_num_{0};
841-  int thread_n_stride_{0};
842-  int num_unit_{0};
843-  int8_t *int8_ptr_ = nullptr;
844-  float *float32_ptr_ = nullptr;
845-  float *real_min_ = nullptr;
846-  float *real_max_ = nullptr;
847-  float *scale_ = nullptr;
848-  int32_t *zero_point_ = nullptr;
849-
850-  int32_t src_dtype_{0};
851-  int32_t dst_dtype_{0};
852-  bool symmetric_ = false;
853-  bool activation_perchannel_ = false;
854-  bool transpose_ = false;
855-  int32_t prefer_axis_{-1};
856-  int32_t channel_num_{0};
857-  int32_t channel_length_{0};
858-  int32_t row_length_{0};
859+  int MallocTmpBuffer();
860+  void FreeTmpBuffer();
861+
862+  DynamicQuantParameter *param_{nullptr};
863+  std::vector<float> real_min_;
864+  std::vector<float> real_max_;
865+  std::vector<float> scale_;
866+  std::vector<float> zero_point_;
867+  std::vector<int> pre_perm_;
868+  std::vector<int> post_perm_;
869+  int8_t *int8_ptr_{nullptr};
870+  float *float32_ptr_{nullptr};
871+  int unit_num_{0};
872+  int unit_segment_num_{0};
873+  bool need_transpose_{false};
874 };
875 }  // namespace mindspore::kernel
876 
877diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc
878index adae37aa..bab1f730 100644
879--- a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc
880+++ b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.cc
881@@ -54,12 +54,12 @@ void MatmulDynamicBaseInt8CPUKernel::FreeQuantParam() {
882 }
883 
884 int MatmulDynamicBaseInt8CPUKernel::MallocQuantParam() {
885-  quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
886+  quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulDynamicQuantParameter)));
887   if (quant_param_ == nullptr) {
888     MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
889     return RET_ERROR;
890   }
891-  memset(quant_param_, 0, sizeof(MatmulQuantParameter));
892+  (void)memset(quant_param_, 0, sizeof(MatmulDynamicQuantParameter));
893   return RET_OK;
894 }
895 
896@@ -80,9 +80,16 @@ int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() {
897     MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
898     return RET_ERROR;
899   }
900-  int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
901   filter_per_channel_ = (weight_quant_params.size() > 1);
902-  auto channel_num = filter_per_channel_ ? col : 1;
903+  filter_per_batch_channel_ = false;
904+  int channel_num = 1;
905+  if (filter_per_channel_) {
906+    channel_num = param_->col_;
907+    if (weight_quant_params.size() > static_cast<size_t>(channel_num)) {
908+      filter_per_batch_channel_ = true;
909+      channel_num = in_tensors_.at(kWeightIndex)->ElementsNum() / param_->deep_;
910+    }
911+  }
912   if (static_cast<int>(weight_quant_params.size()) != channel_num) {
913     MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
914                   << " != channel_num:" << channel_num;
915@@ -90,10 +97,10 @@ int MatmulDynamicBaseInt8CPUKernel::InitFilterQuantParam() {
916   }
917   quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num * sizeof(float)));
918   CHECK_NULL_RETURN(quant_param_->filter_scale_);
919-  memset(quant_param_->filter_scale_, 0, sizeof(channel_num));
920+  (void)memset(quant_param_->filter_scale_, 0, sizeof(channel_num));
921   quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num * sizeof(int32_t)));
922   CHECK_NULL_RETURN(quant_param_->filter_zp_);
923-  memset(quant_param_->filter_zp_, 0, sizeof(channel_num));
924+  (void)memset(quant_param_->filter_zp_, 0, sizeof(channel_num));
925 
926   for (int i = 0; i < channel_num; i++) {
927     quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
928@@ -143,7 +150,15 @@ int MatmulDynamicBaseInt8CPUKernel::InitInputQuantParam(std::vector<float> *scal
929     return RET_ERROR;
930   }
931   input_per_channel_ = (in_quant_params.size() > 1);
932-  auto channel_num = input_per_channel_ ? param_->row_ : 1;
933+  input_per_batch_channel_ = false;
934+  int channel_num = 1;
935+  if (input_per_channel_) {
936+    channel_num = param_->row_;
937+    if (in_quant_params.size() > static_cast<size_t>(channel_num)) {
938+      input_per_batch_channel_ = true;
939+      channel_num = in_tensors_.at(kInputIndex)->ElementsNum() / param_->deep_;
940+    }
941+  }
942   if (static_cast<int>(in_quant_params.size()) != channel_num) {
943     MS_LOG(ERROR) << in_tensors_.at(kInputIndex)->tensor_name() << " quant params size:" << in_quant_params.size()
944                   << " != channel_num:" << channel_num;
945@@ -199,7 +214,7 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixABuffer() {
946     return lite::RET_NULL_PTR;
947   }
948   input_sums_ = reinterpret_cast<int *>(pack_a_ptr_ + pack_a_size);
949-  memset(pack_a_ptr_, 0, pack_a_size + sum_a_size);
950+  (void)memset(pack_a_ptr_, 0, pack_a_size + sum_a_size);
951   return RET_OK;
952 }
953 
954@@ -240,8 +255,8 @@ int MatmulDynamicBaseInt8CPUKernel::InitMatrixBBuffer() {
955     FreeTmpBuffer();
956     return RET_ERROR;
957   }
958-  memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
959-  memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int));
960+  (void)memset(pack_b_ptr_, 0, b_batch_ * param_->col_align_ * param_->deep_align_ * sizeof(int8_t));
961+  (void)memset(weight_sums_, 0, b_batch_ * param_->col_align_ * sizeof(int));
962   return RET_OK;
963 }
964 
965@@ -258,7 +273,7 @@ int MatmulDynamicBaseInt8CPUKernel::CopyBias() {
966       FreeTmpBuffer();
967       return RET_MEMORY_FAILED;
968     }
969-    memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->Size());
970+    (void)memcpy(bias_ptr_, bias_tensor->data(), bias_tensor->Size());
971   } else {
972     bias_ptr_ = nullptr;
973   }
974@@ -352,6 +367,8 @@ int MatmulDynamicBaseInt8CPUKernel::ReSize() {
975 int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector<int> &a_shape_const,
976                                                         const std::vector<int> &b_shape_const, MatMulParameter *params,
977                                                         std::vector<int> *a_offsets, std::vector<int> *b_offsets) {
978+  CHECK_NULL_RETURN(a_offsets);
979+  CHECK_NULL_RETURN(b_offsets);
980   std::vector<int> a_shape = a_shape_const;
981   if (a_shape.size() < kNCHWDimNumber) {
982     size_t add_nums = kNCHWDimNumber - a_shape.size();
983@@ -370,8 +387,8 @@ int MatmulDynamicBaseInt8CPUKernel::InitBroadcastParams(const std::vector<int> &
984   int batch_sizes[MAX_SHAPE_SIZE] = {0};
985   int a_batch_sizes[MAX_SHAPE_SIZE] = {0};
986   int b_batch_sizes[MAX_SHAPE_SIZE] = {0};
987-  for (int i = a_shape.size() - kCHWDimNumber; i >= 0; --i) {
988-    if (static_cast<int>(a_shape.size() - kCHWDimNumber) == i) {
989+  for (int i = static_cast<int>(a_shape.size()) - kCHWDimNumber; i >= 0; --i) {
990+    if (static_cast<int>(a_shape.size()) - kCHWDimNumber == i) {
991       batch_sizes[i] = std::max(a_shape[i], b_shape[i]);
992       a_batch_sizes[i] = a_shape[i];
993       b_batch_sizes[i] = b_shape[i];
994diff --git a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h
995index 3fc20d80..858affc8 100644
996--- a/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h
997+++ b/mindspore/lite/src/litert/kernel/cpu/int8/matmul_dynamic_base_int8.h
998@@ -58,6 +58,8 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
999   int b_batch_ = 1;
1000   std::vector<int> a_offset_;
1001   std::vector<int> b_offset_;
1002+  int a_quant_offset_ = 0;
1003+  int b_quant_offset_ = 0;
1004   typedef void (*PackFunc)(const int8_t *src, int8_t *dst, int row, int col);
1005   virtual void InitParameter() = 0;
1006   int TransferA();
1007@@ -69,14 +71,15 @@ class MatmulDynamicBaseInt8CPUKernel : public LiteKernel {
1008   int InitMatrixABuffer();
1009   void FreeMatrixABuffer();
1010 
1011- protected:
1012   MatMulParameter *param_ = nullptr;
1013   MatmulDynamicQuantParameter *quant_param_ = nullptr;
1014   int8_t *pack_a_ptr_ = nullptr;
1015   int8_t *pack_b_ptr_ = nullptr;
1016 
1017   bool input_per_channel_ = false;
1018-  bool filter_per_channel_ = true;
1019+  bool input_per_batch_channel_ = false;
1020+  bool filter_per_channel_ = false;
1021+  bool filter_per_batch_channel_ = false;
1022   int8_t *batch_input_ptr_ = nullptr;
1023   int8_t *batch_weight_ptr_ = nullptr;
1024   int8_t *batch_a_ptr_ = nullptr;
1025diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc
1026index 721a1a8c..03113eaa 100644
1027--- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc
1028+++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc
1029@@ -102,7 +102,7 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap
1030   bool symmetric = activation_channel ? true : false;
1031   primitive->set_symmetric(symmetric);
1032   primitive->set_activation_channel(activation_channel);
1033-  if (activation_channel && SetPreferAxis(cnode, index, primitive) != RET_OK) {
1034+  if (activation_channel && SetPreferAxes(cnode, index, primitive) != RET_OK) {
1035     MS_LOG(ERROR) << "Set prefer axis failed, " << cnode->fullname_with_scope();
1036     return RET_ERROR;
1037   }
1038@@ -127,18 +127,25 @@ int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &grap
1039   return RET_OK;
1040 }
1041 
1042-int InsertQuantNodeManager::SetPreferAxis(const CNodePtr &cnode, size_t index,
1043+int InsertQuantNodeManager::SetPreferAxes(const CNodePtr &cnode, size_t index,
1044                                           const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive) {
1045   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
1046   if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) {
1047     auto matmul_prim = api::MakeShared<ops::MatMul>(primitive);
1048     CHECK_NULL_RETURN(matmul_prim);
1049+    auto shape = opt::GetAnfNodeOutputShape(cnode->input(index), 0);
1050+    std::vector<int> prefer_axes;
1051+    for (int i = 0; i < static_cast<int>(shape.size()) - C2NUM; ++i) {
1052+      prefer_axes.push_back(i);
1053+    }
1054     // For MatMul A
1055     if (index == kInputIndex + kPrimOffset) {
1056       if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) {
1057+        prefer_axes.push_back(kLastFisrtIndex);
1058         dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
1059         dynamic_primitive->set_transpose(true);
1060       } else {
1061+        prefer_axes.push_back(kLastSecondIndex);
1062         dynamic_primitive->set_prefer_axis(kLastSecondIndex);
1063         dynamic_primitive->set_transpose(false);
1064       }
1065@@ -146,13 +153,16 @@ int InsertQuantNodeManager::SetPreferAxis(const CNodePtr &cnode, size_t index,
1066     // For MatMul B
1067     if (index == kWeightIndex + kPrimOffset) {
1068       if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) {
1069+        prefer_axes.push_back(kLastSecondIndex);
1070         dynamic_primitive->set_prefer_axis(kLastSecondIndex);
1071         dynamic_primitive->set_transpose(true);
1072       } else {
1073+        prefer_axes.push_back(kLastFisrtIndex);
1074         dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
1075         dynamic_primitive->set_transpose(false);
1076       }
1077     }
1078+    dynamic_primitive->set_prefer_axes(prefer_axes);
1079   } else {
1080     MS_LOG(WARNING) << "cnode don't need prefer axis, cnode name: " << cnode->fullname_with_scope();
1081   }
1082@@ -167,13 +177,17 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
1083     return RET_ERROR;
1084   }
1085   auto input = cnode->input(kInputIndex + kPrimOffset);
1086+  auto weight = cnode->input(kWeightIndex + kPrimOffset);
1087+  if (activation_channel && (input->isa<mindspore::CNode>() || IsGraphInput(input)) &&
1088+      (weight->isa<mindspore::CNode>() || IsGraphInput(weight))) {
1089+    return RET_NOT_SUPPORT;
1090+  }
1091   if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
1092     auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimOffset, activation_channel);
1093     if (ret != RET_OK) {
1094       MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
1095     }
1096   }
1097-  auto weight = cnode->input(kWeightIndex + kPrimOffset);
1098   if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
1099     auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimOffset, activation_channel);
1100     if (ret != RET_OK) {
1101@@ -218,6 +232,9 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
1102       continue;
1103     }
1104     ret = NewDynamicQuantNode(graph, cnode, activation_channel);
1105+    if (ret == RET_NOT_SUPPORT) {
1106+      continue;
1107+    }
1108     if (ret != RET_OK) {
1109       MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed.";
1110       return ret;
1111@@ -684,7 +701,7 @@ int InsertQuantNodeManager::InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func
1112 
1113 int InsertQuantNodeManager::CalculateScaleZPNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1114                                                  size_t input_index, ParameterPtr *scales_node, ParameterPtr *zps_node,
1115-                                                 TypeId src_dtype, TypeId dst_dtype, int axis) {
1116+                                                 TypeId dst_dtype, int axis) {
1117   CHECK_NULL_RETURN(scales_node);
1118   CHECK_NULL_RETURN(zps_node);
1119   auto input_node = cnode->input(input_index);
1120@@ -785,7 +802,7 @@ int InsertQuantNodeManager::InsertAscendAntiQuantNode(const FuncGraphPtr &func_g
1121   CHECK_NULL_RETURN(cast_cnode);
1122   ParameterPtr scales_node;
1123   ParameterPtr zps_node;
1124-  auto ret = CalculateScaleZPNode(func_graph, cnode, input_index, &scales_node, &zps_node, src_dtype, dst_dtype, axis);
1125+  auto ret = CalculateScaleZPNode(func_graph, cnode, input_index, &scales_node, &zps_node, dst_dtype, axis);
1126   if (ret != RET_OK) {
1127     MS_LOG(ERROR) << "Fail to Remove node: " << input_node->fullname_with_scope() << " quant param";
1128     return RET_ERROR;
1129diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h
1130index a46e8c68..6f328485 100644
1131--- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h
1132+++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h
1133@@ -75,13 +75,12 @@ class InsertQuantNodeManager {
1134   int MarkDynamicQuantize(const CNodePtr &cnode);
1135 
1136   int CalculateScaleZPNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
1137-                           ParameterPtr *scales_node, ParameterPtr *zps_node, TypeId src_dtype, TypeId dst_dtype,
1138-                           int axis);
1139+                           ParameterPtr *scales_node, ParameterPtr *zps_node, TypeId dst_dtype, int axis);
1140 
1141   int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index,
1142                                   bool activation_channel = true);
1143 
1144-  int SetPreferAxis(const CNodePtr &cnode, size_t index, const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive);
1145+  int SetPreferAxes(const CNodePtr &cnode, size_t index, const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive);
1146 
1147   int SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node, const CNodePtr &cast_cnode);
1148 
1149-- 
11502.25.1
1151
1152