详解pytorch中的常见的Tensor数据类型以及类型转换

文章目录

  • 概览
  • Tensor的构建
  • 补充
  • 类型转换
  • 附录
  • 概览

    本文主要讲pytorch中的常见的Tensor数据类型,例如:float32float64int32int64。构造他们分别使用如下函数:torch.FloatTensor()torch.DoubleTensor(), torch.IntTensor(), torch.LongTensor()

    Tensor的构建

    1.32-bit floating point:

    a=torch.FloatTensor([1.0,3.0])#a=torch.Tensor([1.0,3.0])和前面等价
    print(a.dtype)
    #torch.float32
    

    2.64-bit floating point

    a=torch.DoubleTensor([1,3])
    print(a.dtype)
    #torch.float64
    

    3.32-bit integer (signed)

    a=torch.IntTensor([1,3])
    print(a.dtype)
    #torch.int32
    

    4.64-bit integer (signed)

    a=torch.LongTensor([1,3])
    print(a.dtype)
    #torch.int64
    

    补充

    type(a)
    #torch.Tensor
    

    torch.Tensor作为一个对象,你创建的所有Tensor,不管是什么数据类型,都是torch.Tensor类,其所有元素都只能是单一数据类型。即:

    A torch.Tensor is a multi-dimensional matrix containing elements of a single data type

    即使你给的数据有多种类型,其会自动转换。比如:

    a=torch.LongTensor([1,3.1])
    print(a.dtype)
    #torch.int64
    a
    #tensor([1, 3])
    

    除了用上述构建方法构建torch.Tensor之外,还可以用torch.tensor()来构建,我个人比较喜欢这个,因为其功能更加强大。上面那种torch.xxxTensor()的方式要记好几种,这里有像numpy那样的dtype参数直接指定,可以看作是前面的升级版吧。

    a=torch.tensor([1,3.1],dtype=torch.int32)#指定是哪种数据类型就是哪种。
    print(a.dtype)
    #torch.int32
    a
    tensor([1, 3], dtype=torch.int32)
    

    本节最后,上面4中数据类型一般够用,其他还有torch.int8,torch.uint8,torch.bool。如果不够?可以参考:
    https://pytorch.org/docs/stable/tensors.html

    类型转换

    下面会介绍两种方法:long()type(torch.int64),显然后者比较容易记住,是我比较喜欢的。

    a=torch.tensor([1,2],dtype=torch.int32)
    print(a)
    b=a.long()
    print(b)
    print(b.dtype)
    c=a.type(torch.int64)
    print(c)
    print(c.dtype)
    print(a)
    


    有人会疑惑,为什么第二行输出没有显示dtype=,因为默认就是torch.int64

    附录

    list,numpy,tensor之间相互转换的方法:

    a=[[1,2],[3,4]]#list
    print(a)
    b=np.array(a)#list->numpy
    print(b)
    c=torch.tensor(a)#list->tensor
    print(c)
    print(b.tolist())#numpy->list
    print(c.tolist())#tensor->list
    print(torch.tensor(b))#numpy->tensor
    print(torch.from_numpy(b))#同上
    print(c.numpy())#tensor->numpy
    


    对了,温馨提示,tensor可以在GPU上运行,其他两个都不可以,这就是为什么你用GPU运行的时候有时会报不是tensor的错误,必须先转化为tensor。
    还有,GPU上的tensor不能直接转为numpy,需要先放到CPU上。

    a=a.cpu()#放到CPU
    a.numpy()#这才对。
    

    放回GPU上?

    a=a.cuda()#放到GPU
    

    当然啦,cuda(),cpu()还有一种写法:

    device=torch.device("cpu")#torch.device("cuda")
    a=torch.tensor([1,2])
    a.to(device)#放到CPU
    

    完结撒花


    来源:音程

    物联沃分享整理
    物联沃-IOTWORD物联网 » 详解pytorch中的常见的Tensor数据类型以及类型转换

    发表评论