pytorch 中 混合精度训练(真香)

一、什么是混合精度训练

在pytorch的tensor中,默认的类型是float32,神经网络训练过程中,网络权重以及其他参数,默认都是float32,即单精度,为了节省内存,部分操作使用float16,即半精度,训练过程既有float32,又有float16,因此叫混合精度训练。

二、如何进行混合精度训练

pytorch中是自动混合精度训练,使用 torch.cuda.amp.autocasttorch.cuda.amp.GradScaler 这两个模块。
torch.cuda.amp.autocast:在选择的区域中自动进行数据精度之间的转换,即提高了运算效率,又保证了网络的性能。
torch.cuda.amp.GradScaler:来解决数据溢出问题,即数据溢出问题:Overflow / Underflow

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

三、哪些运算操作可以自动转换,哪些不可以

首先:只有 CUDA 操作有资格进行自动转换

下面这些操作可以自动转换为float16:

matmul, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell

下面这些操作可以自动转换为float32:
pow, rdiv, rpow, rtruediv, acos, asin, binary_cross_entropy_with_logits, cosh, cosine_embedding_loss, cdist, cosine_similarity, cross_entropy, cumprod, cumsum, dist, erfinv, exp, expm1, group_norm, hinge_embedding_loss, kl_div, l1_loss, layer_norm, log, log_softmax, log10, log1p, log2, margin_ranking_loss, mse_loss, multilabel_margin_loss, multi_margin_loss, nll_loss, norm, normalize, pdist, poisson_nll_loss, pow, prod, reciprocal, rsqrt, sinh, smooth_l1_loss, soft_margin_loss, softmax, softmin, softplus, sum, renorm, tan, triplet_margin_loss

有些操作并没有指定是float16还是float32,但是需要输入的数据类型一致,如果所有的输入都是float16,操作就是在float16中进行,如果输入中的任何一个是float32,操作就是在float32中进行。

四、遇到不能自动转换的操作怎么办

例如下面这句代码,在自动转换的区域,where操作中,tensor phi是float16,但是cosine是float32,where操作没有自动转换的能力,因此就会出现数据类型匹配,报错!
注意:下面的代码在非混合精度训练中,没有问题,因为所有生成的Tensor数据都是float32类型

phi = torch.where(cosine > self.th, phi, cosine - self.mm)

把上面一句代码改为下面的代码就可以了

phi = torch.where(cosine.to(dtype=phi.dtype) > self.th, phi, cosine.to(dtype=phi.dtype) - self.mm)

来源:仙女修炼史

物联沃分享整理
物联沃-IOTWORD物联网 » pytorch 中 混合精度训练(真香)

发表评论