torch.nn 之 Module、ModuleList 和 Sequential
本文参考自 神经网络工具箱 torch.nn 之 Module、ModuleList 和 Sequential (opens new window)
PyTorch 提供了集成度更高的模块化接口 torch.nn
,该接口构建于 Autograd 之上,提供了网络模块、优化器和初始化策略等一系列功能。
PyTorch
把与深度学习模型搭建相关的全部类全部在 torch.nn
这个子模块中。根据类的功能分类,常用的有如下部分:
- Containers:容器类,如
torch.nn.Module
;torch.nn.ModuleList
;torch.nn.Sequential()
; - Convolution Layers:卷积层,如
torch.nn.Conv2d
; - Pooling Layers:池化层,如
torch.nn.MaxPool2d
; - Non-linear activations:非线性激活层,如
torch.nn.ReLU
; - Normalization layers:归一化层,如
torch.nn.BatchNorm2d
; - Recurrent layers:循环神经层,如
torch.nn.LSTM
; - Transformer layers:transformer 层,如
torch.nn.TransformerEncoder
; - Linear layers:线性连接层,如
torch.nn.Linear
; - Dropout layers:dropout 层,如
torch.nn.Dropout
; - ……
本次主要讲解 torch.nn.Module
、torch.nn.ModuleList
、torch.nn.Sequential()
。
# 1. nn.Module()
nn.Module
是 PyTorch 提供的神经网络类,并在类中实现了网络各层的定义及前向计算与反向传播机制。在实际使用时,如果想要实现某个神经网络,只需继承 nn.Module
,在初始化中定义模型结构与参数,在函数 forward()
中编写网络前向过程即可。
示例 1:
from torch import nn
class MnistLogistic(nn.Module):
def __init__(self):
super().__init__() # 调用nn.Module的构造函数
# 使用 nn.Parameter 来构造需要学习的参数
self.weights = nn.Parameter(torch.randn(784, 10) / math.sqrt(784))
self.bias = nn.Parameter(torch.zeros(10))
def forward(self, xb):
# 在forward中实现前向传播过程
return xb @ self.weights + self.bias
2
3
4
5
6
7
8
9
10
11
12
notice:
- 在类的
__init__()
中需要定义网络学习的参数,在此使用nn.Parameter()
函数定义了参数weights
和bias
,这是一种特殊的Tensor
的构造方法,默认需要求导,即requires_grad
为True
。 - 在 PyTorch 中,还有一个库为
nn.functional
,同样也提供了很多网络层与函数功能,但与nn.Module
不同的是,利用nn.functional
定义的网络层不可自动学习参数,还需要使用nn.Parameter
封装。nn.functional
的设计初衷是对于一些不需要学习参数的层,如激活层、BN(Batch Normalization)层,可以使用nn.functional
,这样这些层就不需要在nn.Module
中定义了。
示例 2:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
2
3
4
5
6
7
8
9
10
11
12
# 2. nn.ModuleList()
ModuleList
就像一个普通的 Python 的List
,我们可以使用下标来访问它,好处是传入的ModuleList
的所有Module
都会注册到 PyTorch 里,这样 Optimizer
就能找到其中的参数,从而用梯度下降进行更新。但是nn.ModuleList
并不是Module
(的子类),因此它没有forward
等方法,通常会被放到某个Module
里。
示例 3:
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.combine = nn.ModuleList()
self.combine.append(nn.Linear(200,100))
self.combine.append(nn.Linear(100,50))
Net = MyNet()
print(Net)
2
3
4
5
6
7
8
9
10
11
12
13
14
结果如下:
MyNet(
(combine): ModuleList(
(0): Linear(in_features=200, out_features=100, bias=True)
(1): Linear(in_features=100, out_features=50, bias=True)
)
)
2
3
4
5
6
可以看到 PyTorch 自动识别 nn.ModuleList
中的参数,注意如果换成普通的 python list 则无法识别。
注意:nn.ModuleList
并没有定义一个网络,而是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。
示例 4:
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.combine = nn.ModuleList()
self.combine.append(nn.Linear(10,100))
self.combine.append(nn.Linear(100,60))
self.combine.append(nn.Linear(50,10))
def forward(self, x):
x = self.combine[2](x)
x = self.combine[0](x)
x = self.combine[1](x)
return x
MyNet = MyNet()
print(MyNet)
input = torch.randn(32, 50)
print('MyNet(input).shape: ',MyNet(input).shape)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
结果如下:
MyNet(
(combine): ModuleList(
(0): Linear(in_features=10, out_features=100, bias=True)
(1): Linear(in_features=100, out_features=60, bias=True)
(2): Linear(in_features=50, out_features=10, bias=True)
)
)
MyNet(input).shape: torch.Size([32, 60])
2
3
4
5
6
7
8
# 3. nn.Sequential()
当模型中只是简单的前馈网络时,即上一层的输出直接作为下一层的输入,这时可以采用nn.Sequential()
模块来快速搭建模型,而不必手动在forward()
函数中一层一层地前向传播。因此,如果想快速搭建模型而不考虑中间过程的话,推荐使用nn.Sequential()
模块。
接下来用nn.Sequential()
改写nn.Module()
中的示例 2,示例 5 如下:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.Conv2d(20, 20, 5)
)
def forward(self, x):
x = self.layer(x)
return x
Model = Model()
print(Model)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Sequential 的第二种写法 —— 指定每个 module 的名字,也就是通过 seq.add_module(层名, 层 class 的实例)
:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# 第二种
self.layer = nn.Sequential()
# 添加名字和层
self.layer.add_module('conv1', nn.Conv2d(1, 20, 5))
self.layer.add_module('conv2', nn.Conv2d(20, 20, 5))
def forward(self, x):
x = self.layer(x)
return x
Model = Model()
print(Model)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Sequential 的第三种写法:nn.Sequential(OrderedDict([多个(层名,层class的实例)]))
:
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 20, 5)),
('conv2', nn.Conv2d(20, 20, 5))
]))
def forward(self, x):
x = self.layer(x)
return x
Model = Model()
print(Model)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
总结:
- 可以把
Sequential
当作list
来看,允许构建序列化的模块。 nn.sequential()
是一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行
# 4. nn.ModuleList 和 nn.Sequential 的区别
示例 3 采用 ModuleList
,没有实现 forward()
方法,若传入输入数据时,会发生什么?请看示例 5:
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.combine = nn.ModuleList()
self.combine.append(nn.Linear(200,100))
self.combine.append(nn.Linear(100,50))
Net = MyNet()
inputs = torch.ones(200)
outputs = Net(inputs)
print(outputs)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
显示报错:NotImplementedError: Module [MyNet] is missing the required "forward" function
。
若按照如下示例 6 实现 forward()
方法,会发生什么呢?
import torch
import torch.nn as nn
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.combine = nn.ModuleList()
self.combine.append(nn.Linear(200,100))
self.combine.append(nn.Linear(100,50))
def forward(self, x):
x = self.combine(x)
return x
Net = MyNet()
inputs = torch.ones(200)
outputs = Net(inputs)
print(outputs)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
仍然显示跟 示例 9 一样的报错信息 NotImplementedError: Module [MyNet] is missing the required "forward" function
。
因为 nn.ModuleList
是一个无序性的序列,不能直接通过 x = self.combine(x)
来实现 forward()
。可按照示例 5 的写法实现。
总结:
nn.Sequential
定义的网络中各层会按照定义的顺序进行级联,因此需要保证各层的输入和输出之间要衔接。nn.Sequential
实现了farward()
方法,因此可以直接通过类似于x=self.combine(x)
实现forward()
。nn.ModuleList
则没有顺序性要求,并且也没有实现forward()
方法。