pytorch回归网络不能使用数据并行

Song • 492 次浏览 • 0 个回复 • 2018年04月27日

在具有 DataParalleldata_parallel() 的模块中使用 pack sequence -> recurrent network -> unpack sequence 模式时有一个非常微妙的地方。每个设备上的forward()的输入只会是整个输入的一部分。由于默认情况下,解包操作 torch.nn.utils.rnn.pad_packed_sequence() 仅填充到其所见的最长输入,即该特定设备上的最长输入,所以在将结果收集在一起时会发生尺寸的不匹配。因此,您可以利用pad_packed_sequence()total_length参数来确保forward()调用返回相同长度

的序列。例如,你可以写:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_squence

class MyModule(nn.Module):
    #  ... __init__, 以及其他访求

    # padding_input 的形状是[B x T x *](batch_first 模式),包含按长度排序的序列
    # B 是批量大小
    # T 是最大序列长度
    def forward(self, padded_input, input_lengths):
        total_length = padded_input.size(1)  # 获取最大序列长度
        packed_input = pack_padded_sequence(padded_input, input_lengths,
                                            batch_first=True)
        packed_output, _ = self.my_lstm(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True,
                                        total_length=total_length)
        return output

m = MyModule().cuda()
dp_m = nn.DataParallel(m)

此外,在批量的维度为dim 1 (第1轴)(即batch_first = False)时需要额外注意数据的并行性。在这种情况下,pack_padded_sequence 函数的的第一个参数 padding_input 维度将是[T x B x *],并且应该沿dim 1 (第1轴)分散,但第二个参数 input_lengths 的维度为 [B],应该沿dim 0 (第0轴)分散。需要额外的代码来操纵张量的维度。


原创文章,转载请注明 :pytorch回归网络不能使用数据并行 - pytorch中文网
原文出处: https://ptorch.com/news/165.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
  • 没有评论
Pytorch是什么?关于Pytorch! pytorch Windows常见问题汇总