这些问题是我在写CS224N的Assignment 5中遇到的,主要涉及到Tensor的定义和形状问题。
Tensor的定义
当我们自己定义一个tensor时需要注意的问题:
device
在大部分情况下,你的程序都会在GPU上执行,那么在定义tensor时一定要加上device!
数据类型
A floating point scalar operand has dtype torch.get_default_dtype() and an integral non-boolean scalar operand has dtype torch.int64.
获取Tensor的默认数据类型:(这个方法指的是浮点数的默认类型,整数的默认类型就是torch.int64
)。
import torch
torch.get_default_dtype()
torch.float32
在Assignment 5作业中句子需要用<pad>来padding达到max word length以及相同的句子长度,padding过后的tensor,就需要用torch.long来表示了:
sents = torch.tensor(sents_padded, dtype=torch.long, device=device).contiguous()
PyTorch的Tensor数据类型如下: (参考:TORCH.TENSOR
Data type | dtype | CPU tensor | GPU tensor | |
---|---|---|---|---|
32-bit floating point | torch.float32 or torch.float | :class:torch.FloatTensor | :class:torch.cuda.FloatTensor | |
64-bit floating point | torch.float64 or torch.double | :class:torch.DoubleTensor | :class:torch.cuda.DoubleTensor | |
16-bit floating point | torch.float16 or torch.half | :class:torch.HalfTensor | :class:torch.cuda.HalfTensor | |
16-bit floating point | torch.bfloat16 | :class:torch.BFloat16Tensor | :class:torch.cuda.BFloat16Tensor | |
32-bit complex | torch.complex32 | |||
64-bit complex | torch.complex64 | |||
128-bit complex | torch.complex128 or torch.cdouble | |||
8-bit integer (unsigned) | torch.uint8 | :class:torch.ByteTensor | :class:torch.cuda.ByteTensor | |
8-bit integer (signed) | torch.int8 | :class:torch.CharTensor | :class:torch.cuda.CharTensor | |
16-bit integer (signed) | torch.int16 or torch.short | :class:torch.ShortTensor | :class:torch.cuda.ShortTensor | |
32-bit integer (signed) | torch.int32 or torch.int | :class:torch.IntTensor | :class:torch.cuda.IntTensor | |
64-bit integer (signed) | torch.int64 or torch.long | :class:torch.LongTensor | :class:torch.cuda.LongTensor | |
Boolean | torch.bool | :class:torch.BoolTensor | :class:torch.cuda.BoolTensor | |
quantized 8-bit integer (unsigned) | torch.quint8 | :class:torch.ByteTensor | / | |
quantized 8-bit integer (signed) | torch.qint8 | :class:torch.CharTensor | / | |
quantized 32-bit integer (signed) | torch.qfint32 | :class:torch.IntTensor | / | |
quantized 4-bit integer (unsigned) | torch.quint4x2 | :class:torch.ByteTensor | / |
其中torch.long
和torch.int64
等价,torch.float32
也是torch.float
。
Tensor的变形
为了让我们的输入输出符合形状,我们总是会在程序中遇到需要变换Tensor形状的情况。
view
官方文档:torch.view
```python
# 还是作业里的例子:
sents = torch.tensor(sents_padded, dtype=torch.long, device=device).contiguous()
reshape
官方文档:torch.reshape
```python
# 我们定义一个tensor,想象它是一个立方体,有六个面,每个面都是一个二维矩阵:
# 为了一目了然,我特地设置每个向量各不相同
cube = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[2, 3, 4], [5, 6, 7], [8, 9, 10]],
[[3, 4, 5], [6, 7, 8], [9, 10, 11]],
[[9, 8, 7], [6, 5, 4], [3, 2, 1]],
[[8, 7, 6], [5, 4, 3], [2, 1, 0]],
[[7, 6, 5], [4, 3, 2], [1, 0, -1]]])
先看一下cube的形状和数据类型:
print(cube.shape)
print(cube.dtype)
torch.Size([6, 3, 3])
torch.int64
接下来,我们将shape转换为(3, 6, 3)会怎样?是我们想象中的那样吗?
torch.reshape(cube, (3, 6, 3))
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[ 2, 3, 4],
[ 5, 6, 7],
[ 8, 9, 10]],
[[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[ 9, 8, 7],
[ 6, 5, 4],
[ 3, 2, 1]],
[[ 8, 7, 6],
[ 5, 4, 3],
[ 2, 1, 0],
[ 7, 6, 5],
[ 4, 3, 2],
[ 1, 0, -1]]])
可以看到reshape
就是按照顺序把数字填入了设定的形状中。
参考:PyTorch:view() 与 reshape() 区别详解
在官网上关于reshape
的说明指明,如果满足连续性条件,那么reshape
和view
等价,也就是说reshape
不会改变现有tensor,也不会新建tensor;如果不满足连续性条件,reshape
会新建一个tensor。
permute
```python
cube
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[ 2, 3, 4],
[ 5, 6, 7],
[ 8, 9, 10]],
[[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]],
[[ 9, 8, 7],
[ 6, 5, 4],
[ 3, 2, 1]],
[[ 8, 7, 6],
[ 5, 4, 3],
[ 2, 1, 0]],
[[ 7, 6, 5],
[ 4, 3, 2],
[ 1, 0, -1]]])
# 如果使用reshape:
torch.reshape(cube, (3, 6, 3))
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[ 2, 3, 4],
[ 5, 6, 7],
[ 8, 9, 10]],
[[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[ 9, 8, 7],
[ 6, 5, 4],
[ 3, 2, 1]],
[[ 8, 7, 6],
[ 5, 4, 3],
[ 2, 1, 0],
[ 7, 6, 5],
[ 4, 3, 2],
[ 1, 0, -1]]])
# 如果使用permute
cube.permute((1, 0, 2))
tensor([[[ 1, 2, 3],
[ 2, 3, 4],
[ 3, 4, 5],
[ 9, 8, 7],
[ 8, 7, 6],
[ 7, 6, 5]],
[[ 4, 5, 6],
[ 5, 6, 7],
[ 6, 7, 8],
[ 6, 5, 4],
[ 5, 4, 3],
[ 4, 3, 2]],
[[ 7, 8, 9],
[ 8, 9, 10],
[ 9, 10, 11],
[ 3, 2, 1],
[ 2, 1, 0],
[ 1, 0, -1]]])
可以看到,虽然两者的形状一样,但是值是不同的。reshape
还原了cube,而permute
则是变换了维度。所以千万不要以为两者等价混用。
文档信息
- 本文作者:weownthenight
- 本文链接:https://weownthenight.github.io/2021/07/21/PyTorch%E6%80%BB%E7%BB%93-view-reshape-permute/
- 版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)