美文网首页
NeRF源码解析

NeRF源码解析

作者: 小黄不头秃 | 来源:发表于2024-01-23 09:25 被阅读0次

    NeRF的源码结构如代码中的"源码结构.md"所示。

    1、NeRF训练参数介绍

    # 配置文件的路径
        parser.add_argument('--config', is_config_file=True,help='config file path')
        # 本次实验的名称,作为log中文件夹的名字
        parser.add_argument("--expname", type=str,help='experiment name')
        # 输出目录
        parser.add_argument("--basedir", type=str, default='./logs/',help='where to store ckpts and logs')
        # 指定数据集的目录
        parser.add_argument("--datadir", type=str, default='./data/llff/fern',help='input data directory')
        # training options
        # 全连接的层数
        parser.add_argument("--netdepth", type=int, default=8,help='layers in network')
        # 网络宽度
        parser.add_argument("--netwidth", type=int, default=256,help='channels per layer')
        # 精细网络的全连接层数
        # 默认精细网络的深度和宽度与粗糙网络是相同的
        parser.add_argument("--netdepth_fine", type=int, default=8,help='layers in fine network')
        parser.add_argument("--netwidth_fine", type=int, default=256,help='channels per layer in fine network')
        # 这里的batch size,指的是光线的数量,像素点的数量
        # N_rand 配置文件中是1024,光线的数量
        # 32*32*4=4096
        # 800*800/4096=156 400*400/1024=156
        parser.add_argument("--N_rand", type=int, default=32 * 32 * 4,help='batch size (number of random rays per gradient step)')
        # 学习率
        parser.add_argument("--lrate", type=float, default=5e-4,help='learning rate')
        # 学习率衰减
        parser.add_argument("--lrate_decay", type=int, default=250,help='exponential learning rate decay (in 1000 steps)')
        # 如果上述的N_rand > chunk就会分批处理
        parser.add_argument("--chunk", type=int, default=1024 * 32,help='number of rays processed in parallel, decrease if running out of memory')
        # 神经网络中处理的点的数量
        parser.add_argument("--netchunk", type=int, default=1024 * 64,help='number of pts sent through network in parallel, decrease if running out of memory')
        # rendering options 粗网络渲染时的采样点数量
        parser.add_argument("--N_samples", type=int, default=64,help='number of coarse samples per ray')
        # 精细网络采样点数量
        parser.add_argument("--N_importance", type=int, default=0,help='number of additional fine samples per ray')
        # 在采样点附近是否加入随机扰动
        parser.add_argument("--perturb", type=float, default=1.,help='set to 0\. for no jitter, 1\. for jitter')
        # L=10
        parser.add_argument("--multires", type=int, default=10,help='log2 of max freq for positional encoding (3D location)')
        # L=4
        parser.add_argument("--multires_views", type=int, default=4,help='log2 of max freq for positional encoding (2D direction)'
    

    2、数据集加载

    Blender结构的数据集的文件结构长这个样子→ :

    json文件内容如下(相机坐标系转世界坐标系):

    还有llff格式的数据,在该文件夹下,首先是一个images文件夹,并在该数据集的根目录下会有llff的相机位姿文件,里面是一个数组(20, 17),前面15维是位姿信息,后面2维是边界。其中处理过后,将其差分为poses的形状为(20, 3,5),bds(2, 20). 并且将相机坐标系转换为世界坐标系。

    首先先处理数据集,大部分数据集都是blender的数据结构,该代码中提供了四种不同的数据集加载方式。在此函数中,主要包含了读取文件,图像归一化,提取位置信息,划分数据集、验证集、测试集,计算焦距,渲染的位置(这个是为了测试)。详细见代码。

    3、位置编码介绍 pe

    位置编码对应的代码位于“./run_nerf.py”的"create_nerf"中,利用"get_embedder"函数,此函数会返回一个lambda函数和神经网络的输入通道数量。

    接下来就是生成NeRF神经网络了,并且会构建神经网络优化器。神经网络的结构如图所示:

    4、将像素坐标系转换成世界坐标系的方法 get_rays

    这部分主要是获得光线的原点和方向,还有坐标系的转换,然后随机选取N_rand条光线

    然后,这里是将相机坐标系转换为世界坐标系,然后获得每一个光线的原点和方向。

    def get_rays_np(H, W, K, c2w):
        # 与上面的方法相似,这个是使用的numpy,上面是使用的torch
        i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
        dirs = np.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -np.ones_like(i)], -1)
        # Rotate ray directions from camera frame to the world frame
        rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3],-1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
        # Translate camera frame's origin to the world frame. It is the origin of all rays.
        rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
        return rays_o, rays_d 
    
    

    5、主循环介绍

    首选是一个与提渲染相关的"render"函数,这里面主要是对光线进行采样点的选择。也就是r(t) = o + td,采样过后经过神经网络计算,获得不透明度和三个颜色值。

    6、体渲染 raw2outputs

    神经网络的结果经过体渲染公式进行输出。

    7、分层采样 sample_pdf

    上面是粗糙网络的采样,后面就是通过精细网络进行分层采样。

    源码解析视频讲解:NeRF源码解析_哔哩哔哩_bilibili

    git代码仓库:https://github.com/xunull/read-nerf-pytorch

    原始代码:https://github.com/yenchenlin/nerf-pytorch

    官方代码:https://github.com/bmild/nerf

    旋转矩阵: https://blog.csdn.net/csxiaoshui/article/details/65446125

    相机标定: https://blog.csdn.net/Kalenee/article/details/99207102

    相关文章

      网友评论

          本文标题:NeRF源码解析

          本文链接:https://www.haomeiwen.com/subject/zpoyodtx.html