Pytorch中torch.save时产生Unknown err

作者: 野风君 | 来源:发表于2018-08-28 22:00 被阅读5次

最近在用Pytorch的时候,出现了一个错误,error traceback如下:

Traceback (most recent call last):
File "main-train.py", line 250, in main()
File "main-train.py", line 232, in main
m.state_dict(), '../exp/{}/{}/model_{}.pth'.format(opt.dataset, opt.expID, opt.epoch))
File "/usr/local/lib/python3.5/dist-packages/torch/serialization.py", line 135, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/usr/local/lib/python3.5/dist-packages/torch/serialization.py", line 117, in _with_file_like
return body(f)
File "/usr/local/lib/python3.5/dist-packages/torch/serialization.py", line 135, in 
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "/usr/local/lib/python3.5/dist-packages/torch/serialization.py", line 204, in _save
serialized_storages[key]._write_file(f)
RuntimeError: Unknown error -1

如这个错误栈所示,这是在调用torch.save保存模型的时候发生的,具体的错误发生位置在torch/serialization.py中,而且在Pytorch 0.3.1Pytorch 0.4我都遇到了这个问题,下面贴出serialization.py对应函数的代码(从Pytorch0.3.0中文文档中复制黏贴,对于0.3.10.4版本,代码类似):

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
    """将一个对象保存到一个磁盘文件中.

    另见: :ref:`recommend-saving-models`

    参数:
        obj: 要保存的对象
        f: 类文件对象 (必须实现返回文件描述符的 fileno 方法) 或包含文件名的字符串
        pickle_module: 用于 pickling 元数据和对象的模块
        pickle_protocol: 可以指定来覆盖默认协议
    """
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))

def _save(obj, f, pickle_module, pickle_protocol):
    import torch.nn as nn
    serialized_container_types = {}
    serialized_storages = {}

    def persistent_id(obj):
        # FIXME: the docs say that persistent_id should only return a string
        # but torch store returns tuples. This works only in the binary protocol
        # see
        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
        if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_file = inspect.getsourcefile(obj)
                source = inspect.getsource(obj)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)
        elif torch.is_storage(obj):
            storage_type = normalize_storage_type(type(obj))
            root, offset = obj._root_storage()
            root_key = str(root._cdata)
            location = location_tag(obj)
            serialized_storages[root_key] = root
            is_view = obj._cdata != root._cdata
            if is_view:
                view_metadata = (str(obj._cdata), offset, obj.size())
            else:
                view_metadata = None

            return ('storage',
                    storage_type,
                    root_key,
                    location,
                    root.size(),
                    view_metadata)

        return None

    sys_info = dict(
        protocol_version=PROTOCOL_VERSION,
        little_endian=sys.byteorder == 'little',
        type_sizes=dict(
            short=SHORT_SIZE,
            int=INT_SIZE,
            long=LONG_SIZE,
        ),
    )

    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)

    serialized_storage_keys = sorted(serialized_storages.keys())
    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
    f.flush()
    for key in serialized_storage_keys:
        serialized_storages[key]._write_file(f)

对于这个错误,可能有多个原因,两类较为常见的是:

  1. 保存模型时,目标目录下的磁盘空间不够了,类似的issue已经被多人报告过(如Pytorch Forum上的讨论),对于这种情况的解决比较简单,把模型保存到其他存储空间的路径就可以了。
  2. 因为Pytorch目前版本中一个较为隐蔽的BUG,相关的说明见Github上的issue#8477),原因较为复杂,issue的报告人给出了相关的推断:

I added some code in the source code to print the error and signal, then reproduced and got the bold lines.Seems like save operation happened right after validation finished, exit of worker subprocesses caused the SIGCHLD, then interrupted "write" system call.

简单来说,就是因为在一个epoch的训练结束后的validation结束时,对应的worker子进程会结束,抛出来一个SIGCHLD信号,然后这个信号有时候会调动父进程把当前进行“write”操作的动作终止掉,也就是终止了保存模型时对硬盘写入的操作。

对于这种情况,当然可以通过比较复杂的手段来解决,但是最简单的方法就是在torch.save之前加入一定的等待时间,避免对应的SIGCHLD操作造成影响,比如插入:

import time
time.sleep(10)

非常不幸的,上述两种情况我都遇到过,第一种情况很快解决了,第二种困扰了我很久。总的来说,这是一个Pytorch当前版本的BUG,而且在产生错误的时候没有给出友好的错误信息,前述两种情况对应的产生原因实际上是不一样的,但是都没有预先设置的友好的错误处理机制处理,才产生了“unkown error”这样的信息。所以,这个报错之下,也很有可能会有别的错误原因存在,也欢迎大家的分享。

PS. Pytorch的团队已经注意到了这个问题,并且将这个issue的处理添加到了TODO中,预计在之后的版本会得到处理。

相关文章

网友评论

    本文标题:Pytorch中torch.save时产生Unknown err

    本文链接:https://www.haomeiwen.com/subject/xskcwftx.html