/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "plugin/device/ascend/optimizer/ir_fission/single_batch_norm_fission.h"
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "include/common/debug/anf_ir_dump.h"

namespace mindspore {
namespace opt {
class TestHWSingleBatchNormFission : public BackendCommon {
 public:
  TestHWSingleBatchNormFission() : get_py_fun_("gtest_input.pre_activate.single_batch_norm_fission_test", true) {}
  ~TestHWSingleBatchNormFission() override = default;

  UT::PyFuncGraphFetcher get_py_fun_;
};

TEST_F(TestHWSingleBatchNormFission, test_fission) {
  FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_single_batch_norm_fission", "before");
  EXPECT_NE(g, nullptr);
  std::vector<int64_t> shp_x{32, 64, 112, 112};
  auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
  std::vector<int64_t> shp_y{64};
  auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
  AbstractBasePtrList args_spec_list{x_abstract};
  for (size_t i = 0; i < 4; ++i) {
    args_spec_list.push_back(y_abstract);
  }
  bool origin_mindrt = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT);
  MsContext::GetInstance()->set_param<bool>(MS_CTX_ENABLE_MINDRT, true);
  auto kg = GetKernelGraph(g, args_spec_list);

  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::SingleBatchNormFission>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(kg);

  FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_single_batch_norm_fission", "after");
  EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
  MsContext::GetInstance()->set_param<bool>(MS_CTX_ENABLE_MINDRT, origin_mindrt);
}

TEST_F(TestHWSingleBatchNormFission, test_no_fission) {
  FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_single_batch_norm_fission", "before");
  EXPECT_NE(g, nullptr);
  std::vector<int64_t> shp_x{32, 64, 112, 112};
  auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
  std::vector<int64_t> shp_y{64};
  auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
  AbstractBasePtrList args_spec_list{x_abstract};
  for (size_t i = 0; i < 4; ++i) {
    args_spec_list.push_back(y_abstract);
  }
  bool origin_mindrt = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT);
  MsContext::GetInstance()->set_param<bool>(MS_CTX_ENABLE_MINDRT, true);
  auto kg = GetKernelGraph(g, args_spec_list);
  auto origin_graph = std::make_shared<session::KernelGraph>(*kg);

  auto optimizer = std::make_shared<opt::GraphOptimizer>();
  auto pm = std::make_shared<opt::PassManager>();
  pm->AddPass(std::make_shared<opt::SingleBatchNormFission>());
  optimizer->AddPassManager(pm);
  FuncGraphPtr new_graph = optimizer->Optimize(kg);

  EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
  MsContext::GetInstance()->set_param<bool>(MS_CTX_ENABLE_MINDRT, origin_mindrt);
}
}  // namespace opt
}  // namespace mindspore
