美文网首页PyTorch阅读笔记
PyTorch-CKPT源码阅读记录

PyTorch-CKPT源码阅读记录

作者: CPinging | 来源:发表于2021-03-12 15:13 被阅读0次

    记录一下CKPT部分的阅读,为后期研究做铺垫。

    image.png

    首先在主函数中定位save函数的位置,进入:

    image.png

    这里初始化了一个OutputArchive类的archive对象,而该对象可以理解为一个装检查点文件的容器。位置的话在:

    image.png 其中定义了一些写入方法,其中比较常用的是位于 image.png
    void OutputArchive::save_to(const std::string& filename) {
      jit::ExportModule(module_, filename);
    }
    

    通过传入文件名来保存数据。

    void ExportModule(
        const Module& module,
        const std::string& filename,
        const ExtraFilesMap& extra_files,
        bool bytecode_format,
        bool save_mobile_debug_info) {
      // 初始化序列化模块
      ScriptModuleSerializer serializer(filename);
      serializer.serialize(
          module, extra_files, bytecode_format, save_mobile_debug_info);
    }
    

    1 writer_成员变量

    该函数中首先初始化ScriptModuleSerializer serializer(filename);然后调用serialize

    ScriptModuleSerializer类主要是包括caffe2::serialize::PyTorchStreamWriter writer_;私有成员变量。

      caffe2::serialize::PyTorchStreamWriter writer_;
      std::vector<at::IValue> constant_table_;
      std::unordered_set<c10::NamedTypePtr> converted_types_;
      PrintDepsTable class_deps_;
      TypeNameUniquer type_name_uniquer_;
    

    并且实现了序列化函数serialize

    首先该成员变量比较重要,是一个写入流的类:

    caffe2::serialize::PyTorchStreamWriter
    
    class CAFFE2_API PyTorchStreamWriter final {
     public:
      explicit PyTorchStreamWriter(std::string archive_name);
      explicit PyTorchStreamWriter(
          const std::function<size_t(const void*, size_t)>& writer_func);
    
      void setMinVersion(const uint64_t version);
    
      void writeRecord(
          const std::string& name,
          const void* data,
          size_t size,
          bool compress = false);
      void writeEndOfFile();
    
      bool finalized() const {
        return finalized_;
      }
    
      const std::string& archiveName() {
        return archive_name_;
      }
    
      ~PyTorchStreamWriter();
    
     private:
      void setup(const std::string& file_name);
      void valid(const char* what, const char* info = "");
      size_t current_pos_ = 0;
      std::unique_ptr<mz_zip_archive> ar_;
      std::string archive_name_;
      std::string archive_name_plus_slash_;
      std::string padding_;
      std::ofstream file_stream_;
      std::function<size_t(const void*, size_t)> writer_func_;
      uint64_t version_ = kProducedFileFormatVersion;
      bool finalized_ = false;
      bool err_seen_ = false;
      friend size_t ostream_write_func(
          void* pOpaque,
          uint64_t file_ofs,
          const void* pBuf,
          size_t n);
    };
    

    该对象在实例化的时候会调用构造函数explicit PyTorchStreamWriter(std::string archive_name);因为我们传入的是一个string字符串,所以只能匹配到这个,另一个是传入一个func。

    explicit PyTorchStreamWriter(std::string archive_name);

    该函数在c中被展开如下:

    image.png

    即调用了setup函数并传入文件名。

    void PyTorchStreamWriter::setup(const string& file_name) {
      ar_ = std::make_unique<mz_zip_archive>();
      memset(ar_.get(), 0, sizeof(mz_zip_archive));
      archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
    
      if (archive_name_.size() == 0) {
        CAFFE_THROW("invalid file name: ", file_name);
      }
      if (!writer_func_) {
        // open一个文件
        file_stream_.open(
            file_name,
            std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
        valid("opening archive ", file_name.c_str());
        TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
        writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
          file_stream_.write(static_cast<const char*>(buf), nbytes);
          return !file_stream_ ? 0 : nbytes;
        };
      }
    
      ar_->m_pIO_opaque = this;
      ar_->m_pWrite = ostream_write_func;
    
      mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
      valid("initializing archive ", file_name.c_str());
    }
    

    而该函数包括了ofstream流对文件的open以及write。

    2 serialize序列化

    void ExportModule(
        const Module& module,
        const std::string& filename,
        const ExtraFilesMap& extra_files,
        bool bytecode_format,
        bool save_mobile_debug_info) {
      // 初始化序列化模块(上文已经度过)
      ScriptModuleSerializer serializer(filename);
    // 现在关注下面这部分
      serializer.serialize(
          module, extra_files, bytecode_format, save_mobile_debug_info);
    }
    

    上面我们走了一条路是初始化writer_,下面我们看序列化函数:

    
      void serialize(
          const Module& module,
         // std::string, std::string结构:Map which stores filename to content.
          const ExtraFilesMap& extra_files,
          bool bytecode_format,
          bool save_mobile_debug_info) {
        C10_LOG_API_USAGE_ONCE("torch.script.save");
        writeExtraFiles(module, extra_files);
        // Serialize the model object
        writeArchive("data", module._ivalue());
        // Then we serialize all code info.
        writeCode(module.type());
        // The tensor constants from the code are written to a separate archive
        // so loading the code does not depend on loading the data
        std::vector<IValue> ivalue_constants(
            constant_table_.begin(), constant_table_.end());
        writeArchive("constants", c10::ivalue::Tuple::create(ivalue_constants));
        if (bytecode_format) {
          writeByteCode(module, save_mobile_debug_info);
          writeMobileMetadata(module, extra_files);
        }
    
        // Acquires and sets minimum (dynamic) version
        for (auto& item : file_streams_) {
          writer_.setMinVersion(item.value().minVersion());
        }
      }
    

    传入需要保存的模块、文件map以及两个bool标志(这里的话不需要传入文件名,因为文件名在初始化writer的时候使用)。

    其中比较关键的函数为:writeArchive("data", module._ivalue());

      void writeArchive(const std::string& archive_name, const IValue& value) {
        std::vector<char> data;
        // Vector to capture the run-time class types during pickling the IValues
        std::vector<c10::ClassTypePtr> memoizedClassTypes;
        Pickler data_pickle(
            [&](const char* buf, size_t size) {
              data.insert(data.end(), buf, buf + size);
            },
            nullptr,
            [&](const c10::ClassTypePtr& t) {
              return type_name_uniquer_.getUniqueName(t);
            },
            &memoizedClassTypes);
        data_pickle.protocol();
        data_pickle.pushIValue(value);
        data_pickle.stop();
        size_t i = 0;
        std::string prefix = archive_name + "/";
        for (const auto& td : data_pickle.tensorData()) {
          WriteableTensorData writable_td = getWriteableTensorData(td);
          std::string fname = prefix + c10::to_string(i++);
          writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
        }
        std::string fname = archive_name + ".pkl";
        writer_.writeRecord(fname, data.data(), data.size());
    
        // serialize all the captured run-time class types
        for (const c10::ClassTypePtr& wroteType : memoizedClassTypes) {
          convertNamedType(wroteType);
        }
      }
    

    使用Pickler将数据序列化为一种特定的 形式,然后使用writeRecord循环输入到文件中

    该函数首先初始化了Pickler data_pickle,这个类用来序列化obj数据,且初始化的时候需要传入四个参数,如下。

    image.png

    之后一键三连:

    // 传入PROTOCOL号
        data_pickle.protocol();
        data_pickle.pushIValue(value);
    // STOP中包括flush()操作,下面细讲
        data_pickle.stop();
    
    • 1 data_pickle.pushIValue(value)函数:

    执行流程为:data_pickle.pushIValue(ivalue) -> pushIValueImpl(ivalue) -> pushTensor(ivalue) -> pushLiteralTensor(ivalue) -> pushStorageOfTensor(tensor) -> tensor_data_.push_back(tensor)

    于是我们的最终目的就是为了赋值tensor_data_这个vector私有成员变量,而这个变量在后面会用到。

    • 2 data_pickle.stop()函数:

    data_pickle.stop() -> flush() -> flushNonEmpty() -> writer_(buffer_.data(), bufferPos_)

    这里的这个writer_是在初始化data_pickle的时候就赋值的

      Pickler(
          std::function<void(const char*, size_t)> writer,
          std::vector<at::Tensor>* tensor_table,
          std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer,
          std::vector<c10::ClassTypePtr>* memoized_class_types)
          : writer_(std::move(writer)),
            tensor_table_(tensor_table),
            type_renamer_(std::move(type_renamer)),
            memoized_class_types_(memoized_class_types) {}
    

    而这个function就是:

    [&](const char* buf, size_t size) {
              data.insert(data.end(), buf, buf + size);
            },
    

    当处理完成上述内容后,代码进入关键步骤:

    for (const auto& td : data_pickle.tensorData()) {
          WriteableTensorData writable_td = getWriteableTensorData(td);
          std::string fname = prefix + c10::to_string(i++);
          writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
        }
    

    循环data_pickle.tensorData()中所有的tensor,这里data_pickle.tensorData()代表上文中提到的tensor_data_。这个tensor里面存了要保存的数据(比如net网络)。

    • 1 首先进入函数getWriteableTensorData(td)
    WriteableTensorData getWriteableTensorData(
        const at::Tensor& tensor,
        bool to_cpu) {
      WriteableTensorData result;
      result.tensor_ = tensor;
      result.size_ = tensor.storage().nbytes();
      // TODO HIP support
      if (tensor.storage().device_type() == DeviceType::CUDA && to_cpu) {
        // NB: This new tensor is created to support cuda tensors.
        // Storages can be mutated when converting tensors from cuda to cpu,
        // and we need a cpu tensor to copy data from.
        result.tensor_ =
            at::empty({0}, tensor.options())
                .set_(
                    tensor.storage(),
                    /* storage_offset = */ 0,
                    /* size = */
                    {static_cast<int64_t>(
                        tensor.storage().nbytes() / tensor.element_size())},
                    /* stride = */ {1})
                .cpu();
        TORCH_CHECK(
            result.tensor_.storage().nbytes() == result.size_,
            "Storage tensor size did not match record size");
      }
      return result;
    }
    

    这里会判断该数据是否在GPU中,如果在的话就转移。

    具体的函数为:

        result.tensor_ = at::empty({0}, tensor.options()).set_(tensor.storage(), 0,{static_cast<int64_t>(tensor.storage().nbytes() / tensor.element_size())},{1}).cpu();
    

    这里首先调用at::empty

    Tensor empty(at::IntArrayRef dims, at::TensorOptions options) {
      // TODO: merge this with at::empty after Tensor is merged
      auto tensor = Tensor(dims, options.device());
      tensor.raw_mutable_data(options.dtype());
      return tensor;
    }
    

    创建一个空Tensor,该tensor的options是原tensor.options()

    然后调用set函数,初始化Tensor。

    之后调用.cpu()

    Tensor Tensor::cpu() const {
      return to(options().device(DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false);
    }
    

    传入CPU的标记

    Tensor to(const Tensor& self, Device device, ScalarType dtype, bool non_blocking, bool copy, c10::optional<c10::MemoryFormat> optional_memory_format) {
      device = ensure_has_index(device);
      return to_impl(
          self,
          self.options().device(device).dtype(dtype).memory_format(optional_memory_format),
          non_blocking,
          copy);
    }
    

    这个to函数走到最后是copy函数:

    Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
      auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
      {
        NoNamesGuard guard;
        copy_impl(self, src, non_blocking);
      }
      namedinference::propagate_names_if_nonempty(self, maybe_outnames);
      return self;
    }
    

    之后会走到copy_impl代码:

    这里会调用DEFINE_DISPATCH(copy_stub);

    这里会向两个方向走,第一个是GPU的转移:

    image.png

    关键函数如下:

      CUDAStream stream = getCurrentCUDAStream();
    
      if (non_blocking) {
        AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
        void* ptr = (dst_device == kCPU ? dst : src);
        AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(ptr, stream));
      } else {
    #if HIP_VERSION >= 301
        AT_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
    #else
        AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
        AT_CUDA_CHECK(cudaStreamSynchronize(stream));
    #endif
      }
    

    相关文章

      网友评论

        本文标题:PyTorch-CKPT源码阅读记录

        本文链接:https://www.haomeiwen.com/subject/mtcsqltx.html