Pytorch实现SVM二分类

很简单的一个模型,参照github上的代码做的,分类效果并不是很好

from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim

import time
%matplotlib inline
from IPython import display

input_size = 2
output_size = 1
learning_rate = 0.00001
def load_data(filename):
    with open(filename, 'r') as f:
        data=[]
        line=f.readline()
        for line in f:
            line=line.strip().split()
            x1=float(line[0])
            x2=float(line[1])
            t=int(line[2])
            data.append([x1,x2,t])
        return np.array(data)

train_file = 'data/train_linear.txt'
test_file = 'data/test_linear.txt'
data_train = load_data(train_file)  # 数据格式[x1, x2, t]
data_test = load_data(test_file)
print(type(data_train))
print(data_train.shape)
print(data_test.shape)

<class 'numpy.ndarray'>
(200, 3)
(200, 3)
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = nn.Linear(input_size, output_size) # One in and one out

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

x_train=data_train[:,:2]
y_train=data_train[:,2:3]
X = torch.FloatTensor(x_train)
Y = torch.FloatTensor(y_train)

X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1

N = len(Y)
model=Model()
optimizer = optim.SGD(model.parameters(), lr=0.01)
model.train()
for epoch in range(50):
    perm = torch.randperm(N)
    sum_loss = 0
    for i in range(0, N):
        x = X[perm[i : i + 1]]
        y = Y[perm[i : i + 1]]

        optimizer.zero_grad()
        output = model(x)

        loss = torch.mean(torch.clamp(1 - output.t() * y, min=0))  # hinge loss
        loss += 0.01 * torch.mean(model.linear.weight ** 2)  # l2 penalty
        loss.backward()
        optimizer.step()

        sum_loss += loss.data.cpu().numpy()

    print("Epoch:{:4d}\tloss:{}".format(epoch, sum_loss / N))
Epoch:   0  loss:0.3788770362269133
Epoch:   1  loss:0.18692774163559078
Epoch:   2  loss:0.16527784419246017
Epoch:   3  loss:0.1560688596777618
Epoch:   4  loss:0.15181147237773984
Epoch:   5  loss:0.14811894194222985
Epoch:   6  loss:0.14536839919630437
Epoch:   7  loss:0.14435229473747313
Epoch:   8  loss:0.14317100788466633
Epoch:   9  loss:0.14245451985858382
Epoch:  10  loss:0.14157551057636739
Epoch:  11  loss:0.14109212066046894
Epoch:  12  loss:0.1407695705164224
Epoch:  13  loss:0.1400472893193364
Epoch:  14  loss:0.13964955242350696
Epoch:  15  loss:0.1392783078365028
Epoch:  16  loss:0.138907381426543
Epoch:  17  loss:0.1387982941698283
Epoch:  18  loss:0.13861130747012795
Epoch:  19  loss:0.13863415602594614
Epoch:  20  loss:0.13828472116030752
Epoch:  21  loss:0.13822870085015893
Epoch:  22  loss:0.13854863058775663
Epoch:  23  loss:0.1381820786278695
Epoch:  24  loss:0.13801106195896864
Epoch:  25  loss:0.13826215670444073
Epoch:  26  loss:0.13842669501900673
Epoch:  27  loss:0.13817614743486048
Epoch:  28  loss:0.1382398795802146
Epoch:  29  loss:0.13823835723102093
Epoch:  30  loss:0.1382795405294746
Epoch:  31  loss:0.1377215893007815
Epoch:  32  loss:0.13821476998738944
Epoch:  33  loss:0.13820640606805681
Epoch:  34  loss:0.13815249134786428
Epoch:  35  loss:0.13808728583157062
Epoch:  36  loss:0.1381826154794544
Epoch:  37  loss:0.138189296182245
Epoch:  38  loss:0.13807438237592579
Epoch:  39  loss:0.13820569985546172
Epoch:  40  loss:0.13821616280823945
Epoch:  41  loss:0.13809766220860184
Epoch:  42  loss:0.13808201501145959
Epoch:  43  loss:0.138220365839079
Epoch:  44  loss:0.13818617248907686
Epoch:  45  loss:0.13783584020100534
Epoch:  46  loss:0.13833872705698014
Epoch:  47  loss:0.13807488267309964
Epoch:  48  loss:0.13824151220731437
Epoch:  49  loss:0.13813724058680235
W = model.linear.weight[0].data.cpu().numpy()
b = model.linear.bias[0].data.cpu().numpy()
print(W,b)
delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))
print(xy)
z = (W.dot(xy) + b).reshape(x.shape)
z[np.where(z > 1.)] = 4
z[np.where((z > 0.) & (z <= 1.))] = 3
z[np.where((z > -1.) & (z <= 0.))] = 2
z[np.where(z <= -1.)] = 1

plt.xlim([X[:, 0].min() + delta, X[:, 0].max() - delta])
plt.ylim([X[:, 1].min() + delta, X[:, 1].max() - delta])
plt.contourf(x, y, z, alpha=0.8, cmap="Greys")
plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10)
plt.tight_layout()
plt.show()
[ 1.7399527 -1.2409829] 1.0133485
[array([-2.75684214, -2.74684215, -2.73684216, ...,  1.66315365,
        1.67315364,  1.68315363]), array([-1.42996836, -1.42996836, -1.42996836, ...,  2.42002797,
        2.42002797,  2.42002797])]
output_5_1.png
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,033评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,725评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,473评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,846评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,848评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,691评论 1 282
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,053评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,700评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,856评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,676评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,787评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,430评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,034评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,990评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,218评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,174评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,526评论 2 343

推荐阅读更多精彩内容

  • 雨垂直飘落拔弄起大地这架大钢琴琴键声响起惊醒沉睡的风风看着雨独自享乐玩心顿起托起树叶和花瓣凌乱了雨的节奏我在江畔漫...
    天元_ae03阅读 2,612评论 56 64
  • 10月17日,周三,晴。今晚国一作业写的还行,除了因为字不好让她重写的,不到八点就完成了。我们俩一起分析了一会语文...
    国一妈妈阅读 160评论 0 0
  • 职场从来不是脉脉温情的地方,讲究的是效率与产出,这是二度回归职场才撞到的南墙。而不久前,我还是那个视年长的老板作父...
    春宴归阅读 163评论 0 0
  • 1奶茶总跟我说,偶像和粉丝最好的距离,就是台上与台下的距离。但我偏不信,我跟她说,总有一天我会见到我偶像。“行行行...
    北不应阅读 404评论 0 0
  • 数据集中第一人,互联时代亦先身。金融科技能无我?虚位如今待凤麟!
    轩若临风阅读 391评论 1 1