/**
 * Copyright 2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef MINDSPORE_LITE_SRC_EXTENDRT_EXECUTION_PLAN_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_EXECUTION_PLAN_H_

#include <memory>
#include <utility>
#include <vector>
#include <unordered_map>

#include "infer/execution_plan.h"
#include "src/executor/sub_graph_kernel.h"

namespace mindspore::infer {
class ExecutionPlan : public abstract::ExecutionPlan {
 public:
  ExecutionPlan() = default;
  ~ExecutionPlan() override;

  std::vector<std::shared_ptr<abstract::ExecutionFlow>> GetExecutionFLows() override { return {}; }
  std::vector<abstract::Kernel *> GetKernels() { return kernels_; }

  void SetExecutionFlows(std::vector<std::shared_ptr<abstract::ExecutionFlow>> execution_flows) override {}
  void SetKernels(std::vector<abstract::Kernel *> kernels) { this->kernels_ = std::move(kernels); }

  void AddExecutionFlow(std::shared_ptr<abstract::ExecutionFlow> execution_flow) override {}
  void AddKernel(abstract::Kernel *kernel) override { this->kernels_.emplace_back(kernel); }

  FuncGraphPtr GetFuncGraph() override { return func_graph_; }

  void SetFuncGraph(FuncGraphPtr func_graph) override { func_graph_ = func_graph; }

  std::vector<abstract::Tensor *> GetInputs() override { return inputs_; }

  void SetInputs(const std::vector<abstract::Tensor *> &inputs) override { inputs_ = inputs; }

  std::vector<abstract::Tensor *> GetOutputs() override { return outputs_; }

  void SetOutputs(const std::vector<abstract::Tensor *> &outputs) override { outputs_ = outputs; }

  std::shared_ptr<abstract::Context> GetContext() override { return context_; }

  void SetContext(std::shared_ptr<abstract::Context> context) override { context_ = context; }

  const abstract::KernelCallBack &GetKernelBeforeCallBack() override { return before_; }

  void SetKernelBeforeCallBack(const abstract::KernelCallBack &callback) override { before_ = callback; }

  const abstract::KernelCallBack &GetKernelAfterCallBack() override { return after_; }

  void SetKernelAfterCallBack(const abstract::KernelCallBack &callback) override { after_ = callback; }

  void SetInputsMap(std::unordered_map<abstract::Tensor *, abstract::Tensor *> *input_isolate_map) {
    FreeInputIsolateMap();
    input_isolate_map_ = input_isolate_map;
    own_input_isolate_map_ = false;
  }

  std::unordered_map<abstract::Tensor *, abstract::Tensor *> *GetInputsMap() {
    if (input_isolate_map_ == nullptr) {
      input_isolate_map_ = new std::unordered_map<abstract::Tensor *, abstract::Tensor *>();
      this->own_input_isolate_map_ = true;
    }
    return input_isolate_map_;
  }

  void SetOutputsMap(std::unordered_map<abstract::Tensor *, abstract::Tensor *> *output_isolate_map) {
    FreeOutputIsolateMap();
    output_isolate_map_ = output_isolate_map;
    own_output_isolate_map_ = false;
  }

  std::unordered_map<abstract::Tensor *, abstract::Tensor *> *GetOutputsMap() {
    if (output_isolate_map_ == nullptr) {
      output_isolate_map_ = new std::unordered_map<abstract::Tensor *, abstract::Tensor *>();
      this->own_output_isolate_map_ = true;
    }
    return output_isolate_map_;
  }

  std::vector<abstract::Kernel *> ToKernelList() override;

  bool PrepareKernels();

 private:
  bool MallocTensorData(abstract::Kernel *kernel);
  void FreeInputIsolateMap() {
    if (this->own_input_isolate_map_) {
      delete this->input_isolate_map_;
      this->own_input_isolate_map_ = false;
    }
  }
  void FreeOutputIsolateMap() {
    if (this->own_output_isolate_map_) {
      delete this->output_isolate_map_;
      this->own_output_isolate_map_ = false;
    }
  }

 private:
  std::vector<abstract::Kernel *> kernels_;
  FuncGraphPtr func_graph_;
  std::vector<abstract::Tensor *> inputs_;
  std::vector<abstract::Tensor *> outputs_;
  std::shared_ptr<abstract::Context> context_;
  abstract::KernelCallBack before_;
  abstract::KernelCallBack after_;
  std::unordered_map<abstract::Tensor *, abstract::Tensor *> *input_isolate_map_ = nullptr;
  std::unordered_map<abstract::Tensor *, abstract::Tensor *> *output_isolate_map_ = nullptr;
  bool own_input_isolate_map_{false};
  bool own_output_isolate_map_{false};
  std::vector<abstract::Kernel *> kernel_list_;
};
}  // namespace mindspore::infer

#endif  // MINDSPORE_LITE_SRC_EXTENDRT_EXECUTION_PLAN_H_
