/**
 * Copyright 2020-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "common/common.h"
#include "gtest/gtest.h"

#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/core/tensor.h"

#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
#include "utils/log_adapter.h"

#include <vector>
#include <unordered_set>

using namespace mindspore::dataset;

class MindDataTestDistributedSampler : public UT::Common {
 public:
  class DummyRandomAccessOp : public RandomAccessOp {
   public:
    DummyRandomAccessOp(uint64_t num_rows) {
      // row count is in base class as protected member
      // GetNumRowsInDataset does not need an override, the default from base class is fine.
      num_rows_ = num_rows;
    }
  };
};

/// Feature: DistributedSampler
/// Description: Test DistributedSampler with num_shards=2 and shard_id=0
/// Expectation: The data is processed successfully
TEST_F(MindDataTestDistributedSampler, TestTwoShardsOne) {
  // num samples to draw.
  uint64_t num_samples = 7;

  // create sampler with replacement = true
  DistributedSamplerRT m_sampler(2, 0, false, num_samples, 0, -1, false);
  DummyRandomAccessOp dummyRandomAccessOp(num_samples);
  m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

  TensorRow row;
  std::vector<uint64_t> out;
  ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
  for (const auto &t : row) {
    for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
      out.push_back(*it);
    }
  }

  ASSERT_EQ(4, out.size());

  ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
  ASSERT_EQ(row.eoe(), true);
}

/// Feature: DistributedSampler
/// Description: Test DistributedSampler with num_shards=2 and shard_id=1
/// Expectation: The data is processed successfully
TEST_F(MindDataTestDistributedSampler, TestTwoShardsTwo) {
  // num samples to draw.
  uint64_t num_samples = 7;

  // create sampler with replacement = true
  DistributedSamplerRT m_sampler(2, 1, false, num_samples, 0, -1, false);
  DummyRandomAccessOp dummyRandomAccessOp(num_samples);
  m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

  TensorRow row;
  std::vector<uint64_t> out;
  ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());

  for (const auto &t : row) {
    for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
      out.push_back(*it);
    }
  }

  ASSERT_EQ(3, out.size());

  ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
  ASSERT_EQ(row.eoe(), true);
}

/// Feature: DistributedSampler
/// Description: Test DistributedSampler with num_shards=3 and shard_id=2
/// Expectation: The data is processed successfully
TEST_F(MindDataTestDistributedSampler, TestThreeShards) {
  // num samples to draw.
  uint64_t num_samples = 2;

  // create sampler with replacement = true
  DistributedSamplerRT m_sampler(3, 2, false, num_samples, 0, -1, false);
  DummyRandomAccessOp dummyRandomAccessOp(num_samples);
  m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp);

  TensorRow row;
  std::vector<uint64_t> out;
  ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());

  for (const auto &t : row) {
    for (auto it = t->begin<uint64_t>(); it != t->end<uint64_t>(); it++) {
      out.push_back(*it);
    }
  }

  ASSERT_EQ(0, out.size());

  ASSERT_EQ(m_sampler.GetNextSample(&row), Status::OK());
  ASSERT_EQ(row.eoe(), true);
}
