#include <array>
#include <cstdio>
#include <cstring>
#include <string>

#include <gtest/gtest.h>

#include <c10/util/Logging.h>
#include "c10/core/CPUAllocator.h"
#include "c10/util/irange.h"
#include "caffe2/serialize/inline_container.h"

namespace caffe2 {
namespace serialize {
namespace {

TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
  int64_t kFieldAlignment = 64L;

  std::ostringstream oss;
  // write records through writers
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 127> data1;
  // Inplace memory buffer
  std::vector<uint8_t> buf(data1.size());

  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1", data1.data(), data1.size());

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 64> data2;
  for (auto i : c10::irange(data2.size())) {
    data2[i] = data2.size() - i;
  }
  writer.writeRecord("key2", data2.data(), data2.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 2);
  ASSERT_EQ(written_records.count("key1"), 1);
  ASSERT_EQ(written_records.count("key2"), 1);

  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  const char* file_name = "output.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  std::istringstream iss(the_file);

  // read records through readers
  PyTorchStreamReader reader(&iss);
  ASSERT_TRUE(reader.hasRecord("key1"));
  ASSERT_TRUE(reader.hasRecord("key2"));
  ASSERT_FALSE(reader.hasRecord("key2000"));
  at::DataPtr data_ptr;
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  int64_t size;
  std::tie(data_ptr, size) = reader.getRecord("key1");
  size_t off1 = reader.getRecordOffset("key1");
  ASSERT_EQ(size, data1.size());
  ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
  ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
  ASSERT_EQ(off1 % kFieldAlignment, 0);
  // inplace getRecord() test
  std::vector<uint8_t> dst(size);
  size_t ret = reader.getRecord("key1", dst.data(), size);
  ASSERT_EQ(ret, size);
  ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
  // chunked getRecord() test
  ret = reader.getRecord(
      "key1",
      dst.data(),
      size,
      3,
      buf.data(),
      [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
  ASSERT_EQ(ret, size);
  ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);

  std::tie(data_ptr, size) = reader.getRecord("key2");
  size_t off2 = reader.getRecordOffset("key2");
  ASSERT_EQ(off2 % kFieldAlignment, 0);

  ASSERT_EQ(size, data2.size());
  ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
  ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
  // inplace getRecord() test
  dst.resize(size);
  ret = reader.getRecord("key2", dst.data(), size);
  ASSERT_EQ(ret, size);
  ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
  // chunked getRecord() test
  ret = reader.getRecord(
      "key2",
      dst.data(),
      size,
      3,
      buf.data(),
      [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
  ASSERT_EQ(ret, size);
  ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
  // clean up
  remove(file_name);
}

TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
  std::ostringstream oss;
  // write records through writers
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 127> data1;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 64> data2;
  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1", data1.data(), data1.size());

  for (auto i : c10::irange(data2.size())) {
    data2[i] = data2.size() - i;
  }
  writer.writeRecord("key2", data2.data(), data2.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 2);
  ASSERT_EQ(written_records.count("key1"), 1);
  ASSERT_EQ(written_records.count("key2"), 1);

  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  const char* file_name = "output.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  // read records through pytorchStreamReader
  std::istringstream iss(the_file);
  PyTorchStreamReader reader(&iss);
  reader.setAdditionalReaderSizeThreshold(0);
  // before testing, sanity check
  int64_t size1, size2, ret;
  at::DataPtr data_ptr;
  std::tie(data_ptr, size1) = reader.getRecord("key1");
  std::tie(data_ptr, size2) = reader.getRecord("key2");

  // Test getRecord(name, additional_readers)
  std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
  for (int i = 0; i < 10; ++i) {
    // Test various sized additional readers.
    std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
    ASSERT_EQ(ret, size1);
    ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);

    std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
    ASSERT_EQ(ret, size2);
    ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
  }

  // Inplace multi-threading getRecord(name, dst, n, additional_readers) test
  additionalReader.clear();
  std::vector<uint8_t> dst1(size1), dst2(size2);
  for (int i = 0; i < 10; ++i) {
    // Test various sizes of read threads
    additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));

    ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
    ASSERT_EQ(ret, size1);
    ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0);

    ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
    ASSERT_EQ(ret, size2);
    ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0);
  }
  // clean up
  remove(file_name);
}

TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
  std::ostringstream oss;
  // write records through writers
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 127> data1;

  // Inplace memory buffer
  std::vector<uint8_t> buf;

  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1", data1.data(), data1.size());

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 64> data2;
  for (auto i : c10::irange(data2.size())) {
    data2[i] = data2.size() - i;
  }
  writer.writeRecord("key2", data2.data(), data2.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 2);
  ASSERT_EQ(written_records.count("key1"), 1);
  ASSERT_EQ(written_records.count("key2"), 1);

  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  const char* file_name = "output2.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  std::istringstream iss(the_file);

  // read records through readers
  PyTorchStreamReader reader(&iss);
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
  EXPECT_THROW(reader.getRecord("key3"), c10::Error);
  std::vector<uint8_t> dst(data1.size());
  EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error);
  EXPECT_THROW(
      reader.getRecord(
          "key3",
          dst.data(),
          data1.size(),
          3,
          buf.data(),
          [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }),
      c10::Error);

  // Reader should still work after throwing
  EXPECT_TRUE(reader.hasRecord("key1"));
  // clean up
  remove(file_name);
}

TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
  std::ostringstream oss;
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 127> data1;
  // Inplace memory buffer
  std::vector<uint8_t> buf(data1.size());

  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 64> data2;
  for (auto i : c10::irange(data2.size())) {
    data2[i] = data2.size() - i;
  }
  writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 2);
  ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
  ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  const char* file_name = "output3.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  std::istringstream iss(the_file);

  // read records through readers
  PyTorchStreamReader reader(&iss);
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)

  reader.setShouldLoadDebugSymbol(false);
  EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
  at::DataPtr ptr;
  size_t size;
  std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
  EXPECT_EQ(size, 0);
  std::vector<uint8_t> dst(data1.size());
  size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size());
  EXPECT_EQ(ret, 0);
  ret = reader.getRecord(
      "key1.debug_pkl",
      dst.data(),
      data1.size(),
      3,
      buf.data(),
      [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
  EXPECT_EQ(ret, 0);
  // clean up
  remove(file_name);
}

TEST(PytorchStreamWriterAndReader, ValidSerializationId) {
  std::ostringstream oss;
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, 127> data1;

  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
  writer.writeEndOfFile();
  auto writer_serialization_id = writer.serializationId();

  std::string the_file = oss.str();

  std::istringstream iss(the_file);

  // read records through readers
  PyTorchStreamReader reader(&iss);
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)

  EXPECT_EQ(reader.serializationId(), writer_serialization_id);

  // write a second time
  PyTorchStreamWriter writer2([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });
  writer2.writeRecord("key1.debug_pkl", data1.data(), data1.size());
  writer2.writeEndOfFile();
  auto writer2_serialization_id = writer2.serializationId();

  EXPECT_EQ(writer_serialization_id, writer2_serialization_id);
}

TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
  std::ostringstream oss;
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });

  std::string dup_serialization_id = "dup-serialization-id";
  writer.writeRecord(
      kSerializationIdRecordName,
      dup_serialization_id.c_str(),
      dup_serialization_id.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 0);
  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
  auto writer_serialization_id = writer.serializationId();

  std::string the_file = oss.str();
  const char* file_name = "output4.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  std::istringstream iss(the_file);

  // read records through readers
  PyTorchStreamReader reader(&iss);
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)

  EXPECT_EQ(reader.serializationId(), writer_serialization_id);
  // clean up
  remove(file_name);
}

TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
  std::map<std::string, std::map<std::string, std::string>> logs;

  SetAPIUsageMetadataLogger(
      [&](const std::string& context,
          const std::map<std::string, std::string>& metadata_map) {
        logs.insert({context, metadata_map});
      });
  std::ostringstream oss;
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });
  writer.writeEndOfFile();

  std::istringstream iss(oss.str());
  // read records through readers
  PyTorchStreamReader reader(&iss);

  ASSERT_EQ(logs.size(), 2);
  std::map<std::string, std::map<std::string, std::string>> expected_logs = {
      {"pytorch.stream.writer.metadata",
       {{"serialization_id", writer.serializationId()},
        {"file_name", "archive"},
        {"file_size", str(oss.str().length())}}},
      {"pytorch.stream.reader.metadata",
       {{"serialization_id", writer.serializationId()},
        {"file_name", "archive"},
        {"file_size", str(iss.str().length())}}}};
  ASSERT_EQ(expected_logs, logs);

  // reset logger
  SetAPIUsageMetadataLogger(
      [&](const std::string& context,
          const std::map<std::string, std::string>& metadata_map) {});
}

class TestAllocator : public at::Allocator {
 public:

  explicit TestAllocator(at::Allocator* allocator): baseAllocator_(allocator) {}
  at::DataPtr allocate(size_t nbytes) override {
  allocatedBytes_ += nbytes;
  return baseAllocator_->allocate(nbytes);
  }
  at::DeleterFnPtr raw_deleter() const override {
    return baseAllocator_->raw_deleter();
  }
  void copy_data(void* dest, const void* src, std::size_t count) const override {
    default_copy_data(dest, src, count);
  }
  size_t getAllocatedBytes() {
    return allocatedBytes_;
  }
 private:
  at::Allocator* baseAllocator_;
  size_t allocatedBytes_{0};
};

TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
  // create two test allocators, ones is supposed to be the default allocator
  // the other one is only used when user specifies it
  auto defaultAllocator = at::GetCPUAllocator();
  TestAllocator overrideAllocator(defaultAllocator);
  TestAllocator baseAllocator(defaultAllocator);
  c10::SetCPUAllocator(&baseAllocator, 10 /* priority */);

  std::ostringstream oss;
  // write records through writers
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });
  const size_t kBytes1 = 127;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, kBytes1> data1;
  // Inplace memory buffer
  std::vector<uint8_t> buf(data1.size());

  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1", data1.data(), data1.size());

  const size_t kBytes2 = 64;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, kBytes2> data2;
  for (auto i : c10::irange(data2.size())) {
    data2[i] = data2.size() - i;
  }
  writer.writeRecord("key2", data2.data(), data2.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 2);
  ASSERT_EQ(written_records.count("key1"), 1);
  ASSERT_EQ(written_records.count("key2"), 1);

  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  const char* file_name = "output.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  std::istringstream iss(the_file);

  // read records through readers
  PyTorchStreamReader reader(&iss);
  ASSERT_TRUE(reader.hasRecord("key1"));
  ASSERT_TRUE(reader.hasRecord("key2"));
  ASSERT_FALSE(reader.hasRecord("key2000"));
  // get the bytes allocated byfore read
  const auto allocBytes = baseAllocator.getAllocatedBytes();
  at::DataPtr data_ptr;
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
  int64_t size;
  // allocated with override allocator
  std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
  EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
  EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
  // allcoate with base allocator
  std::tie(data_ptr, size) = reader.getRecord("key1");
  EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
  EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);

  std::tie(data_ptr, size) = reader.getRecord("key2", &overrideAllocator);
  EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1 + kBytes2);
  EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
  std::tie(data_ptr, size) = reader.getRecord("key2");
  EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1 + kBytes2);
  EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1 + kBytes2);
  std::tie(data_ptr, size) = reader.getRecord("key2", &baseAllocator);
  EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1 + 2 * kBytes2);
}


TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreadsWithAllocator) {
  auto defaultAllocator = at::GetCPUAllocator();
  TestAllocator overrideAllocator(defaultAllocator);
  TestAllocator baseAllocator(defaultAllocator);
  c10::SetCPUAllocator(&baseAllocator, 10 /* priority */);
  std::ostringstream oss;
  // write records through writers
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });

  const size_t kBytes1 = 127;
  const size_t kBytes2 = 64;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, kBytes1> data1;
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
  std::array<char, kBytes2> data2;
  for (auto i : c10::irange(data1.size())) {
    data1[i] = data1.size() - i;
  }
  writer.writeRecord("key1", data1.data(), data1.size());

  for (auto i : c10::irange(data2.size())) {
    data2[i] = data2.size() - i;
  }
  writer.writeRecord("key2", data2.data(), data2.size());

  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 2);
  ASSERT_EQ(written_records.count("key1"), 1);
  ASSERT_EQ(written_records.count("key2"), 1);

  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  const char* file_name = "output.zip";
  std::ofstream foo(file_name);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();

  // read records through pytorchStreamReader
  std::istringstream iss(the_file);
  PyTorchStreamReader reader(&iss);
  reader.setAdditionalReaderSizeThreshold(0);
  // before testing, sanity check
  int64_t size1, size2, ret;
  at::DataPtr data_ptr;
  std::tie(data_ptr, size1) = reader.getRecord("key1");
  std::tie(data_ptr, size2) = reader.getRecord("key2");

  // Test getRecord(name, additional_readers)
  std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
  size_t allocatedBytes = 0;
  auto baseAllocBytes = baseAllocator.getAllocatedBytes();
  for (int i = 0; i < 10; ++i) {
    // Test various sized additional readers.
    std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader, &overrideAllocator);
    ASSERT_EQ(ret, size1);
    allocatedBytes += size1;
    EXPECT_EQ(overrideAllocator.getAllocatedBytes(), allocatedBytes);
    EXPECT_EQ(baseAllocator.getAllocatedBytes(), baseAllocBytes);
    ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);

    baseAllocBytes += size2;
    std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
    ASSERT_EQ(ret, size2);
    ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
    EXPECT_EQ(overrideAllocator.getAllocatedBytes(), allocatedBytes);
    EXPECT_EQ(baseAllocator.getAllocatedBytes(), baseAllocBytes);
  }

  // Inplace multi-threading getRecord(name, dst, n, additional_readers) test
  additionalReader.clear();
  std::vector<uint8_t> dst1(size1), dst2(size2);
  for (int i = 0; i < 10; ++i) {
    // Test various sizes of read threads
    additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));

    ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
    ASSERT_EQ(ret, size1);

    ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
    ASSERT_EQ(ret, size2);
  }
  // clean up
  remove(file_name);
}

class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
INSTANTIATE_TEST_SUITE_P(
    ChunkRecordIteratorTestGroup,
    ChunkRecordIteratorTest,
    testing::Values(100, 150, 1010));

TEST_P(ChunkRecordIteratorTest, ChunkRead) {
  auto chunkSize = GetParam();
  std::string zipFileName =
      "output_chunk_" + std::to_string(chunkSize) + ".zip";
  const char* fileName = zipFileName.c_str();
  const std::string recordName = "key1";
  const size_t tensorDataSizeInBytes = 1000;

  // write records through writers
  std::ostringstream oss(std::ios::binary);
  PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
    oss.write(static_cast<const char*>(b), n);
    return oss ? n : 0;
  });

  auto tensorData = std::vector<uint8_t>(tensorDataSizeInBytes, 1);
  auto dataPtr = tensorData.data();
  writer.writeRecord(recordName, dataPtr, tensorDataSizeInBytes);
  const std::unordered_set<std::string>& written_records =
      writer.getAllWrittenRecords();
  ASSERT_EQ(written_records.size(), 1);
  ASSERT_EQ(written_records.count(recordName), 1);
  writer.writeEndOfFile();
  ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);

  std::string the_file = oss.str();
  std::ofstream foo(fileName, std::ios::binary);
  foo.write(the_file.c_str(), the_file.size());
  foo.close();
  LOG(INFO) << "Finished saving tensor into zip file " << fileName;

  LOG(INFO) << "Testing chunk size " << chunkSize;
  PyTorchStreamReader reader(fileName);
  ASSERT_TRUE(reader.hasRecord(recordName));
  auto chunkIterator = reader.createChunkReaderIter(
      recordName, tensorDataSizeInBytes, chunkSize);
  std::vector<uint8_t> buffer(chunkSize);
  size_t totalReadSize = 0;
  while (auto readSize = chunkIterator.next(buffer.data())) {
    auto expectedData = std::vector<uint8_t>(readSize, 1);
    ASSERT_EQ(memcmp(expectedData.data(), buffer.data(), readSize), 0);
    totalReadSize += readSize;
  }
  ASSERT_EQ(totalReadSize, tensorDataSizeInBytes);
  // clean up
  remove(fileName);
}

} // namespace
} // namespace serialize
} // namespace caffe2
