Pytorch的高级训练,数据增强和实用程序(torchsample/Keras)

Song • 660 次浏览 • 0 个回复 • 2018年02月01日

Pytorch的高级训练,数据增强和实用程序

v0.1.3刚刚发布 - 包含重大改进,错误修复和其他支持。从版本获取它,或者拉出主分支。

这个包提供了一些东西:

  • Keras类似的高级模块,带有回调,约束和规则的训练。
  • 全面的数据增强,转换,采样和加载
  • 效用张量和变量函数,所以你不经常需要numpy

有任何功能要求?可以提交问题!我会做到这一点。特别是,数据增加,数据加载或采样功能。

想贡献?检查问题页面 标有[捐助欢迎]的标签。 项目地址ncullen93/torchsample

ModuleTrainer

ModuleTrainer类提供了一个高级训练界面,在提供回调,约束,初始化程序,规则化程序等的同时提取训练循环。

例:

from torchsample.modules import ModuleTrainer

# Define your model EXACTLY as normal
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(1600, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 1600)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

model = Network()
trainer = ModuleTrainer(model)

trainer.compile(loss='nll_loss',
                optimizer='adadelta')

trainer.fit(x_train, y_train, 
            val_data=(x_test, y_test),
            num_epoch=20, 
            batch_size=128,
            verbose=1)

您还可以访问标准的评估和预测功能:

loss = model.evaluate(x_train, y_train)
y_pred = model.predict(x_train)

Torchsample提供了广泛的回调,通常模仿Keras中的以下接口:

  • EarlyStopping
  • ModelCheckpoint
  • LearningRateScheduler
  • ReduceLROnPlateau
  • CSVLogger
from torchsample.callbacks import EarlyStopping

callbacks = [EarlyStopping(monitor='val_loss', patience=5)]
model.set_callbacks(callbacks)

Torchsample还提供regularizers

  • L1Regularizer
  • L2Regularizer
  • L1L2Regularizer 和约束:

  • UnitNorm
  • MaxNorm
  • NonNeg 正则表达式和module_filter参数都可以选择性地将规则化器和约束条件应用于层。约束可以是以任意批次或时期频率应用的显式(硬)约束,或者可以是与正规化器相似的隐式(软)约束,其中约束偏差作为惩罚被添加到总模型损失。
from torchsample.constraints import MaxNorm, NonNeg
from torchsample.regularizers import L1Regularizer

# hard constraint applied every 5 batches
hard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*')
# implicit constraint added as a penalty term to model loss
soft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*')
constraints = [hard_constraint, soft_constraint]
model.set_constraints(constraints)

regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')]
model.set_regularizers(regularizers)

你也可以直接适应一个torch.utils.data.DataLoader,也可以有一个验证集:

from torchsample import TensorDataset
from torch.utils.data import DataLoader

train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32)

val_dataset = TensorDataset(x_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=32)

trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100)

实用功能

最后,torchsample采样提供了一些不常见的实用功能:

张量函数

  • th_iterproduct (模仿itertools.product
  • th_gather_ndtorch.gatherN维版本)
  • th_random_choice (模仿np.random.choice
  • th_pearsonr (模仿scipy.stats.pearsonr
  • th_corrcoef (模仿np.corrcoef
  • th_affine2dth_affine3dTorch上的仿射变换。传感器)

变量函数

  • F_affine2dF_affine3d
  • F_map_coordinates2dF_map_coordinates3d

, 数据扩充和数据集

torchsample包提供了大量数据加载和转换工具,可以在数据加载过程中使用。该软件包还提供了灵活性 TensorDatasetFolderDataset类来处理大多数数据集的需求。

Torch变换

这些转换直接在Torch张量上工作

  • Compose()
  • AddChannel()
  • SwapDims()
  • RangeNormalize()
  • StdNormalize()
  • Slice2D()
  • RandomCrop()
  • SpecialCrop()
  • Pad()
  • RandomFlip()
  • ToTensor()

仿射变换

pytprch仿射变换

以下转换对Torch张量执行仿射(仿射)变换。

  • Rotate()
  • Translate()
  • Shear()
  • Zoom() 我们还提供了一个将多个仿射变换串联在一起的类,以便只进行一次插值:

  • Affine()
  • AffineCompose()

    数据集和抽样

    我们提供以下数据集,这些数据集提供通用结构和迭代器,用于对内存中或内存不足数据进行采样和使用变换:

  • TensorDataset()
  • FolderDataset()

原创文章,转载请注明 :Pytorch的高级训练,数据增强和实用程序(torchsample/Keras) - pytorch中文网
原文出处: https://ptorch.com/news/121.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
  • 没有评论
Pytorch是什么?关于Pytorch! PyTorch周围的实用程序库Inferno