Nerf源码解析——Pytroch3D版
文章目录
项目目录
train_nerf.py
- 构建模型
# Initialize the Radiance Field model.
model = RadianceFieldRenderer(
image_size=cfg.data.image_size,
n_pts_per_ray=cfg.raysampler.n_pts_per_ray, # 每一术光线采样的n个点
n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, # 每一术光线采样的n个点 fine网络
n_rays_per_image=cfg.raysampler.n_rays_per_image, #每个图像的n个光束
min_depth=cfg.raysampler.min_depth, # 最近的边界
max_depth=cfg.raysampler.max_depth, #最远的边界
stratified=cfg.raysampler.stratified, #分层?
stratified_test=cfg.raysampler.stratified_test, #分层测试
chunk_size_test=cfg.raysampler.chunk_size_test, #块大小测试
n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, # 坐标
n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, # 方向
n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, #xyz的 隐藏神经元
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, #方向的 隐藏神经元
n_layers_xyz=cfg.implicit_function.n_layers_xyz, # 坐标 n层
density_noise_std=cfg.implicit_function.density_noise_std, # 密度噪音方差
visualization=cfg.visualization.visdom, #可视化
)
- resume
checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) # 恢复训练的路径
if len(cfg.checkpoint_path) > 0:
# Make the root of the experiment directory. 创建实验目录的根目录
checkpoint_dir = os.path.split(checkpoint_path)[0]
os.makedirs(checkpoint_dir, exist_ok=True) # 创建checkpoint的路径。
# Resume training if requested.
if cfg.resume and os.path.isfile(checkpoint_path):
print(f"Resuming from checkpoint {checkpoint_path}.")
loaded_data = torch.load(checkpoint_path) # 从checkpoint加载数据
model.load_state_dict(loaded_data["model"]) # 根据checkpoint的路径恢复模型
stats = pickle.loads(loaded_data["stats"]) # 从checkoint中恢复出来的
print(f" => resuming from epoch {stats.epoch}.")
optimizer_state_dict = loaded_data["optimizer"] # 恢复得到
start_epoch = stats.epoch
- 初始化optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=cfg.optimizer.lr, #学习率由超参数指定
)
- 加载optimizer state dict
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
optimizer.last_epoch = start_epoch
- 初始化stats 对象
if stats is None:
stats = Stats(
["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"],
)
- 定义学习率
# learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
def lr_lambda(epoch):
return cfg.optimizer.lr_scheduler_gamma ** (
epoch / cfg.optimizer.lr_scheduler_step_size
)
- learning rate scheduling
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
)
- 加载训练和验证数据
# Load the training/validation data.
train_dataset, val_dataset, _ = get_nerf_datasets(
dataset_name=cfg.data.dataset_name,
image_size=cfg.data.image_size,
)
- 设定训练验证dataloader
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=0,
collate_fn=trivial_collate,
)
# The validation dataloader is just an endless stream of random samples.
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=0,
collate_fn=trivial_collate,
sampler=torch.utils.data.RandomSampler(
val_dataset,
replacement=True,
num_samples=cfg.optimizer.max_epochs,
),
)
- 模型训练。
nerf/datasets.py
get_nerf_datasets()
dataset_name: str, # 'lego | fern' 给定具体的场景。
image_size: Tuple[int, int], # 图像的尺寸。(height,width)
data_root: str = DEFAULT_DATA_ROOT, # 数据的网络链接?
autodownload: bool = True, #根据网络链接对数据进行自动下载。
Tuple[Dataset, Dataset, Dataset]: # 返回三个数据集。(训练集、验证集、测试集)
cameras_path = os.path.join(data_root, dataset_name + ".pth")
image_path = cameras_path.replace(".pth", ".png")
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
# Automatically download the data files if missing.
download_data((dataset_name,), data_root=data_root)
train_data = torch.load(cameras_path)
n_cameras = train_data["cameras"]["R"].shape[0]
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS # 这个是python image 库中设定好的, 也就是说,这个是image像素的最大值。
Image.MAX_IMAGE_PIXELS = None # The dataset image is very large ...
images = torch.FloatTensor(np.array(Image.open(image_path))) / 255.0
images = torch.stack(torch.chunk(images, n_cameras, dim=0))[..., :3]
Image.MAX_IMAGE_PIXELS = _image_max_image_pixels
# 看不懂这个神奇的操作,感觉没有1,2,5行是一样的。
scale_factors = [s_new / s for s, s_new in zip(images.shape[1:3], image_size)] # 图片的放缩比例计算。
if abs(scale_factors[0] - scale_factors[1]) > 1e-3:
raise ValueError(
"Non-isotropic scaling is not allowed. Consider changing the 'image_size' argument." # 意思是高和宽的放缩是不一致的,这样会影响图像的成像。
)
scale_factor = sum(scale_factors) * 0.5
if scale_factor != 1.0:
print(f"Rescaling dataset (factor={scale_factor})")
images = torch.nn.functional.interpolate(
images.permute(0, 3, 1, 2),
size=tuple(image_size),
mode="bilinear",
).permute(0, 2, 3, 1)
scale_factors = [s_new / s for s, s_new in zip(images.shape[1:3], image_size)] # 图片的放缩比例计算。
if abs(scale_factors[0] - scale_factors[1]) > 1e-3:
raise ValueError(
"Non-isotropic scaling is not allowed. Consider changing the 'image_size' argument." # 意思是高和宽的放缩是不一致的,这样会影响图像的成像。
)
scale_factor = sum(scale_factors) * 0.5
if scale_factor != 1.0:
print(f"Rescaling dataset (factor={scale_factor})")
images = torch.nn.functional.interpolate(
images.permute(0, 3, 1, 2),
size=tuple(image_size),
mode="bilinear",
).permute(0, 2, 3, 1)
train_dataset, val_dataset, test_dataset = [
ListDataset(
[
{"image": images[i], "camera": cameras[i], "camera_idx": int(i)}
for i in idx
]
)
for idx in [train_idx, val_idx, test_idx]
]
return train_dataset, val_dataset, test_dataset
nerf/nerf_render.py
文件实现了 RadianceFieldRenderer 类, 集成torch.nn.Module
rendering 的前向传播过程如下:
init()构造函数
参数:
image_size: Tuple[int, int], # 高 x 宽
n_pts_per_ray: int, # 在粗渲染过程中,沿着每条光线采样的点数。
n_pts_per_ray_fine: int, #在精细渲染过程中,沿着每条光线采样的点数
n_rays_per_image: int, # 训练时,每张图片的光束数量
min_depth: float, #用于粗渲染的采样射线点的最小深度。
max_depth: float, # 用于粗渲染的采样射线点的最大深度。
stratified: bool, # if true,在训练过程中分层每个光束点的深度?
stratified_test: bool, # if true ,在测试过程中 分层每个光束的深度?
chunk_size_test: int, # 每个成像射线块中的射线数,只有在训练的时候才会用到? (难道不是测试的时候吗?)
n_harmonic_functions_xyz: int = 6, # 进行位置嵌入时候的L值。
n_harmonic_functions_dir: int = 4, # 进行位置嵌入时候的L值。
n_hidden_neurons_xyz: int = 256, #在MLP的完全连接层中隐藏单元的数量,该层接受3D点位置并输出具有中间特征的占用域。
n_hidden_neurons_dir: int = 128, #在MLP的完全连接层中,接受中间特征和光线方向并输出亮度场(逐点颜色)的隐藏单元数。(最后一层)
n_layers_xyz: int = 8, # 输出占用字段的MLP的层数。 ???
append_xyz: Tuple[int] = (5,), # 占用MLP的跳过层的索引列表。在计算跳跃层之前,输入到MLP的张量被附加到跳跃层输入。
density_noise_std: float = 0.0, #将随机正态噪声的标准差加到占用MLP的输出上。仅在' self.training==True '时激活。
visualization: bool = False, #是否为可视化存储额外的输出。
属性:
self._renderer = torch.nn.ModuleDict() # render (分coarse 和fine)
self._implicit_function = torch.nn.ModuleDict() # 网络 (分coarse 和fine)
self._density_noise_std = density_noise_std
self._chunk_size_test = chunk_size_test
self._image_size = image_size
self.visualization = visualization
定义raymarcher:
# Init the EA raymarcher used by both passes.
raymarcher = EmissionAbsorptionNeRFRaymarcher() # 来源于raymarcher。py 返回特征和权重。
获取图像尺寸
# Parse out image dimensions.
image_height, image_width = image_size
**确定raysampler方式: **
for render_pass in ("coarse", "fine"):
if render_pass == "coarse":
# Initialize the coarse raysampler.
raysampler = NeRFRaysampler( # 定义在raysampler。py中
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
stratified=stratified,
stratified_test=stratified_test,
n_rays_per_image=n_rays_per_image,
image_height=image_height,
image_width=image_width,
)
elif render_pass == "fine":
# Initialize the fine raysampler.
raysampler = ProbabilisticRaysampler(# 定义在raysampler。py中
n_pts_per_ray=n_pts_per_ray_fine,
stratified=stratified,
stratified_test=stratified_test,
)
初始化render
# Initialize the fine/coarse renderer.
self._renderer[render_pass] = ImplicitRenderer( # render pass in (coarse,fine)
raysampler=raysampler,
raymarcher=raymarcher,
)
实例化网络
# Instantiate the fine/coarse NeuralRadianceField module. 实例化网络
self._implicit_function[render_pass] = NeuralRadianceField( # 在implicit_function.py 中。
n_harmonic_functions_xyz=n_harmonic_functions_xyz,
n_harmonic_functions_dir=n_harmonic_functions_dir,
n_hidden_neurons_xyz=n_hidden_neurons_xyz,
n_hidden_neurons_dir=n_hidden_neurons_dir,
n_layers_xyz=n_layers_xyz,
append_xyz=append_xyz,
)
precache_rays()
参数:
cache_cameras: List[CamerasBase], # precache ray的n个camera的列表。
cache_camera_hashes: List[str], # 每个相机n个独特标识符的列表
self._renderer["coarse"].raysampler.precache_rays(
cache_cameras,
cache_camera_hashes,
_process_ray_chunk()
参数:
` camera_hash: Optional[str], # pre-cached camera的唯一标识符。
camera: CamerasBase, # 一批场景被渲染的cameara,
image: torch.Tensor, # ground truth , shape(batch——size,,3)
chunk_idx: int, # 当前射线块的索引。
返回: dict
out: `dict` containing the outputs of the rendering:
`rgb_coarse`: The result of the coarse rendering pass.
`rgb_fine`: The result of the fine rendering pass.
`rgb_gt`: The corresponding ground-truth RGB values.
通过self._renderer来获取 渲染后的rgb,权重等:
# First evaluate the coarse rendering pass, then the fine one.
for renderer_pass in ("coarse", "fine"):
(rgb, weights), ray_bundle_out = self._renderer[renderer_pass]( # 不知道ray bundle out是神码
cameras=camera,
volumetric_function=self._implicit_function[renderer_pass],
chunksize=self._chunk_size_test,
chunk_idx=chunk_idx,
density_noise_std=(self._density_noise_std if self.training else 0.0),
input_ray_bundle=coarse_ray_bundle,
ray_weights=coarse_weights,
camera_hash=camera_hash,
)
if renderer_pass == "coarse":
rgb_coarse = rgb
# Store the weights and the rays of the first rendering pass
# for the ensuing importance ray-sampling of the fine render.
coarse_ray_bundle = ray_bundle_out # 把中间值保存下来。
coarse_weights = weights
if image is not None:
# Sample the ground truth images at the xy locations of the
# rendering ray pixels.
rgb_gt = sample_images_at_mc_locs( # 获取真实的rgb值。 函数由utils。py提供。
image[..., :3][None],
ray_bundle_out.xys,
)
else:
rgb_gt = None
elif renderer_pass == "fine":
rgb_fine = rgb
else:
raise ValueError(f"No such rendering pass {renderer_pass}")
forward()
根据输入camera的视角点,执行coarse和fine的渲染过程,。
渲染结果会和ground truth做比较。
对比峰值信噪比和均方误差。
参数
camera_hash: Optional[str], #
camera: CamerasBase,
image: torch.Tensor, # ground truth
返回:Tuple[dict, dict]
out:
dictcontaining the outputs of the rendering:
rgb_coarse: The result of the coarse rendering pass.
rgb_fine: The result of the fine rendering pass.
rgb_gt: The corresponding ground-truth RGB values.
rgb的shapa取决于 是否是training模式。
在training模式下, 三个rgb tensors 都是(batchsize, n_rays_per_image,3)
非training模式下,三个rgb tensors 都是((batch_size, image_size[0], image_size[1], 3)
metrics: dict
包含fine、coarse 和真实值对比的 误差metrics。
mse_coarse : coarse 和 真实值的MSE
mse_fine: fine 和 真实值的MSE
psnr_coarse: coarse 和 真实值的峰值信噪比
psnr_fine: fine 和 真实值的峰值信噪比
如果是测试的话,获取chunks的数量
if not self.training:
# Full evaluation pass.
n_chunks = self._renderer["coarse"].raysampler.get_n_chunks(
self._chunk_size_test,
camera.R.shape[0],
)
else:
# MonteCarlo ray sampling.
n_chunks = 1
计算一个chunk的输出:
# Process the chunks of rays.
chunk_outputs = [
self._process_ray_chunk(
camera_hash,
camera,
image,
chunk_idx,
)
for chunk_idx in range(n_chunks)
]
对于完整的渲染通道,连接输出块,并重塑为图像大小。
if not self.training:
# For a full render pass concatenate the output chunks,
# and reshape to image size.
out = {
k: torch.cat(
[ch_o[k] for ch_o in chunk_outputs],
dim=1,
).view(-1, *self._image_size, 3)
if chunk_outputs[0][k] is not None
else None
for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
}
else:
out = chunk_outputs[0]
计算误差矩阵
# Calc the error metrics.
metrics = {}
if image is not None:
for render_pass in ("coarse", "fine"):
for metric_name, metric_fun in zip(
("mse", "psnr"), (calc_mse, calc_psnr)
):
metrics[f"{metric_name}_{render_pass}"] = metric_fun(
out["rgb_" + render_pass][..., :3],
out["rgb_gt"][..., :3],
)
return out, metrics
visualize_nerf_outputs()
可视化RadianceFieldRenderer 的输出
参数
nerf_out: dict, # 渲染的输出
output_cache: List, # 多个训练渲染通道的输出list
viz: Visdom, # 一个visdom的链接对象
visdom_env: str # 可视化的visdom环境的名字
展示训练的图片
# Show the training images.
ims = torch.stack([o["image"] for o in output_cache])
ims = torch.cat(list(ims), dim=1)
viz.image(
ims.permute(2, 0, 1),
env=visdom_env,
win="images",
opts={"title": "train_images"},
)
将coarse、fine 和真实值一起展示
# Show the coarse and fine renders together with the ground truth images.
ims_full = torch.cat(
[
nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
],
dim=2,
)
viz.image(
ims_full,
env=visdom_env,
win="images_full",
opts={"title": "coarse | fine | target"},
)
制作训练摄像机及其发射光线的 3D 图。
camera_trace = {
f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache)
}
ray_pts_trace = {
f"ray_pts_{ci:03d}": Pointclouds(
ray_bundle_to_ray_points(o["coarse_ray_bundle"])
.detach()
.cpu()
.view(1, -1, 3)
)
for ci, o in enumerate(output_cache)
}
plotly_plot = plot_scene(
{
"training_scene": {
**camera_trace,
**ray_pts_trace,
},
},
pointcloud_max_points=5000,
pointcloud_marker_size=1,
camera_scale=0.3,
)
viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")
implicit_function.py _ class NeuralRadianceField()
** 参数**
n_harmonic_functions_xyz: int = 6,
n_harmonic_functions_dir: int = 4,
n_hidden_neurons_xyz: int = 256,
n_hidden_neurons_dir: int = 128,
n_layers_xyz: int = 8,
append_xyz: Tuple[int] = (5,), #占用 MLP 的跳过层的索引列表。
use_multiple_streams: bool = True, # 是否应在单独的 CUDA 流上计算密度和颜色。
**kwargs,
构造函数
self.harmonic_embedding_xyz = HarmonicEmbedding(n_harmonic_functions_xyz)
self.harmonic_embedding_dir = HarmonicEmbedding(n_harmonic_functions_dir)
embedding_dim_xyz = n_harmonic_functions_xyz * 2 * 3 + 3
embedding_dim_dir = n_harmonic_functions_dir * 2 * 3 + 3
self.mlp_xyz = MLPWithInputSkips( # 多层感知机
n_layers_xyz,
embedding_dim_xyz,
n_hidden_neurons_xyz,
embedding_dim_xyz,
n_hidden_neurons_xyz,
input_skips=append_xyz,
)
self.intermediate_linear = torch.nn.Linear( # 中间层
n_hidden_neurons_xyz, n_hidden_neurons_xyz
)
_xavier_init(self.intermediate_linear) #执行线性层“linear”的 Xavier 权重初始化。
self.density_layer = torch.nn.Linear(n_hidden_neurons_xyz, 1)
_xavier_init(self.density_layer)
# Zero the bias of the density layer to avoid
# a completely transparent initialization. 将密度层的偏差归零以避免完全透明的初始化。
self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough
self.color_layer = torch.nn.Sequential(
LinearWithRepeat(
n_hidden_neurons_xyz + embedding_dim_dir, n_hidden_neurons_dir
),
torch.nn.ReLU(True),
torch.nn.Linear(n_hidden_neurons_dir, 3),
torch.nn.Sigmoid(),
)
self.use_multiple_streams = use_multiple_streams
_get_densities()
此函数采用 self.mlp_xyz
预测的 features
并使用 self.density_layer
将它们转换为 raw_densities
。 raw_densities
稍后使用深度步长重新加权,并映射到 [0-1] 范围,其中 1 – raw_densities
的反指数。
参数
features: torch.Tensor, # MLP的输出
depth_values: torch.Tensor,
density_noise_std: float,
返回 : torch.tensor
内容:
raw_densities = self.density_layer(features) # MLP输出的特征。
deltas = torch.cat( # 这个delta是干啥的?
(
depth_values[..., 1:] - depth_values[..., :-1],
1e10 * torch.ones_like(depth_values[..., :1]),
),
dim=-1,
)[..., None]
if density_noise_std > 0.0:
raw_densities = (
raw_densities + torch.randn_like(raw_densities) * density_noise_std
)
densities = 1 - (-deltas * torch.relu(raw_densities)).exp() # 密度的最终计算公式。
return densities
_get_colors()
预测每个点的rgb颜色
参数
self, features: torch.Tensor,
rays_directions: torch.Tensor
) -> torch.Tensor:
# Normalize the ray_directions to unit l2 norm. 将 ray_directions 归一化为单位 l2 范数
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions. 获得归一化光线方向的谐波嵌入。
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
return self.color_layer((self.intermediate_linear(features), rays_embedding)) # 直接从color层获得颜色。
_get_densities_and_colors()
对于一个batch而言,做整体的计算
参数
features: torch.Tensor,
ray_bundle: RayBundle,
density_noise_std: float
) -> Tuple[torch.Tensor, torch.Tensor]:
两个输出,一个是rays_densities,shape (minibatch, ..., num_points_per_ray, 1)
,用于表示每个射线点的不透明度
一个是rays_colors, shape (minibatch, …, num_points_per_ray, 3) 用于表示每个射线点的颜色。
if self.use_multiple_streams and features.is_cuda:
current_stream = torch.cuda.current_stream(features.device)
other_stream = torch.cuda.Stream(features.device)
other_stream.wait_stream(current_stream)
with torch.cuda.stream(other_stream):
rays_densities = self._get_densities(
features, ray_bundle.lengths, density_noise_std
)
# rays_densities.shape = [minibatch x ... x 1] in [0-1]
rays_colors = self._get_colors(features, ray_bundle.directions)
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
current_stream.wait_stream(other_stream)
else:
# Same calculation as above, just serial.
rays_densities = self._get_densities(
features, ray_bundle.lengths, density_noise_std
)
rays_colors = self._get_colors(features, ray_bundle.directions)
return rays_densities, rays_colors
forward()
参数
ray_bundle: RayBundle,
density_noise_std: float = 0.0, #不一个浮点值,表示添加到不透明度函数输出的随机法线噪声的方差。 这可以防止浮动伪影。
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
ray_bundle 包含了光束原点、方向、和长度等信息。
(minibatch, ..., 3)
(minibatch, ..., 3)
(minibatch, ..., num_points_per_ray)
光线被采样的长度返回
(minibatch, ..., num_points_per_ray, 1)
denoting the opacity of each ray point.(minibatch, ..., num_points_per_ray, 3)
denoting the color of each ray point.转换参数为世界坐标系
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x 3]
# For each 3D world coordinate, we obtain its harmonic embedding.
embeds_xyz = self.harmonic_embedding_xyz(rays_points_world)
# embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3]
# self.mlp maps each harmonic embedding to a latent feature space.
features = self.mlp_xyz(embeds_xyz, embeds_xyz)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
rays_densities, rays_colors = self._get_densities_and_colors(
features, ray_bundle, density_noise_std
)
return rays_densities, rays_colors
Implicit_function.py —— class MLPWithInputSkips()
实现NeRF的多层感知机
构造函数
def __init__(
self,
n_layers: int,
input_dim: int,
output_dim: int,
skip_dim: int,
hidden_dim: int,
input_skips: Tuple[int] = (),
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers. 在评估跳过层时附加的张量“z”的通道数。
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
layers = []
for layeri in range(n_layers):
if layeri == 0:
dimin = input_dim
dimout = hidden_dim
elif layeri in input_skips:
dimin = hidden_dim + skip_dim
dimout = hidden_dim
else:
dimin = hidden_dim
dimout = hidden_dim
linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(torch.nn.Sequential(linear, torch.nn.ReLU(True)))
self.mlp = torch.nn.ModuleList(layers)
self._input_skips = set(input_skips)
forward()
def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
"""
Args:
x: The input tensor of shape `(..., input_dim)`.
z: The input skip tensor of shape `(..., skip_dim)` which is appended
to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape `(..., output_dim)`.
"""
y = x
for li, layer in enumerate(self.mlp):
if li in self._input_skips:
y = torch.cat((y, z), dim=-1)
y = layer(y)
return y
其他文件
参看文献
https://github.com/facebookresearch/pytorch3d/tree/main/projects/nerf
来源:_Summer tree