pytorch LeNet5网络的实现

大大卷 • 125 次浏览 • 1 个回复 • 2018年12月07日

LeNet网络是CNN的开创新网络,在这个网络中形成了现在CNN的主要框架,在这里实现了LeNet5网络的框架,与原文略有不同。

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.convnet = nn.Sequential(OrderedDict([
            ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),    # 6 * 28 * 28
            ('relu1', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)), # 6 * 14 * 14
            ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),       # 16 * 10 * 10
            ('relu3', nn.ReLU()),
            ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)), # 16 * 5 * 5
            ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),     # 这还是一个卷积,卷积和是5*5
            ('relu5', nn.ReLU())
        ]))

        self.fc = nn.Sequential(OrderedDict([
            ('f6', nn.Linear(120, 84)),         #   全连接
            ('relu6', nn.ReLU()),
            ('f7', nn.Linear(84, 10)),
            ('sig7', nn.LogSoftmax())
        ]))

    def forward(self, img):
        output = self.convnet(img)
        output = output.view(-1, 120)
        output = self.fc(output)
        return output

原创文章,转载请注明 :pytorch LeNet5网络的实现 - pytorch中文网
原文出处: https://ptorch.com/news/221.html
问题交流群 :168117787
提交评论
要回复文章请先登录注册
用户评论
Pytorch是什么?关于Pytorch! pytorch v1.0正式版发布,增加jit和C++ API以及全新的分布式包和Torch HUB