多GPU示例

数据并行是当我们将小批量样品分成多个较小的批量批次,并且对每个较小的小批量并行运行计算。

数据并行使用torch.nn.DataParallel。一个可以包装一个模块DataParallel,它将在批量维度中的多个GPU并行化。

数据并行

import torch.nn as nn

class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

在CPU模式下不需要更改代码。

DataParallel的文档在 这里

DataParallel实现的基元:

一般来说,pytorchnn.parallel原语可以独立使用。我们实现了简单的类似MPI的原语:

  • 复制:在多个设备上复制模块
  • 散点:在第一维中分配输入
  • 收集:收集并连接第一维中的输入
  • parallel_apply:将一组已经分布的输入应用于一组已经分布的模型。 为了给出更好的清晰度,这里的功能data_parallel使用这些集合

    def data_parallel(module, input, device_ids, output_device=None):
    if not device_ids:
        return module(input)
    
    if output_device is None:
        output_device = device_ids[0]
    
    replicas = nn.parallel.replicate(module, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    return nn.parallel.gather(outputs, output_device)

    部分型号在CPU上,部分在GPU上

    我们来看一个实现网络的一个小例子,其中一部分是在CPU上,部分在GPU上

    class DistributedModel(nn.Module):
    
    def __init__(self):
        super().__init__(
            embedding=nn.Embedding(1000, 10),
            rnn=nn.Linear(10, 10).cuda(0),
        )
    
    def forward(self, x):
        # Compute embedding on CPU
        x = self.embedding(x)
    
        # Transfer to GPU
        x = x.cuda(0)
    
        # Compute RNN on GPU
        x = self.rnn(x)
        return x

    这是PyTorch对前Torch用户的一个小介绍。还有更多的学习。 看看我们更全面的入门教程,介绍optim包,数据加载器等:使用PyTorch进行深度学习:60分钟闪电
    还看看:

  • 训练神经网络玩视频游戏
  • 在imagenet上培训最先进的ResNet网络
  • 使用生成对抗网络训练面部发生器
  • 使用Recurrent LSTM网络训练一个单词级语言模型
  • 更多例子
  • 更多教程
  • 在论坛上讨论PyTorch
  • 在Slack上与其他用户聊天

脚本的总运行时间:(0分0.002秒)

下载Python源代码:parallelism_tutorial.py

下载jupyter笔记:parallelism_tutorial.ipynb