/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tools/converter/parser/tf/tf_argmin_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "ops/fusion/arg_min_fusion.h"

namespace mindspore {
namespace lite {
PrimitiveCPtr TFArgMinParser::Parse(const tensorflow::NodeDef &tf_op,
                                    const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
                                    std::vector<std::string> *inputs, int *output_size) {
  auto prim = std::make_unique<ops::ArgMinFusion>();
  MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
  tensorflow::AttrValue attr_value;
  auto tf_axis_input_name = tf_op.input(tf_op.input_size() - 1);
  if (tf_node_map.find(tf_axis_input_name) == tf_node_map.end()) {
    MS_LOG(ERROR) << "not find axis node.";
    return nullptr;
  }
  auto axis_node = tf_node_map.at(tf_axis_input_name);
  MS_CHECK_TRUE_RET(axis_node != nullptr, nullptr);
  if (!TensorFlowUtils::FindAttrValue(*axis_node, "value", &attr_value)) {
    MS_LOG(ERROR) << "The attr value should be specified.";
    return nullptr;
  }
  auto &axis_tensor = attr_value.tensor();
  prim->set_axis(axis_tensor.int_val(0));
  prim->set_out_max_value(false);
  prim->set_top_k(1);

  *output_size = 1;
  if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
    MS_LOG(ERROR) << "add op input failed";
    return nullptr;
  }

  return prim->GetPrim();
}
TFNodeRegistrar g_tfArgMinParser("ArgMin", new TFArgMinParser());
}  // namespace lite
}  // namespace mindspore
