深入理解torch.clamp()函数的用法和原理

torch.clamp()函数用于对输入张量进行截断操作,将张量中的每个元素限制在指定的范围内。

其语法为:

torch.clamp(input, min, max, out=None) -> Tensor

其中,参数的含义如下:

  • input:输入张量。
  • min:张量中的最小值。如果为None,则表示不对最小值进行限制。
  • max:张量中的最大值。如果为None,则表示不对最大值进行限制。
  • out:输出张量。
  • torch.clamp()函数返回一个新的张量,其中每个元素都被截断在[min, max]的范围内。如果minmaxNone,则对应的限制条件被忽略。

    下面是一个使用torch.clamp()函数的示例:

    import torch
    
    x = torch.randn(2, 3)
    print(x)
    y = torch.clamp(x, min=-0.5, max=0.5)
    print(y)
    

    输出结果为:

    tensor([[-0.3138, -0.1604, -0.4374],
            [-1.0861, -0.2837,  1.1688]])
    tensor([[-0.3138, -0.1604, -0.4374],
            [-0.5000, -0.2837,  0.5000]])
    

    可以看到,torch.clamp()函数将x张量中的元素限制在了[-0.5, 0.5]的范围内。

    物联沃分享整理
    物联沃-IOTWORD物联网 » 深入理解torch.clamp()函数的用法和原理

    发表评论