/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/litert/kernel/cpu/int8/leaky_relu_int8.h"
#include "src/litert/kernel_registry.h"

using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LeakyRelu;

namespace mindspore::kernel {
namespace {
int LeakyReluInt8Run(void *cdata, int task_id, float, float) {
  if (cdata == nullptr) {
    MS_LOG(ERROR) << "input cdata is nullptr!";
    return RET_ERROR;
  }
  auto relu = reinterpret_cast<LeakyReluInt8CPUKernel *>(cdata);
  auto ret = relu->DoExecute(task_id);
  if (ret != RET_OK) {
    MS_LOG(ERROR) << "LeakyReluInt8Run task_id " << task_id << " failed.";
    return ret;
  }
  return RET_OK;
}
}  // namespace

int LeakyReluInt8CPUKernel::Prepare() {
  CHECK_LESS_RETURN(in_tensors_.size(), C1NUM);
  CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
  CHECK_NULL_RETURN(in_tensors_[0]);
  CHECK_NULL_RETURN(out_tensors_[0]);
  if (in_tensors_[0]->data_type() != mindspore::kNumberTypeInt8 ||
      out_tensors_[0]->data_type() != mindspore::kNumberTypeInt8) {
    MS_LOG(ERROR) << "Datatype error, input0 data_type is " << in_tensors_[0]->data_type() << ", output data_type is "
                  << out_tensors_[0]->data_type();
    return RET_ERROR;
  }
  quant_prelu_parm_.thread_num_ = op_parameter_->thread_num_;
  quant_prelu_parm_.slope_ = reinterpret_cast<ActivationParameter *>(op_parameter_)->alpha_;

  auto *input_tensor = in_tensors_.at(kInputIndex);
  auto in_quant_args = input_tensor->quant_params();
  MS_CHECK_TRUE_MSG(!in_quant_args.empty(), RET_ERROR, "Input quant param cannot be empty.");
  quant_prelu_parm_.in_args_.scale_ = static_cast<float>(in_quant_args.front().scale);
  quant_prelu_parm_.in_args_.zp_ = in_quant_args.front().zeroPoint;

  auto *out_tensor = out_tensors_.at(kOutputIndex);
  auto out_quant_args = out_tensor->quant_params();
  MS_CHECK_TRUE_MSG(!out_quant_args.empty(), RET_ERROR, "Output quant param cannot be empty.");
  quant_prelu_parm_.out_args_.scale_ = static_cast<float>(out_quant_args.front().scale);
  quant_prelu_parm_.out_args_.zp_ = out_quant_args.front().zeroPoint;

  if (!InferShapeDone()) {
    return RET_OK;
  }
  return ReSize();
}

LeakyReluInt8CPUKernel::~LeakyReluInt8CPUKernel() {}

int LeakyReluInt8CPUKernel::ReSize() {
  auto *input_tensor = in_tensors_.at(kInputIndex);
  auto input_dim = input_tensor->shape().size();
  quant_prelu_parm_.input_dim_ = input_dim;
  MS_CHECK_GT(input_tensor->ElementsNum(), 0, RET_ERROR);
  quant_prelu_parm_.element_num = input_tensor->ElementsNum();
  return RET_OK;
}

int LeakyReluInt8CPUKernel::Run() {
  auto ret = ParallelLaunch(this->ms_context_, LeakyReluInt8Run, this, op_parameter_->thread_num_);
  if (ret != RET_OK) {
    MS_LOG(ERROR) << "RunPreluParam failed. errorcode: ";
  }
  return RET_OK;
}

int LeakyReluInt8CPUKernel::DoExecute(int task_id) {
  auto input_tensor = in_tensors_.at(kInputIndex);
  auto out_tensor = out_tensors_.at(kOutputIndex);
  int8_t *input_data = reinterpret_cast<int8_t *>(input_tensor->data());
  int8_t *output_data = reinterpret_cast<int8_t *>(out_tensor->data());
  MS_ASSERT(input_data);
  MS_ASSERT(output_data);
  auto ret = DoLeakReluInt8(input_data, output_data, &quant_prelu_parm_, task_id);
  if (ret != RET_OK) {
    MS_LOG(ERROR) << "DoLeakReluInt8 failed";
    return RET_ERROR;
  }
  return RET_OK;
}

REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LeakyRelu, LiteKernelCreator<LeakyReluInt8CPUKernel>)
}  // namespace mindspore::kernel
