PyTorch流行的预训练模型和数据集列表pytorch-playground

Song • 338 次浏览 • 0 个回复 • 2018年05月13日

pytorch-playground包含基础预训练模型和pytorch中的数据集(MNIST,SVHN,CIFAR10,CIFAR100,STL10,AlexNet,VGG16,VGG19,ResNet,Inception,SqueezeNet)

这是pytorch初学者的游乐场(即资源列表,你可以随意使用如下模型),其中包含流行数据集的预定义模型。目前支持如下模型:

  • mnist,svhn
  • cifar10,cifar100
  • stl10
  • alexnet
  • vgg16,vgg16_bn,vgg19,vgg19_bn
  • resnet18,resnet34,resnet50,resnet101,resnet152
  • squeezenet_v0,squeezenet_v1
  • inception_v3

下面是MNIST数据集的例子。下面的代码将自动下载数据集和预先训练的模型。

import torch
from torch.autograd import Variable
from utee import selector
model_raw, ds_fetcher, is_imagenet = selector.select('mnist')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
    data =  Variable(torch.FloatTensor(data)).cuda()
    output = model_raw(data)

另外,如果想在mnist上训练MLP模型,只需运行python mnist/train.py即可。

一、安装

二、ImageNet数据集

我们提供224x224x3大小的预训练imagenet验证数据集。我们首先将较短尺寸的图像调整为256,然后在中心剪裁224x224图像。然后我们将裁剪后的图像编码为jpg字符串并转储到pickle

三、量化

我们还提供了一个简单的演示,使用几种方法将这些模型量化为指定的位宽,包括线性方法,最小最大值方法和非线性方法。

python quantize.py --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1

四、Top1准确度

我们用线性量化方法评估流行数据集和模型的性能。BN中的运行均值和运行方差的比特宽度对于所有结果都是10比特。(32-float除外)

模型32-float12-bit10-bit8-bit6-bit
MNIST98.4298.4398.4498.4498.32
SVHN96.0396.0396.0496.0295.46
CIFAR1093.7893.7993.8093.5890.86
CIFAR10074.2774.2174.1973.7066.32
STL1077.5977.6577.7077.5973.40
AlexNet55.70/78.4255.66/78.4155.54/78.3954.17/77.2918.19/36.25
VGG1670.44/89.4370.45/89.4370.44/89.3369.99/89.1753.33/76.32
VGG1971.36/89.9471.35/89.9371.34/89.8870.88/89.6256.00/78.62
ResNet1868.63/88.3168.62/88.3368.49/88.2566.80/87.2019.14/36.49
ResNet3472.50/90.8672.46/90.8272.45/90.8571.47/90.0032.25/55.71
ResNet5074.98/92.1774.94/92.1274.91/92.0972.54/90.442.43/5.36
ResNet10176.69/93.3076.66/93.2576.22/92.9065.69/79.541.41/1.18
ResNet15277.55/93.5977.51/93.6277.40/93.5474.95/92.469.29/16.75
SqueezeNetV056.73/79.3956.75/79.4056.70/79.2753.93/77.0414.21/29.74
SqueezeNetV156.52/79.1356.52/79.1556.24/79.0354.56/77.3317.10/32.46
InceptionV376.41/92.7876.43/92.7176.44/92.7373.67/91.341.50/4.82


注意:`ImageNet 32-float`模型直接来自`torchvision` ### 五、定义参数 在`quantize.py`可以定义参数
参数默认值描述 & 参数
typecifar10mnist,svhn,cifar10,cifar100,stl10,alexnet,vgg16,vgg16_bn,vgg19,vgg19_bn,resent18,resent34,resnet50,resnet101,resnet152,squeezenet_v0,squeezenet_v1,inception_v3
quant_methodlinearquantization method:linear,minmax,log,tanh
param_bits8bit-width of weights and bias
fwd_bits8bit-width of activation
bn_bits32bit-width of running mean and running vairance
overflow_rate0.0overflow rate threshold for linear quantization method
n_samples20number of samples to make statistics for activation


项目地址:[aaron-xichen/pytorch-playground](https://github.com/aaron-xichen/pytorch-playground)
原创文章,转载请注明 :PyTorch流行的预训练模型和数据集列表pytorch-playground - pytorch中文网
原文出处: https://ptorch.com/news/171.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
  • 没有评论
Pytorch是什么?关于Pytorch! pytorch中autograd以及hook函数详解