SMO算法实现

这里根据SMO算法原论文中的伪代码实现了SMO算法。算法和数据已经上传到了git

伪代码

target = desired output vector
point = training point matrix
procedure takeStep(i1,i2)
    if (i1 == i2) return 0
    alph1 = Lagrange multiplier for i1
    y1 = target[i1]
    E1 = SVM output on point[i1] – y1 (check in error cache)
    s = y1*y2
    Compute L, H via equations (13) and (14)
    if (L == H)
        return 0
    k11 = kernel(point[i1],point[i1])
    k12 = kernel(point[i1],point[i2])
    k22 = kernel(point[i2],point[i2])
    eta = k11+k22-2*k12
    if (eta > 0)
    {
        a2 = alph2 + y2*(E1-E2)/eta
        if (a2 < L) a2 = L
        else if (a2 > H) a2 = H
    }
    else
    {
        Lobj = objective function at a2=L
        Hobj = objective function at a2=H
        if (Lobj < Hobj-eps)
            a2 = L
        else if (Lobj > Hobj+eps)
            a2 = H
        else
            a2 = alph2
    }
    if (|a2-alph2| < eps*(a2+alph2+eps))
        return 0
    a1 = alph1+s*(alph2-a2)
    Update threshold to reflect change in Lagrange multipliers
    Update weight vector to reflect change in a1 & a2, if SVM is linear
    Update error cache using new Lagrange multipliers
    Store a1 in the alpha array
    Store a2 in the alpha array
    return 1
endprocedure

procedure examineExample(i2)
    y2 = target[i2]
    alph2 = Lagrange multiplier for i2
    E2 = SVM output on point[i2] – y2 (check in error cache)
    r2 = E2*y2
    if ((r2 < -tol && alph2 < C) || (r2 > tol && alph2 > 0))
    {
        if (number of non-zero & non-C alpha > 1)
        {
            i1 = result of second choice heuristic (section 2.2)
            if takeStep(i1,i2)
                return 1
        }
        loop over all non-zero and non-C alpha, starting at a random point
        {
            i1 = identity of current alpha
            if takeStep(i1,i2)
                return 1
        }
        loop over all possible i1, starting at a random point
        {
            i1 = loop variable
            if (takeStep(i1,i2)
                return 1
        }
    }
    return 0
endprocedure

main routine:
    numChanged = 0
    examineAll = 1
    while (numChanged > 0 | examineAll)
    {
        numChanged = 0;
        if (examineAll)
            loop I over all training examples
            numChanged += examineExample(I)
        else
            loop I over examples where alpha is not 0 & not C
            numChanged += examineExample(I)
        if (examineAll == 1)
            examineAll = 0
        else if (numChanged == 0)
            examineAll = 1
    }

python实现

# @Author  : lightXu
# @File    : smo_paper.py

import numpy as np
import matplotlib.pyplot as plt
import random
import copy


class OptStruct:
    """
    数据结构,维护所有需要操作的值
    Parameters:
        dataMatIn - 数据矩阵
        classLabels - 数据标签
        C - 松弛变量
        toler - 容错率
    """

    def __init__(self, data_x, label, C, toler):
        self.X = data_x
        self.label = label
        self.C = C
        self.toler = toler
        self.row = data_x.shape[0]
        self.alpha = np.zeros(self.row)
        self.b = 0
        self.e_cache = np.zeros(self.row)
        # self.e_cache = label * (-1)


def cal_Ek(ost, k):
    """
    计算误差
    Parameters:
        ost - 数据结构
        k - 标号为k的数据
    Returns:
        Ek - 标号为k的数据误差
    """
    fxk = np.dot((ost.alpha * ost.label).T, np.dot(ost.X, ost.X[k, :])) + ost.b
    Ek = fxk - ost.label[k]
    return round_float(Ek), round_float(fxk)


def round_float(value):
    return round(value, 8)


def load_data(file_name):
    data_x = []
    data_y = []

    with open(file_name, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip().split('\t')
            xi = line[:-1]
            data_x.append(xi)
            data_y.append(line[-1])

    data_x = np.array(data_x, dtype=np.float)
    label = np.array(data_y, dtype=np.float)

    return data_x, label


def select_j_random(i, m):
    j = i
    while j == i:
        j = int(random.uniform(0, m))

    return j


def select_j(eligible_list, i, ost, Ei):
    eligible_list = list(eligible_list)
    if i in eligible_list:
        eligible_list.remove(i)

    E_list = [cal_Ek(ost, k)[0] for k in eligible_list]
    if Ei < 0:
        value = max(E_list)
    elif Ei > 0:
        value = min(E_list)
    else:
        value = max([abs(cal_Ek(ost, k)[0]) for k in eligible_list])
    max_k = eligible_list[E_list.index(value)]

    # E_list1 = [(Ei - cal_Ek(ost, k)[0]) for k in eligible_list]
    # value1 = max(E_list1)
    # max_k1 = eligible_list[E_list1.index(value1)]
    # if max_k != max_k1:
    #     print('!=', i)

    Ej, _ = cal_Ek(ost, max_k)
    return max_k, Ej


def updateEk(ost, k):
    """
    计算Ek,并更新误差缓存
    Parameters:
        oS - 数据结构
        k - 标号为k的数据的索引值
    Returns:
    """
    Ek, _ = cal_Ek(ost, k)
    ost.e_cache[k] = Ek


def clip_alpha(alpha, L, H):
    if alpha > H:
        alpha = H
    if alpha < L:
        alpha = L

    return alpha


def cal_w(data_x, label, alpha):
    w = np.dot((alpha * label).T, data_x)
    return w


def objective_func(ost, i1, i2, alpha1, alpha2, L, H):
    s = ost.label[i1] * ost.label[i2]
    k11 = np.dot(ost.X[i1, :], ost.X[i1, :])
    k12 = np.dot(ost.X[i1, :], ost.X[i2, :])
    k22 = np.dot(ost.X[i2, :], ost.X[i2, :])
    f1 = (ost.label[i1] * (cal_Ek(ost, i1) + ost.b) - alpha1 * k11 - s * alpha2 * k12)
    f2 = (ost.label[i2] * (cal_Ek(ost, i2) + ost.b) - s * alpha1 * k12 - alpha2 * k12)

    L1 = alpha1 + s * (alpha2 - L)
    H1 = alpha1 + s * (alpha2 - H)

    obj_L = L1 * f1 + L * f2 + 0.5 * L1 * L1 * k11 + 0.5 * L * L * k22 + s * L * L1 * k12
    obj_H = H1 * f1 + H * f2 + 0.5 * H1 * H1 * k11 + 0.5 * H * H * k22 + s * H * H1 * k12

    return obj_L, obj_H


def take_step(ost, i1, i2, E2):
    if i1 == i2:
        return 0
    alpha1 = ost.alpha[i1].copy()
    y1 = ost.label[i1]
    alpha2 = ost.alpha[i2].copy()
    y2 = ost.label[i2]

    E1, _ = cal_Ek(ost, i1)
    s = y1 * y2

    if ost.label[i1] != ost.label[i2]:
        L = max(0, ost.alpha[i2] - ost.alpha[i1])
        H = min(ost.C, ost.C + ost.alpha[i2] - ost.alpha[i1])
    else:
        L = max(0, ost.alpha[i2] + ost.alpha[i1] - ost.C)
        H = min(ost.C, ost.alpha[i2] + ost.alpha[i1])
    if L == H:
        # print("L==H")
        return 0

    eta = (np.dot(ost.X[i1, :], ost.X[i1, :])
           + np.dot(ost.X[i2, :], ost.X[i2, :])
           - 2 * np.dot(ost.X[i1, :], ost.X[i2, :]))
    eta = round_float(eta)
    if eta > 0:
        a2 = alpha2 + y2 * (E1 - E2) / eta
        a2 = a2
        if a2 < L:
            a2 = L
        if a2 > H:
            a2 = H

    else:
        Lobj, _ = objective_func(ost, i1, i2, alpha1, L, L, H)
        _, Hobj = objective_func(ost, i1, i2, alpha1, H, L, H)
        if Lobj < Hobj - 0.0001:
            a2 = L
        elif Lobj > Hobj + 0.001:
            a2 = H
        else:
            a2 = alpha2

    if abs(a2 - alpha2) < 0.0001:
        return 0

    a1 = alpha1 + s * (alpha2 - a2)

    b1 = (ost.b - E1
          - ost.label[i1] * (ost.alpha[i1] - alpha1) * np.dot(ost.X[i1, :], ost.X[i1, :])
          - ost.label[i2] * (ost.alpha[i2] - alpha2) * np.dot(ost.X[i1, :], ost.X[i2, :]))
    b2 = (ost.b - E2
          - ost.label[i1] * (ost.alpha[i1] - alpha1) * np.dot(ost.X[i1, :], ost.X[i2, :])
          - ost.label[i2] * (ost.alpha[i2] - alpha2) * np.dot(ost.X[i2, :], ost.X[i2, :]))

    if 0 < ost.alpha[i1] < ost.C:
        ost.b = b1
    elif 0 < ost.alpha[i2] < ost.C:
        ost.b = b2
    else:
        ost.b = (b1 + b2) / 2.0

    updateEk(ost, i1)
    updateEk(ost, i2)

    ost.alpha[i1] = round_float(a1)
    ost.alpha[i2] = round_float(a2)

    return 1


def violate_kkt(ost, alpha2, E2, fx2, y2):
    r2 = E2 * y2
    violate_cond1 = r2 < -ost.toler and alpha2 < ost.C
    violate_cond2 = r2 > ost.toler and alpha2 > 0

    violate12 = violate_cond1 or violate_cond2

    # 原始kkt
    # y2 * fx2 - 1 = y2*(fx2-y2) = y2*E2
    violate_cond3 = (not y2 * fx2 - 1 >= 0) and alpha2 == 0
    violate_cond4 = (not y2 * fx2 - 1 != 0) and 0 < alpha2 < ost.C
    violate_cond5 = (not y2 * fx2 - 1 <= 0) and alpha2 == ost.C

    # Notice that the KKT conditions are checked to be within ε of fulfillment.
    # 论文中引入了一个误差eps, 此时
    violate_cond3_ = (not y2 * fx2 - 1 >= -ost.toler) and alpha2 == 0
    violate_cond4_ = (not abs(y2 * fx2 - 1) <= ost.toler) and 0 < alpha2 < ost.C
    violate_cond5_ = (not y2 * fx2 - 1 <= ost.toler) and alpha2 == ost.C

    violate345 = violate_cond3_ or violate_cond4_ or violate_cond5_

    return violate345


def examine_example(ost, i2):
    y2 = ost.label[i2]
    alpha2 = ost.alpha[i2]
    E2, fx2 = cal_Ek(ost, i2)

    # 是非违反kkt条件
    cond = violate_kkt(ost, alpha2, E2, fx2, y2)
    if cond:
        non_0_non_C_alpha_list = np.where((ost.alpha != 0) & (ost.alpha != ost.C))[0]
        if (len(non_0_non_C_alpha_list)) > 1:
            i1, _ = select_j(non_0_non_C_alpha_list, i2, ost, E2)
            if take_step(ost, i1, i2, E2):
                return 1

        non_tmp = non_0_non_C_alpha_list.copy().tolist()
        while len(non_tmp) > 0:
            i1 = random.choice(non_tmp)
            if take_step(ost, i1, i2, E2):
                return 1
            else:
                non_tmp.remove(i1)

        tmp_list = list(range(0, ost.row))
        while len(tmp_list) > 0:
            i1 = random.choice(tmp_list)
            if take_step(ost, i1, i2, E2):
                return 1
            else:
                tmp_list.remove(i1)

    return 0


def main(dataMatIn, classLabels, C, toler, maxIter):
    ost = OptStruct(dataMatIn, classLabels, C, toler)
    iter_num = 0
    num_changed = 0
    examine_all = 1

    while (iter_num < maxIter) and num_changed > 0 or examine_all:
        num_changed = 0
        if examine_all:
            for i in range(ost.row):
                """
                The outer loop first iterates over the entire training set, 
                determining whether each example violates the KKT conditions (12). 
                """
                num_changed = num_changed + examine_example(ost, i)
                print("全样本遍历:第%d次迭代 样本:%d, alpha优化次数:%d" % (iter_num, i, num_changed))
            iter_num += 1

        else:
            """
            After one pass through the entire training set, the outer loop iterates over all examples whose
            Lagrange multipliers are neither 0 nor C (the non-bound examples). Again, each example is
            checked against the KKT conditions and violating examples are eligible for optimization. 
            """
            non_bound_index = np.where((0 < ost.alpha) & (ost.alpha < C))[0]
            for i in non_bound_index:
                num_changed = num_changed + examine_example(ost, i)
                print("非边界:第%d次迭代 样本:%d, alpha优化次数:%d" % (iter_num, i, num_changed))
            iter_num += 1
        if examine_all:
            examine_all = 0
        elif num_changed == 0:
            examine_all = 1

    return ost.b, ost.alpha


def show_classifier(data_x, label, w, b, alpha, seed):
    positive_index = np.where(label == 1)[0]
    negative_index = np.where(label == -1)[0]
    data_x_positive = data_x[positive_index]
    data_x_negative = data_x[negative_index]

    plt.scatter(data_x_positive[:, 0], data_x_positive[:, 1],
                s=30, alpha=0.7, c='green')  # 正样本散点图
    plt.scatter(data_x_negative[:, 0], data_x_negative[:, 1],
                s=30, alpha=0.7, c='pink')  # 负样本散点图

    x_max = np.max(data_x, axis=0)[0]
    x_min = np.min(data_x, axis=0)[0]
    a1, a2 = w
    b = float(b)
    y1, y2 = (-b - a1 * x_max) / a2, (-b - a1 * x_min) / a2
    plt.plot([x_max, x_min], [y1, y2])

    # 找出支持向量点
    for i, alp in enumerate(alpha):
        if abs(alp) > 0:
            print(i)
            x_max, x_min = data_x[i]
            plt.scatter([x_max], [x_min], s=150, c='none', alpha=0.7, linewidth=1.5, edgecolor='red')

    # plt.savefig("./fig/seed_{}.png".format(seed))
    plt.show()
    plt.close()


def show_classifier1(data_x, label):
    positive_index = np.where(label == 1)[0]
    negative_index = np.where(label == -1)[0]
    data_x_positive = data_x[positive_index]
    data_x_negative = data_x[negative_index]

    plt.scatter(data_x_positive[:, 0], data_x_positive[:, 1],
                s=30, alpha=0.7, c='green')  # 正样本散点图
    plt.scatter(data_x_negative[:, 0], data_x_negative[:, 1],
                s=30, alpha=0.7, c='pink')  # 负样本散点图

    plt.show()


if __name__ == '__main__':
    seed = 10
    random.seed(seed)
    dataMat, labelMat = load_data('testSet.txt')
    b, alphas = main(dataMat, labelMat, 0.6, 0.0001, 100)
    w = cal_w(dataMat, labelMat, alphas)
    print(w, b)
    show_classifier(dataMat, labelMat, w, b, alphas, seed)

分类结果如下:


smo.png

补充

第一个参数选择需要判断是否违反原始KKT条件, 这部分的原理可以参考博客。论文作者在这部分引入了eps加快训练, 推导过程大家直接看代码就行了。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,711评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,932评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,770评论 0 330
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,799评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,697评论 5 359
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,069评论 1 276
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,535评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,200评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,353评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,290评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,331评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,020评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,610评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,694评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,927评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,330评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,904评论 2 341