很简单的一个模型,参照github上的代码做的,分类效果并不是很好
from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import time
%matplotlib inline
from IPython import display
input_size = 2
output_size = 1
learning_rate = 0.00001
def load_data(filename):
with open(filename, 'r') as f:
data=[]
line=f.readline()
for line in f:
line=line.strip().split()
x1=float(line[0])
x2=float(line[1])
t=int(line[2])
data.append([x1,x2,t])
return np.array(data)
train_file = 'data/train_linear.txt'
test_file = 'data/test_linear.txt'
data_train = load_data(train_file) # 数据格式[x1, x2, t]
data_test = load_data(test_file)
print(type(data_train))
print(data_train.shape)
print(data_test.shape)
<class 'numpy.ndarray'>
(200, 3)
(200, 3)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(input_size, output_size) # One in and one out
def forward(self, x):
y_pred = self.linear(x)
return y_pred
x_train=data_train[:,:2]
y_train=data_train[:,2:3]
X = torch.FloatTensor(x_train)
Y = torch.FloatTensor(y_train)
X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1
N = len(Y)
model=Model()
optimizer = optim.SGD(model.parameters(), lr=0.01)
model.train()
for epoch in range(50):
perm = torch.randperm(N)
sum_loss = 0
for i in range(0, N):
x = X[perm[i : i + 1]]
y = Y[perm[i : i + 1]]
optimizer.zero_grad()
output = model(x)
loss = torch.mean(torch.clamp(1 - output.t() * y, min=0)) # hinge loss
loss += 0.01 * torch.mean(model.linear.weight ** 2) # l2 penalty
loss.backward()
optimizer.step()
sum_loss += loss.data.cpu().numpy()
print("Epoch:{:4d}\tloss:{}".format(epoch, sum_loss / N))
Epoch: 0 loss:0.3788770362269133
Epoch: 1 loss:0.18692774163559078
Epoch: 2 loss:0.16527784419246017
Epoch: 3 loss:0.1560688596777618
Epoch: 4 loss:0.15181147237773984
Epoch: 5 loss:0.14811894194222985
Epoch: 6 loss:0.14536839919630437
Epoch: 7 loss:0.14435229473747313
Epoch: 8 loss:0.14317100788466633
Epoch: 9 loss:0.14245451985858382
Epoch: 10 loss:0.14157551057636739
Epoch: 11 loss:0.14109212066046894
Epoch: 12 loss:0.1407695705164224
Epoch: 13 loss:0.1400472893193364
Epoch: 14 loss:0.13964955242350696
Epoch: 15 loss:0.1392783078365028
Epoch: 16 loss:0.138907381426543
Epoch: 17 loss:0.1387982941698283
Epoch: 18 loss:0.13861130747012795
Epoch: 19 loss:0.13863415602594614
Epoch: 20 loss:0.13828472116030752
Epoch: 21 loss:0.13822870085015893
Epoch: 22 loss:0.13854863058775663
Epoch: 23 loss:0.1381820786278695
Epoch: 24 loss:0.13801106195896864
Epoch: 25 loss:0.13826215670444073
Epoch: 26 loss:0.13842669501900673
Epoch: 27 loss:0.13817614743486048
Epoch: 28 loss:0.1382398795802146
Epoch: 29 loss:0.13823835723102093
Epoch: 30 loss:0.1382795405294746
Epoch: 31 loss:0.1377215893007815
Epoch: 32 loss:0.13821476998738944
Epoch: 33 loss:0.13820640606805681
Epoch: 34 loss:0.13815249134786428
Epoch: 35 loss:0.13808728583157062
Epoch: 36 loss:0.1381826154794544
Epoch: 37 loss:0.138189296182245
Epoch: 38 loss:0.13807438237592579
Epoch: 39 loss:0.13820569985546172
Epoch: 40 loss:0.13821616280823945
Epoch: 41 loss:0.13809766220860184
Epoch: 42 loss:0.13808201501145959
Epoch: 43 loss:0.138220365839079
Epoch: 44 loss:0.13818617248907686
Epoch: 45 loss:0.13783584020100534
Epoch: 46 loss:0.13833872705698014
Epoch: 47 loss:0.13807488267309964
Epoch: 48 loss:0.13824151220731437
Epoch: 49 loss:0.13813724058680235
W = model.linear.weight[0].data.cpu().numpy()
b = model.linear.bias[0].data.cpu().numpy()
print(W,b)
delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))
print(xy)
z = (W.dot(xy) + b).reshape(x.shape)
z[np.where(z > 1.)] = 4
z[np.where((z > 0.) & (z <= 1.))] = 3
z[np.where((z > -1.) & (z <= 0.))] = 2
z[np.where(z <= -1.)] = 1
plt.xlim([X[:, 0].min() + delta, X[:, 0].max() - delta])
plt.ylim([X[:, 1].min() + delta, X[:, 1].max() - delta])
plt.contourf(x, y, z, alpha=0.8, cmap="Greys")
plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10)
plt.tight_layout()
plt.show()
[ 1.7399527 -1.2409829] 1.0133485
[array([-2.75684214, -2.74684215, -2.73684216, ..., 1.66315365,
1.67315364, 1.68315363]), array([-1.42996836, -1.42996836, -1.42996836, ..., 2.42002797,
2.42002797, 2.42002797])]