PyTorch周围的实用程序库Inferno

Song • 607 次浏览 • 0 个回复 • 2018年02月04日

Inferno

Inferno是一个提供PyTorch实用程序和便利功能/类的小型库。程序库还在开发中,第一个稳定版本是0.2

特征

目前的功能包括:

  • 一个基本的Trainer类来封装训练样板(迭代/时代循环,验证和检查点创建),
  • 一个图形API,由networkx提供,具有复杂体系结构的模型。
  • 简单的数据并行性在多个GPU上,
  • 一个用于torch.nn.Module级参数初始化的子模块,
  • 数据预处理/转换子模块,
  • 支持Tensorboard(最好ATLEAST tensorflow CPU的安装)
  • 一个回调API,以便与教练进行灵活的交互,
  • 各种实用程序层正在进行中,
  • 体积数据集的子模块等等。
import torch.nn as nn
from inferno.io.box.cifar import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten

# Fill these in:
LOG_DIRECTORY = '...'
SAVE_DIRECTORY = '...'
DATASET_DIRECTORY = '...'
DOWNLOAD_CIFAR = True
USE_CUDA = True

# Build torch model
model = nn.Sequential(
    ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
    nn.MaxPool2d(kernel_size=2, stride=2),
    Flatten(),
    nn.Linear(in_features=(256 * 4 * 4), out_features=10),
    nn.Softmax()
)

# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
                                                    download=DOWNLOAD_CIFAR)

# Build trainer
trainer = Trainer(model) \
  .build_criterion('CrossEntropyLoss') \
  .build_metric('CategoricalError') \
  .build_optimizer('Adam') \
  .validate_every((2, 'epochs')) \
  .save_every((5, 'epochs')) \
  .save_to_directory(SAVE_DIRECTORY) \
  .set_max_num_epochs(10) \
  .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
                                  log_images_every='never'),
                log_directory=LOG_DIRECTORY)

# Bind loaders
trainer \
    .bind_loader('train', train_loader) \
    .bind_loader('validate', validate_loader)

if USE_CUDA:
  trainer.cuda()

# Go!
trainer.fit()

为了显示训练进度,导航到LOG_DIRECTORY并使用张量板

$ tensorboard --logdir=${PWD} --port=6007

并使用浏览器导航到localhost:6007

安装

适用于linuxmacconda软件包(只有python 3)可以通过如下代码:

$ conda install -c inferno-pytorch inferno

开发计划:

计划的功能包括:

  • 一个封装Hogwild的类!在多个GPU上进行训练,
  • 最小形状推断的dry-run
  • 正确的包装和文件,

原创文章,转载请注明 :PyTorch周围的实用程序库Inferno - pytorch中文网
原文出处: https://ptorch.com/news/122.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
  • 没有评论
Pytorch是什么?关于Pytorch! PyTorch周围的实用程序库Inferno