广播语义

许多pytorch操作都支持NumPy广播语义

简而言之,如果Pytorch操作支持广播,则其张量参数可以自动扩展为相同大小(不需要复制数据)。

一般语义

如果pytorch张量满足以下条件,那么就可以广播:

  • 每个张量至少有一个维度。
  • 在遍历维度大小时, 从尾部维度开始遍历, 并且二者维度必须相等, 它们其中一个要么是1要么不存在.

例如:

>>> x=torch.FloatTensor(5,7,3)
>>> y=torch.FloatTensor(5,7,3)
# 相同形状的质量可以被广播(上述规则总是成立的)

>>> x=torch.FloatTensor()
>>> y=torch.FloatTensor(2,2)
# x和y不能被广播,因为x没有维度

# can line up trailing dimensions
>>> x=torch.FloatTensor(5,3,4,1)
>>> y=torch.FloatTensor(  3,1,1)
# x和y可以被广播
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# 但是:
>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(  3,1,1)
# x和y不能被广播,因为在`3rd`中
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果xy可以被广播,得到的张量大小的计算方法如下:

  • 如果维数xy不相等,在维度较少的张量的维度前加上1使它们相等的长度。
  • 然后,对于每个维度的大小,生成维度的大小是xy的最大值。

例如:

# 可以排列尾部维度,使阅读更容易
>>> x=torch.FloatTensor(5,1,4,1)
>>> y=torch.FloatTensor(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# 但不是必要的:
>>> x=torch.FloatTensor(1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.FloatTensor(5,2,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

直接语义(In-place语义)

一个复杂的问题是in-place操作不允许in-place张量像广播那样改变形状。

例如:

>>> x=torch.FloatTensor(5,3,4,1)
>>> y=torch.FloatTensor(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.FloatTensor(1,3,1)
>>> y=torch.FloatTensor(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向后兼容性

以前版本的PyTorch只要张量中的元素数量是相等的, 便允许某些点状pointwise函数在不同的形状的张量上执行, 其中点状操作是通过将每个张量视为1维来执行。PyTorch现在支持广播语义并且不推荐使用点状函数操作向量,并且张量不支持广播但具有相同数量的元素将生成一个Python警告。

注意,广播的引入可能会导致向后不兼容,因为两个张量的形状不同,但是可以被广播且具有相同数量的元素。

例如:

>>> torch.add(torch.ones(4,1), torch.randn(4))

事先生成一个torch.Size([4,1])的张量,然后再提供一个torch.Size([4,4])的张量。为了帮助识别你代码中可能存在的由引入广播语义的向后不兼容情况, 您可以设置torch.utils.backcompat.broadcast_warning.enabledTrue,这种情况下会生成一个python警告。

例如:

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.