下面是一段简单的神经网络的代码,super()函数的定义以前一直不太懂
class NetC(torch.nn.Module): # 定义神经网络 def __init__(self, n_feature, n_hidden, n_output): """ 初始化神经网络 参数: - n_feature: 输入特征的数量 - n_hidden: 隐藏层神经元的数量 - n_output: 输出层神经元的数量 """ super(NetC, self).__init__() self.h1 = nn.Linear(n_feature, n_hidden) self.relu1 = nn.ReLU() self.out = nn.Linear(n_hidden, n_output) self.softmax = nn.Softmax(dim=1) #定义前向运算 def forward(self, x): """ 前向传播函数 参数: - x: 输入数据 返回值: - out: 输出结果 """ # 得到的数据格式torch.Size([64, 1, 28, 28])需要转变为(64,784) x = x.view(x.size()[0],-1) # -1表示自动匹配 h1 = self.h1(x) a1 = self.relu1(h1) out = self.out(a1) a_out = self.softmax(out) return out
描述
super() 函数是用于调用父类(超类)的一个方法。
super() 是用来解决多重继承问题的,直接用类名调用父类方法在使用单继承的时候没问题,但是如果使用多继承,会涉及到查找顺序(MRO)、重复调用(钻石继承)等种种问题。
语法
以下是 super() 方法的语法:
super(type[, object-or-type])
参数
- type — 类。
- object-or-type — 类,一般是 self
Python3.x 和 Python2.x 的一个区别是: Python 3 可以使用直接使用 super().xxx 代替 super(Class, self).xxx :
参考链接: Python super() 函数 | 菜鸟教程 https://www.runoob.com/python/python-func-super.html