/**
 * Copyright 2019-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "frontend/operator/composite/list_operation.h"

#include <memory>
#include <string>

#include "abstract/param_validator.h"
#include "frontend/optimizer/opt.h"
#include "include/common/pybind_api/api_register.h"
#include "mindspore/core/ops/sequence_ops.h"
#include "pipeline/jit/fallback.h"
#include "utils/ms_context.h"

namespace mindspore {
// namespace to support composite operators definition
namespace prim {
FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  constexpr size_t list_append_size_expect = 2;
  abstract::CheckArgsSize("ListAppend", args_list, list_append_size_expect);

  AbstractBasePtr obj_arg = args_list[0];
  abstract::AbstractListPtr arg0_list = dyn_cast<abstract::AbstractList>(obj_arg);
  MS_EXCEPTION_IF_NULL(arg0_list);

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("append");
  AnfNodePtr arg0_node = ret->add_parameter();
  AnfNodePtr arg1_node = ret->add_parameter();

  std::vector<AnfNodePtr> elems;
  elems.push_back(NewValueNode(prim::kPrimMakeList));
  size_t arg0_length = arg0_list->size();
  for (size_t i = 0; i < arg0_length; ++i) {
    elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToLong(i))}));
  }
  elems.push_back(arg1_node);

  ret->set_output(ret->NewCNode(elems));
  return ret;
}

FuncGraphPtr ListInsert::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  const size_t list_insert_args_size = 3;
  abstract::CheckArgsSize("ListInsert", args_list, list_insert_args_size);
  AbstractBasePtr index_arg = args_list[0];
  AbstractBasePtr obj_arg = args_list[1];

  abstract::AbstractListPtr arg0_list = dyn_cast<abstract::AbstractList>(index_arg);
  MS_EXCEPTION_IF_NULL(arg0_list);
  size_t list_len = arg0_list->size();
  int64_t len = SizeToLong(list_len);
  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("insert");
  AnfNodePtr arg0_node = ret->add_parameter();
  AnfNodePtr insert_index_node = ret->add_parameter();
  AnfNodePtr insert_obj_node = ret->add_parameter();
  // List inplace operation do not support:
  // 1. The python object of list is not found.
  // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
  if (fallback::EnableFallbackList() && fallback::HasPySeqObject(arg0_list) &&
      scope_name().find("VmapRule") == std::string::npos) {
    MS_LOG(DEBUG) << "Enable inplace operation, convert list insert to InplaceListInsert ops.";
    AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceInsert), arg0_node, insert_index_node,
                                          insert_obj_node};
    auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
    list_inplace_node->set_has_side_effect_node(true);
    ret->set_output(list_inplace_node);
    ret->set_has_side_effect_node(true);
    return ret;
  }

  std::vector<AnfNodePtr> elems;
  elems.push_back(NewValueNode(prim::kPrimMakeList));
  auto obj_arg_value = obj_arg->BuildValue();
  MS_EXCEPTION_IF_NULL(obj_arg_value);
  if (!utils::isa<int64_t>(obj_arg_value)) {
    MS_EXCEPTION(TypeError) << "Integer argument expected, but got " << obj_arg_value->type_name()
                            << " type value: " << obj_arg_value->ToString();
  }
  int64_t index_value = GetValue<int64_t>(obj_arg_value);
  int64_t insert_position = 0;
  if (index_value >= len) {
    insert_position = len;
  } else if (index_value > 0 && index_value < len) {
    insert_position = index_value;
  } else if (index_value < 0 && index_value > -len) {
    insert_position = len + index_value;
  }
  for (int64_t i = 0; i < insert_position; ++i) {
    auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
    elems.push_back(value);
  }
  elems.push_back(insert_obj_node);
  for (int64_t i = insert_position; i < len; ++i) {
    auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
    elems.push_back(value);
  }
  auto out = ret->NewCNode(elems);
  ret->set_output(out);
  return ret;
}

FuncGraphPtr ListPop::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  constexpr size_t list_pop_args_size = 2;
  abstract::CheckArgsSize("ListPop", args_list, list_pop_args_size);
  abstract::AbstractListPtr list_input = dyn_cast<abstract::AbstractList>(args_list[0]);
  AbstractBasePtr pop_index = args_list[1];
  MS_EXCEPTION_IF_NULL(list_input);
  size_t list_len = list_input->size();
  int64_t len = SizeToLong(list_len);
  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("pop");
  AnfNodePtr arg0_node = ret->add_parameter();
  AnfNodePtr pop_index_node = ret->add_parameter();

  std::vector<AnfNodePtr> elems;
  elems.push_back(NewValueNode(prim::kPrimMakeList));
  auto pop_index_value = pop_index->BuildValue();
  if (!utils::isa<int64_t>(pop_index_value)) {
    MS_EXCEPTION(TypeError) << "Integer argument expected, but got " << pop_index_value->type_name()
                            << " type value: " << pop_index_value->ToString();
  }
  int64_t index_value = GetValue<int64_t>(pop_index_value);
  if (index_value >= len || index_value < -1 * len) {
    MS_EXCEPTION(IndexError) << "The pop index out of range.";
  }
  int64_t pop_position = (index_value >= 0) ? index_value : (len + index_value);

  // List inplace operation do not support:
  // 1. The python object of list is not found.
  // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
  if (fallback::EnableFallbackList() && fallback::HasPySeqObject(list_input) &&
      scope_name().find("VmapRule") == std::string::npos) {
    MS_LOG(DEBUG) << "Enable inplace operation, convert list pop to InplaceListPop ops.";
    AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplacePop), arg0_node, pop_index_node};
    auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
    list_inplace_node->set_has_side_effect_node(true);
    auto pop_node = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(pop_position)});
    auto out_node = ret->NewCNode({NewValueNode(prim::kPrimMakeTuple), list_inplace_node, pop_node});
    ret->set_output(out_node);
    ret->set_has_side_effect_node(true);
    return ret;
  }

  for (int64_t i = 0; i < pop_position; ++i) {
    auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
    elems.push_back(value);
  }
  auto pop_node = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(pop_position)});
  for (int64_t i = pop_position + 1; i < len; ++i) {
    auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(i)});
    elems.push_back(value);
  }

  auto new_list = ret->NewCNode(elems);
  auto out = ret->NewCNode({NewValueNode(prim::kPrimMakeTuple), new_list, pop_node});
  ret->set_output(out);
  return ret;
}

FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  abstract::CheckArgsSize("ListClear", args_list, 1);

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("clear");
  (void)ret->add_parameter();

  auto empty_list = std::vector<ValuePtr>();
  ret->set_output(NewValueNode(std::make_shared<ValueList>(empty_list)));
  return ret;
}

FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  constexpr size_t list_extend_args_size = 2;
  abstract::CheckArgsSize("ListExtend", args_list, list_extend_args_size);

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("extend");

  constexpr size_t current_index = 0;
  constexpr size_t extend_index = 1;
  auto abs_current = args_list[current_index];
  auto abs_extend = args_list[extend_index];

  std::vector<AnfNodePtr> elems;
  elems.push_back(NewValueNode(prim::kPrimMakeList));
  auto abs_current_list = dyn_cast<abstract::AbstractList>(abs_current);
  MS_EXCEPTION_IF_NULL(abs_current_list);

  // List inplace operation do not support:
  // 1. The python object of list is not found.
  // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
  if (fallback::EnableFallbackList() && fallback::HasPySeqObject(abs_current_list) &&
      scope_name().find("VmapRule") == std::string::npos) {
    MS_LOG(DEBUG) << "Enable inplace operation, convert list extend to InplaceListExtend ops.";
    AnfNodePtr arg0_node = ret->add_parameter();
    AnfNodePtr arg1_node = ret->add_parameter();
    AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceExtend), arg0_node, arg1_node};
    auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
    list_inplace_node->set_has_side_effect_node(true);
    ret->set_output(list_inplace_node);
    ret->set_has_side_effect_node(true);
    return ret;
  }

  AddNodeToElems(abs_current_list, ret, &elems);
  AddNodeToElems(abs_extend, ret, &elems);

  auto out = ret->NewCNode(elems);
  ret->set_output(out);
  return ret;
}

void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
  AnfNodePtr arg_node = ret->add_parameter();
  if (arg->isa<abstract::AbstractList>()) {
    auto arg_list = dyn_cast<abstract::AbstractList>(arg);
    if (arg_list->dynamic_len()) {
      MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length list.";
    }
    int64_t len = SizeToLong(arg_list->size());
    for (int64_t i = 0; i < len; ++i) {
      auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
      elems->push_back(value);
    }
    return;
  }
  if (arg->isa<abstract::AbstractTuple>()) {
    auto arg_tuple = dyn_cast<abstract::AbstractTuple>(arg);
    if (arg_tuple->dynamic_len()) {
      MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length tuple.";
    }
    int64_t len = SizeToLong(arg_tuple->size());
    for (int64_t i = 0; i < len; ++i) {
      auto value = ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), arg_node, NewValueNode(i)});
      elems->push_back(value);
    }
    return;
  }
  if (arg->isa<abstract::AbstractTensor>()) {
    auto abs_tensor = dyn_cast<abstract::AbstractTensor>(arg);
    auto shape_ptr = abs_tensor->BuildShape();
    MS_EXCEPTION_IF_NULL(shape_ptr);
    auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
    MS_EXCEPTION_IF_NULL(tensor_shape);
    auto shape = tensor_shape->shape();
    if (shape.empty()) {
      MS_LOG(EXCEPTION) << "ListExtend does not support scalar tensor.";
    }
    if (shape[0] < 0) {
      MS_LOG(EXCEPTION) << "ListExtend does not support the tensor whose shapes has an uncertain 0th dimension.";
    }
    int64_t len = shape[0];

    std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
    ValuePtr op = prim::GetPythonOps("getitem", module_name);
    for (int64_t i = 0; i < len; ++i) {
      auto value = ret->NewCNode({NewValueNode(op), arg_node, NewValueNode(i)});
      elems->push_back(value);
    }
    return;
  }
  MS_LOG(EXCEPTION) << "ListExtend supports list, tuple and Tensor, but got " << arg->ToString();
}

FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  abstract::CheckArgsSize("ListReverse", args_list, 1);
  abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(args_list[0]);
  MS_EXCEPTION_IF_NULL(arg_list);
  int64_t arg_length = SizeToLong(arg_list->size());

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("reverse");
  AnfNodePtr arg_node = ret->add_parameter();
  // List inplace operation do not support:
  // 1. The python object of list is not found.
  // 2. List operation is generated by Vmap, will be opened after Vmap function of list inplace operation is provided.
  if (fallback::EnableFallbackList() && fallback::HasPySeqObject(arg_list) &&
      scope_name().find("VmapRule") == std::string::npos) {
    MS_LOG(DEBUG) << "Enable inplace operation, convert list reverse to InplaceListReverse ops.";
    AnfNodePtrList list_inplace_inputs = {NewValueNode(prim::kPrimListInplaceReverse), arg_node};
    auto list_inplace_node = ret->NewCNodeInOrder(list_inplace_inputs);
    list_inplace_node->set_has_side_effect_node(true);
    ret->set_output(list_inplace_node);
    ret->set_has_side_effect_node(true);
    return ret;
  }

  std::vector<AnfNodePtr> elems;
  elems.push_back(NewValueNode(prim::kPrimMakeList));
  for (int64_t i = arg_length - 1; i >= 0; --i) {
    elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)}));
  }

  ret->set_output(ret->NewCNode(elems));
  return ret;
}
}  // namespace prim
}  // namespace mindspore
