您现在的位置是:首页 >技术交流 >解读nerf_pytorch中的get_rays和get_rays_np函数网站首页技术交流
解读nerf_pytorch中的get_rays和get_rays_np函数
source code from yenchenlin:
https://github.com/yenchenlin/nerf-pytorch
作者对于numpy的各种操作出神入化,其精炼程度令人叹为观止。本文总结其中两个函数的物理模型意义与(尤其是)矩阵计算意义,作为学习记录。
物理模型意义
详见:
https://zhuanlan.zhihu.com/p/593204605/
中的《3D空间射线怎么构造》。
矩阵计算意义(重点)
比较get_rays和get_rays_np可以发现,前者是在pytorch中、后者实在numpy中的同一操作(所以后者函数名以“np”结尾)。因此我们选择其中一个进行研究即可(get_rays):
def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
接下来进行我学习花费良久的逐行解释——
输入参数
调用该函数的函数有相关注解:
H: int. Height of image in pixels.
W: int. Width of image in pixels.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
K则是一个(3x3)矩阵。
第一行
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))
作者给了一句尾注:
pytorch’s meshgrid has indexing=‘ij’
torch.linspace(0,W-1,W)
的意思,从0到W-1取一共W个点,弄成一个行向量。同理,
torch.linspace(0,H-1,H)
从0到W-1取一共W个点,弄成一个行向量。然后把两个放入torch的meshgrid,就可以得到一个以第一个参数为列而重复的矩阵,以及一个以第二个参数为行而重复的矩阵。注意,这一点和numpy的meshgrid是恰恰相反的(无语)。所以这就解释了第二行和第三行(numpy的计算相对更加符合思考的惯用形式,相对):
i = i.t()
j = j.t()
另外,和numpy的meshgrid相比,torch的meshgrid自带默认的类似前者的“indexing=‘xy’”的功能。
第四行
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
这里,K[0][2])/K[0][0]
只是一个标量,而i是一个(H,W)的矩阵,那这样就意味着广播的介入,因此i-K[0][2])/K[0][0]
就是一个(H,W)的矩阵:它意思是每个像素点的横坐标都根据https://zhuanlan.zhihu.com/p/593204605/
中的《3D空间射线怎么构造》的公式计算好了。显然,每个像素点的纵坐标也同样通过一个(H,W)的矩阵 -(j-K[1][2])/K[1][1]
得到了。类似地,z坐标的情况则是 -torch.ones_like(i)
。
好,那么torch.stack在这里是要做什么呢?观察其axis参数为-1
。我参考了https://blog.csdn.net/weixin_44201525/article/details/109769214的讲法,特别是:
axis为0,表示它堆叠方向为第0维,堆叠的内容为数组第0维的数据。前面说了第0维是相对于堆叠的数组而言的,而这里数组的第0维其实就是整个3×4的数组(其中第1维为行,第2维为某一行中的一个值,这里有一个层层深入的感觉),所以就是以整个3×4的数组为堆叠内容在第0维上进行堆叠,等到的结果就是一个3×3×4的新数组。再通俗一点,就是将a,b,c分别作为堆叠内容进行堆叠得到3×3×4的输出。
以及
和刚才的解释一样,axis为1表示堆叠的方向为3×4数组的第1维(行),堆叠内容也为3×4数组的第1维的数据。而3×4的数组的第1维就是它的行,以数组a为例,它的堆叠数据分别是[0 1 2 3],[ 4 5 6 7],[ 8 9 10 11]。
意思就是说,根据层层深入的思想,axis=-1,也就是最后一个维度,那么就可以理解为,你通过这个维度之前的一层层维度,深入到了这个维度,然后开始堆叠。
所以我们看,这里dirs
就是一个像素点一个像素点地“堆叠”,其中每个像素点的信息就是它的xyz坐标。
第五、六行
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
这里确实十分令人头疼!
如前所述,dirs的维度在上一行应该是(W,H,3),其中最后的3“遍历”每个点的x、y、z。现在,
dirs[...,np.newaxis,:]
可以得到一个(W, H, 1, 3)的矩阵。那么它和c2w[:3,:3]
的关系是啥?阅读https://blog.csdn.net/qq_51352578/article/details/125074264 学习numpy的广播机制,可以知道,它这里插入一个新的1维度,可以让逐点乘法*
得以完成。但是这也不是乱加的axis。我们要问,这个操作的物理意义是啥?
事实上,答案和上一行的解读类似。就是说,你插入了一个newaxis,那么广播的时候你就自己复制了之前维度的东西。在咱的场景里,这个之前的维度不是别的,正是一个个像素点!事实上,c2w[:3,:3] 即3列分别表达关于x轴、y轴、z轴的信息
(参见 c2w矩阵的值直接描述了相机坐标系的朝向和原点 )。这里的*
运算可理解为:
> (。。。) 点 点 点 * c2w(3,3)
> 口 口 口
> 口 口 口
> 口 口 口
然后sum就是按列求和(其中同一个点被案列复制了三遍,这就是加了个newaxis的效果!,有转置的特性)。这也符合作者注释里面说的:
dot product, equals to: [c2w.dot(dir) for dir in dirs]
即每个dir是锁定了横坐标的点坐标数据,然后被c2w左乘。
第七、八行
Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
参看expand
也就是指定维度的一种广播
第九行
return rays_o, rays_d
不难知道此时返回的两个都是维度为(H,W,3)