多任务学习(Multi-task)tensorflow2.0及keras实现

多目标任务存在很多场景中,如多目标检测,推荐系统中的多任务学习。

多任务学习(Multi-task learning)简介

多任务学习背景:只专注于单个模型可能会忽略一些相关任务中可能提升目标任务的潜在信息,通过进行一定程度的共享不同任务之间的参数,可能会使原任务泛化更好。广义的讲,只要loss有多个就算MTL,一些别名(joint learning,learning to learn,learning with auxiliary task)

多任务学习(Multitask learning)定义:基于共享表示(shared representation),把多个相关的任务放在一起学习的一种机器学习方法。

多任务学习(Multitask Learning)是一种推导迁移学习方法,主任务(main tasks)使用相关任务(related tasks)的训练信号(training signal)所拥有的领域相关信息(domain-specific information),做为一直推导偏差(inductive bias)来提升主任务(main tasks)泛化效果(generalization performance)的一种机器学习方法。

多任务学习目标:通过权衡主任务与辅助的相关任务中的训练信息来提升模型的泛化性与表现。从机器学习的视角来看,MTL可以看作一种inductive transfer(先验知识),通过提供inductive bias(某种对模型的先验假设)来提升模型效果。比如,使用L1正则,我们对模型的假设模型偏向于sparse solution(参数要少)。在MTL中,这种先验是通过auxiliary task来提供,更灵活,告诉模型偏向一些其他任务,最终导致模型会泛化得更好。

多任务学习(Multi-task learning)的两种模式

深度学习中两种多任务学习模式:隐层参数的硬共享与软共享。

隐层参数硬共享,指的是多个任务之间共享网络的同几层隐藏层,只不过在网络的靠近输出部分开始分叉去做不同的任务。

隐层参数软共享,不同的任务使用不同的网络,但是不同任务的网络参数,采用距离(L1,L2)等作为约束,鼓励参数相似化。

而本次的代码实现采用的是隐层参数硬共享,也就是两个任务共享网络浅层的参数。

上图是美团使用的多任务学习框架

在使用XGBoost进行单目标训练的时候,通过把点击的样本和下单的样本都作为正样本,并对下单的样本进行上采样或者加权,来平衡点击率和下单率。但这种样本的加权方式也会有一些缺点,例如调整下单权重或者采样率的成本较高,每次调整都需要重新训练,并且对于模型来说较难用同一套参数来表达这两种混合的样本分布。针对上述问题,可以利用DNN灵活的网络结构引入了Multi-task训练。

根据业务目标,我们把点击率和下单率拆分出来,形成两个独立的训练目标,分别建立各自的Loss Function,作为对模型训练的监督和指导。DNN网络的前几层作为共享层,点击任务和下单任务共享其表达,并在BP阶段根据两个任务算出的梯度共同进行参数更新。网络在最后一个全连接层进行拆分,单独学习对应Loss的参数,从而更好地专注于拟合各自Label的分布。

Multi-task DNN的网络结构如上图所示。线上预测时,我们将Click-output和Pay-output做一个线性融合。

# 搭建双任务并训练

def get_model():

    """函数式API搭建双塔DNN模型"""

    # 输入

    user_id = tf.keras.layers.Input(shape=(1,), name="user_id")

    store_id = tf.keras.layers.Input(shape=(1,), name="store_id")

    sku_id = tf.keras.layers.Input(shape=(1,), name="sku_id")

    search_keyword = tf.keras.layers.Input(shape=(1,), name="search_keyword")

    category_id = tf.keras.layers.Input(shape=(1,), name="category_id")

    brand_id = tf.keras.layers.Input(shape=(1,), name="brand_id")

    ware_type = tf.keras.layers.Input(shape=(1,), name="ware_type")

    # user特征

    user_vector = tf.keras.layers.concatenate([

        tf.keras.layers.Embedding(num_user_ids, 32)(user_id),

        tf.keras.layers.Embedding(num_store_ids, 8)(store_id),

        tf.keras.layers.Embedding(num_search_keywords, 16)(search_keyword)

    ])

    user_vector = tf.keras.layers.Dense(32, activation='relu')(user_vector)

    user_vector = tf.keras.layers.Dense(8, activation='relu',

                              name="user_embedding", kernel_regularizer='l2')(user_vector)

    # item特征

    movie_vector = tf.keras.layers.concatenate([

        tf.keras.layers.Embedding(num_sku_ids, 32)(sku_id),

        tf.keras.layers.Embedding(num_category_ids, 8)(category_id),

        tf.keras.layers.Embedding(num_brand_ids, 8)(brand_id),

        tf.keras.layers.Embedding(num_ware_types, 2)(ware_type)

    ])

    movie_vector = tf.keras.layers.Dense(32, activation='relu')(movie_vector)

    movie_vector = tf.keras.layers.Dense(8, activation='relu',

                                name="movie_embedding", kernel_regularizer='l2')(movie_vector)

    x = tf.keras.layers.concatenate([user_vector,movie_vector])

    out1 = tf.keras.layers.Dense(16,activation = 'relu')(x)

    out1 = tf.keras.layers.Dense(8,activation = 'relu')(out1)

    out1 = tf.keras.layers.Dense(1, activation='sigmoid',name = 'out1')(out1)


    out2 = tf.keras.layers.Dense(16,activation = 'relu')(x)

    out2 = tf.keras.layers.Dense(8,activation = 'relu')(out2)

    out2 = tf.keras.layers.Dense(1, activation='sigmoid',name = 'out2')(out2)

    return tf.keras.models.Model(inputs=[user_id, sku_id, store_id, search_keyword, category_id, brand_id,ware_type],

                              outputs=[out1,out2])

模型代码部分

这里模型构建有两点需要注意:

各个任务的输出层一定要命名,比如笔者这个模型的点击率输出层Dense(1, activation='sigmod',name = "out1")(out1)中的name ="out1",以及下单率输出层Dense(1, activation='sigmod',name = "out2")(out2)中的name ="out2"不能省略。

第二个就是model.compile中的loss和loss的权重需要和任务输出层的name进行对应,如下:

loss={'out1': loss,'out2': loss}

loss_weights={'out1':1, 'crf_output': 1}

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,189评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,577评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,857评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,703评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,705评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,620评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,995评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,656评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,898评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,639评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,720评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,395评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,982评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,953评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,195评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,907评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,472评论 2 342

推荐阅读更多精彩内容