Lines Matching full:param
210 + /// \param[in] device_id The device id.
220 + /// \param[in] performance_mode The performance mode.
230 + /// \param[in] priority The priority.
240 + /// \param[in] is_fp16 Enable float16 inference or not.
250 + /// \param[in] extension array.
278 +/// \param[out] num Number of NNRT device description.
285 +/// \param[in] descs NNRT device description array.
286 +/// \param[in] index Element index.
293 +/// \param[out] num Number of NNRT device description.
300 +/// \param[in] desc NNRT device description array.
305 +/// \param[in] desc pointer to the NNRT device description instance.
312 +/// \param[in] desc pointer to the NNRT device description instance.
319 +/// \param[in] desc pointer to the NNRT device description instance.
326 +/// \param[in] name NNRt device name.
333 +/// \param[in] name NNRt device type.
340 +/// \param[in] device_info Device info object handle.
341 +/// \param[in] device_id NNRT device id.
346 +/// \param[in] device_info Device info object handle.
353 +/// \param[in] device_info Device info object handle.
354 +/// \param[in] device_id NNRT performance mode.
359 +/// \param[in] device_info Device info object handle.
366 +/// \param[in] device_info Device info object handle.
367 +/// \param[in] device_id NNRT priority.
372 +/// \param[in] device_info Device info object handle.
379 +/// \param[in] device_info Device info object handle.
380 +/// \param[in] name The content of key as a C string.
381 +/// \param[in] value The pointer to the value, which is a byte array.
382 +/// \param[in] value_size The size of the value, which is a byte array.
413 +/// \param[in] train_cfg TrainCfg object handle.
418 +/// \param[in] train_cfg TrainCfg object handle.
419 +/// \param[in] num The num of loss_name.
426 +/// \param[in] train_cfg TrainCfg object handle.
427 +/// \param[in] loss_name define part of the name that identify a loss kernel.
428 +/// \param[in] num The num of loss_name.
433 +/// \param[in] train_cfg TrainCfg object handle.
440 +/// \param[in] train_cfg TrainCfg object handle.
441 +/// \param[in] level The optimization level of train_cfg.
446 +/// \param[in] model Model object handle.
447 +/// \param[in] model_data Define the buffer read from a model file.
448 +/// \param[in] data_size Define bytes number of model file buffer.
449 +/// \param[in] model_type Define The type of model file.
450 +/// \param[in] model_context Define the context used to store options during execution.
451 +/// \param[in] train_cfg Define the config used by training.
460 +/// \param[in] model Model object handle.
461 +/// \param[in] model_path Define the model path.
462 +/// \param[in] model_type Define The type of model file.
463 +/// \param[in] model_context Define the context used to store options during execution.
464 +/// \param[in] train_cfg Define the config used by training.
473 +/// \param[in] model Model object handle.
474 +/// \param[in] before CallBack before predict.
475 +/// \param[in] after CallBack after predict.
483 +/// \param[in] learning_rate to set.
495 +/// \param[in] model Model object handle.
502 +/// \param[in] new_weights A vector new weights.
509 +/// \param[in] model Model object handle.
516 +/// \param[in] model Model object handle.
517 +/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode.
524 +/// \param[in] model Model object handle.
525 +/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable.
526 +/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration.
527 +/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration.
535 +/// \param[in] model The model data.
536 +/// \param[in] model_type The model file type.
537 +/// \param[in] model_file The exported model file.
538 +/// \param[in] quantization_type The quantification type.
539 +/// \param[in] export_inference_only Whether to export a reasoning only model.
540 +/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
542 +/// \param[in] num The number of output_tensor_name.
551 +/// \param[in] model The model data.
552 +/// \param[in] model_type The model file type.
553 +/// \param[in] model_data The exported model buffer.
554 +/// \param[in] data_size The exported model buffer size.
555 +/// \param[in] quantization_type The quantification type.
556 +/// \param[in] export_inference_only Whether to export a reasoning only model.
557 +/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
559 +/// \param[in] num The number of output_tensor_name.
568 +/// \param[in] model The model data.
569 +/// \param[in] model_type The model file type.
570 +/// \param[in] weight_file The path of exported weight file.
571 +/// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
572 +/// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
573 +/// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
574 +/// \param[in] num The number of changeable_weights_name.
598 /// \param[in] data A pointer to the data of the tensor.
607 +/// \param[in] tensor Tensor object handle.
608 +/// \param[in] data A pointer to the user data buffer.
609 +/// \param[in] data the byte size of the user data buffer.
616 /// \param[in] tensor Tensor object handle.
698 +/// \param[out] num Number of NNRT device description.
705 +/// \param[in] descs NNRT device description array.
706 +/// \param[in] index Element index.
713 +/// \param[in] desc NNRT device description array.
718 +/// \param[in] desc pointer to the NNRT device description instance.
725 +/// \param[in] desc pointer to the NNRT device description instance.
732 +/// \param[in] desc pointer to the NNRT device description instance.
739 +/// \param[in] name NNRt device name.
746 +/// \param[in] name NNRt device type.
753 +/// \param[in] device_info Device info object handle.
754 +/// \param[in] device_id NNRT device id.
759 +/// \param[in] device_info Device info object handle.
766 +/// \param[in] device_info Device info object handle.
767 +/// \param[in] device_id NNRT performance mode.
772 +/// \param[in] device_info Device info object handle.
779 +/// \param[in] device_info Device info object handle.
780 +/// \param[in] device_id NNRT priority.
785 +/// \param[in] device_info Device info object handle.
792 +/// \param[in] device_info Device info object handle.
793 +/// \param[in] name The content of key as a C string.
794 +/// \param[in] value The pointer to the value, which is a byte array.
795 +/// \param[in] value_size The size of the value, which is a byte array.
824 +/// \param[in] tensor Tensor object handle.
825 +/// \param[in] data A pointer to the user data buffer.
826 +/// \param[in] data the byte size of the user data buffer.
1065 +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1067 + if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) {
1070 + if (param->op_parameter.thread_num_ == 0) {
1073 + int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_);
1075 + int end = MSMIN(begin + unit_per_thread, param->num_unit);
1080 + const float *update_data = update_fp32 + i * param->unit_size;
1084 + SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data);
1085 + for (; j < param->unit_size; j++) {
1098 int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1101 +int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
2347 return reinterpret_cast<OpParameter *>(param);
2351 + auto *param = static_cast<CustomIsInfParameter *>(malloc(sizeof(CustomIsInfParameter)));
2352 + if (param == nullptr) {
2356 + memset(param, 0, sizeof(CustomIsInfParameter));
2357 + param->op_parameter_.type_ = PrimType_Inner_CustomIsInf;
2358 + return reinterpret_cast<OpParameter *>(param);
2362 + auto *param = static_cast<CustomTensorScatterMaxParameter *>(malloc(sizeof(CustomTensorScatterMaxParameter)));
2363 + if (param == nullptr) {
2367 + memset(param, 0, sizeof(CustomTensorScatterMaxParameter));
2368 + param->op_parameter_.type_ = PrimType_Inner_CustomTensorScatterMax;
2369 + return reinterpret_cast<OpParameter *>(param);
2373 + auto *param = static_cast<CustomMaskedFillParameter *>(malloc(sizeof(CustomMaskedFillParameter)));
2374 + if (param == nullptr) {
2378 + memset(param, 0, sizeof(CustomMaskedFillParameter));
2379 + param->op_parameter_.type_ = PrimType_Inner_CustomMaskedFill;
2380 + return reinterpret_cast<OpParameter *>(param);
2391 + auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter)));
2392 + if (param == nullptr) {
2396 + memset(param, 0, sizeof(CustomParameter));
2397 + param->op_parameter_.type_ = PrimType_Inner_ThirdPartyModel;
2399 + param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim));
2400 + return reinterpret_cast<OpParameter *>(param);
2444 MS_LOG(ERROR) << "param is nullptr.";
2455 MS_LOG(ERROR) << "param is nullptr.";
2466 MS_LOG(ERROR) << "param is nullptr.";
2477 MS_LOG(ERROR) << "param is nullptr.";
2499 MS_LOG(ERROR) << "param is nullptr.";
2521 MS_LOG(ERROR) << "param is nullptr.";
2532 MS_LOG(ERROR) << "param is nullptr.";
2543 MS_LOG(ERROR) << "param is nullptr.";
2586 MS_LOG(ERROR) << "param is nullptr.";
2597 MS_LOG(ERROR) << "param is nullptr.";
2614 MS_LOG(ERROR) << "param is nullptr.";
2625 MS_LOG(ERROR) << "param is nullptr.";
2642 MS_LOG(ERROR) << "param is nullptr.";
2653 MS_LOG(ERROR) << "param is nullptr.";
2674 MS_LOG(ERROR) << "param is nullptr.";
2698 MS_LOG(ERROR) << "param is nullptr.";
2712 MS_LOG(ERROR) << "param is nullptr.";
3166 - MS_LOG(ERROR) << "param is nullptr.";
3399 - MS_LOG(ERROR) << "param is nullptr.";
3404 - MS_LOG(ERROR) << "param is invalid.";
3422 - MS_LOG(ERROR) << "param is nullptr.";
3427 - MS_LOG(ERROR) << "param is invalid.";
3445 - MS_LOG(ERROR) << "param is nullptr.";
3476 - MS_LOG(ERROR) << "param is nullptr.";
3517 - MS_LOG(ERROR) << "param is nullptr.";
3526 - MS_LOG(ERROR) << "param is nullptr.";
3535 - MS_LOG(ERROR) << "param is nullptr.";
3544 - MS_LOG(ERROR) << "param is nullptr.";
3885 MS_LOG(ERROR) << "param is nullptr.";
3910 MS_LOG(ERROR) << "param is nullptr.";
3921 MS_LOG(ERROR) << "param is nullptr.";
3932 MS_LOG(ERROR) << "param is nullptr.";
3941 MS_LOG(ERROR) << "param is nullptr.";
3950 MS_LOG(ERROR) << "param is nullptr.";
3959 MS_LOG(ERROR) << "param is nullptr.";
3968 MS_LOG(ERROR) << "param is nullptr.";
3977 MS_LOG(ERROR) << "param is nullptr.";
3988 MS_LOG(ERROR) << "param is nullptr.";
3998 + MS_LOG(ERROR) << "param is nullptr.";
4016 MS_LOG(ERROR) << "param is nullptr.";
4025 MS_LOG(ERROR) << "param is nullptr.";
4034 MS_LOG(ERROR) << "param is nullptr.";
4043 MS_LOG(ERROR) << "param is nullptr.";
6858 -int NetTrain::SetNr(std::function<int(NetTrainFlags *)> param) {
6859 - nr_cb_ = param;
7425 - static int SetNr(std::function<int(NetTrainFlags *)> param);
7638 +int NetTrainBase::SetNr(std::function<int(NetTrainFlags *)> param) {
7639 + nr_cb_ = param;
8169 + static int SetNr(std::function<int(NetTrainFlags *)> param);
9407 + MS_LOG(ERROR) << "Only support fixed shapes in third party param";
9564 +int ThirdPartyParamParser::Parse(const ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param) {
9565 + MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
9567 + auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes));
9569 + MS_LOG(ERROR) << "Parse input shapes of third party param failed";
9573 + ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes));
9575 + MS_LOG(ERROR) << "Parse input dtypes of third party param failed";
9579 + auto input_shape_num = param->input_shapes.size();
9580 + auto input_dtype_num = param->input_dtypes.size();
9587 + ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats));
9589 + MS_LOG(ERROR) << "Parse input formats of third party param failed";
9594 + ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names));
9596 + MS_LOG(ERROR) << "Parse input names of third party param failed";
9600 + ret = DoParseShape(param_string.output_shapes, &(param->output_shapes));
9602 + MS_LOG(ERROR) << "Parse output shaped of third party param failed";
9606 + ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes));
9608 + MS_LOG(ERROR) << "Parse output dtypes of third party param failed";
9612 + auto output_shape_num = param->output_shapes.size();
9613 + auto output_dtype_num = param->output_dtypes.size();
9620 + ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats));
9622 + MS_LOG(ERROR) << "Parse output formats of third party param failed";
9627 + ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names));
9629 + MS_LOG(ERROR) << "Parse output names of third party param failed";
9633 + ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters));
9635 + MS_LOG(ERROR) << "Parse extended parameter of third party param failed";
9678 + static int Parse(const lite::ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param);
9706 MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
9710 + ¶m->thirdPartyModelParam);
9712 + MS_LOG(ERROR) << "Parse third party param failed.";
9715 ret = InitExtendedIntegrationInfo(param, *config_parser);
9720 int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) {
9721 if (param != nullptr) {
9725 - if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) {
9728 - << ", but got " << param->fmk_type;
9731 - if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
9740 + if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) {
9743 + << ", but got " << param->fmk_type;
9746 + if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
9757 converter_parameters.save_type = param->save_type;
9758 converter_parameters.model_file = param->model_file;
9759 converter_parameters.weight_file = param->weight_file;
9760 + converter_parameters.attrs.emplace("config_file", param->config_file);
9764 @@ -447,11 +448,13 @@ STATUS ConverterFuncGraph::Optimize(const std::shared_ptr<ConverterPara> ¶m,
9769 - status = funcgraph_transform.Transform(func_graph, param);
9773 + if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) {
9775 + status = funcgraph_transform.Transform(func_graph, param);
9782 status = UnifyFuncGraphOutputFormat(param, func_graph);
9844 @@ -76,11 +76,55 @@ int QuantTransform(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGrap
9878 int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) {
9879 MS_ASSERT(param != nullptr);
9882 + if (param->fmk_type == converter::kFmkTypeThirdParty) {
9886 + auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes);
9893 + FillGraphInputAndOutputFormats(graph_defT_, *param);
9989 + MS_LOG(ERROR) << "Parse third party model param failed.";