1. 在模型定义的__init__函数中进行初始化:
self.rnn = nn.LSTM(input_size=embedding_size, hidden_size=128, num_layers=1, bidirectional=False)
for name, param in self.rnn.named_parameters():
if name.startswith("weight"):
nn.init.xavier_normal_(param)
else:
nn.init.zeros_(param)
class Net(nn.Module):
def __init__(self, input_channels, n_classes):
super(MouEtAl, self).__init__()
self.gru = nn.GRU(1, 64, 1, bidirectional=False)
self.gru_bn = nn.BatchNorm1d(64*input_channels)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
2. apply()
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
elif isinstance(m, GCNConv):
nn.init.xavier_normal_(m.lin.weight)
nn.init.xavier_uniform_(m.weight, gain=1.414)
elif isinstance(m, GatedGraphConv):
nn.init.xavier_normal_(m.weight)
model = Net( input_channels, n_classes ).to(device)
model.apply(weight_init)
或者把def weight_init()放到模型的__init__中,并在最后加上self.apply(self.weight_init)