代码补全快餐教程(4) - 训练语言模型

代码补全快餐教程(4) - 训练语言模型

一个强大的语言模型可以是其他任务的良好基础。预训练的模型就为我们提供了一个强大的语言模型基础,在些基础上,我们进行微调,就可以实现满足特殊需求的模型。
我们先做实操,然后再讲解相关理论。

代码数据准备

严格来讲,进行代码数据准备需要做代码的排重,后面讲到相关论文时我们会讲到。
现在我们就用个最简单的办法,将代码先拼接在一起。

我们写个小脚本,将transformer库中的python文件都读出来连接在一起:

import os


def walkPrograms(dir, datafile, wildcard):
    exts = wildcard.split(" ")
    for root, subdirs, files in os.walk(dir):
        for name in files:
            for ext in exts:
                if name.endswith(ext):
                    print(root)
                    # print(subdirs)
                    print(name)
                    filename = os.path.join(root, name)
                    print(filename)
                    try:
                        f1 = open(filename, 'r', encoding='utf-8')
                        datafile.writelines(f1.readlines())
                    except UnicodeDecodeError:
                        continue
                    break


outfile = open('transformer.data', 'w', encoding='utf-8')
wildcard = '.py'
walkPrograms('/home/xulun/github/transformers/', outfile, wildcard)

最后会生成一个transformer.data文件,其中是python文件的组合。

语言模型fine-tuning

进行训练之前,我们先安装下transformer库,先cd到transformers的下载目录,然后执行

pip3 install -e . --user

安装成功之后,我们就可以使用transformers下的examples中的run_lm_finetuning.py脚本来进行fine-tuning:

python3 run_lm_finetuning.py \
    --output_dir=/home/xulun/out_trans \
    --model_type=gpt2 \
    --model_name_or_path=gpt2 \
    --per_gpu_train_batch_size=1 \
    --do_train \
    --train_data_file=/home/xulun/github/lusinga/localcomplete/server/transformer.data \
    --block_size=512 --save_steps=500 --overwrite_output_dir

在新版中,这个脚本名字改成了:run_language_modeling.py。

我们来介绍下这些参数的含义:

  • output_dir: 最终我们要保存的是权值,这里给出保存权值的目录
  • model_type: 模型的大类,比如gpt2或者其他
  • model_name_or_path: 模型的小类,比如gpt2-medium, gpt2-large, gpt2-xl等
  • per_gpu_train_batch_size: 多CPU训练时每个CPU批次的大小
  • do_train: 只有指定了这个才会进行训练
  • train_data_file: 要训练的文件名
  • block_size: 分块的大小,如果GPU内存大就多选点,我用的是NVidia 2060 GPU,内存较小,所以我选了个相对较小的值
  • save_steps: 训练多少步保存一次,默认值是50,我觉得有点小,这里改成500
  • overwrite_output_dir: 输出目录不为空时覆盖之,节省存储空间

验证效果

我们做个补全效果测试吧,还是我们之前的代码,我们先用gpt2试试效果:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# MODEL = '/home/xulun/out_trans/'
MODEL = 'gpt2'

# 加载词汇表
tokenizer = GPT2Tokenizer.from_pretrained(MODEL)

# 输入待补全的文本
text = '    indexed_tokens = tokenizer.'
predicted_text = text

# 加载模型中预训练好的权值
model = GPT2LMHeadModel.from_pretrained(MODEL)

# 设置为eval模式,这样就不会执行训练模式下的Dropout过程
model.eval()
#model.to('cuda')

# 每一个只能补一个token出来,补一句话需要多次,30次是我拍脑袋的
for i in range(0,30):

    # 以上次预测结果作为本次的输入,所谓的自回归
    indexed_tokens = tokenizer.encode(predicted_text)

    # 将读出的索引标记转化成PyTorch向量
    tokens_tensor = torch.tensor([indexed_tokens])

    # 使用GPU进行加速,诚实地讲速度不太快
    #tokens_tensor = tokens_tensor.to('cuda')

    # 进行推理
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    # 获取预测的下一个子词
    predicted_index = torch.argmax(predictions[0, -1, :]).item()
    # 解码成我们都读懂的文本
    predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
    # 打印输入结果
    print(predicted_text)

输出如下:

indexed_tokens = tokenizer.get_tokenizer_id(tokenizer.get_tokenizer_id(), tokenizer.get_tokenizer_id(), tokenizer.

下面我们换成我们刚才训练的模型,就是让MODEL从gpt2换成刚才我们训练好的目录:

MODEL = '/home/xulun/out_trans/'

好吧,有同学要完整的:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

MODEL = '/home/xulun/out_trans/'
# MODEL = 'gpt2'

# 加载词汇表
tokenizer = GPT2Tokenizer.from_pretrained(MODEL)

# 输入待补全的文本
#text = 'function walk(dir, fn) { if (fs.existsSync(dir)) { let stat ='
#text = 'if (stat.isDirectory()) {fs.readdirSync(dir).'
#text = 'mediaFileText.color ='
#text = 'mediaFileText.top ='
text = '    indexed_tokens = tokenizer.'
predicted_text = text

# 加载模型中预训练好的权值
model = GPT2LMHeadModel.from_pretrained(MODEL)

# 设置为eval模式,这样就不会执行训练模式下的Dropout过程
model.eval()
#model.to('cuda')

# 每一个只能补一个token出来,补一句话需要多次,30次是我拍脑袋的
for i in range(0,30):

    # 以上次预测结果作为本次的输入,所谓的自回归
    indexed_tokens = tokenizer.encode(predicted_text)

    # 将读出的索引标记转化成PyTorch向量
    tokens_tensor = torch.tensor([indexed_tokens])

    # 使用GPU进行加速,诚实地讲速度不太快
    #tokens_tensor = tokens_tensor.to('cuda')

    # 进行推理
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    # 获取预测的下一个子词
    predicted_index = torch.argmax(predictions[0, -1, :]).item()
    # 解码成我们都读懂的文本
    predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
    # 打印输入结果
    print(predicted_text)

输出结果如下:

indexed_tokens = tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)

看起来是比原始模型更懂transformers。我们可以用更多的代码进行训练,这样就能对于写python代码的效果更好。
如果要支持其他语言,我们将训练集换成其他语言就可以了。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,723评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,485评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,998评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,323评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,355评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,079评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,389评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,019评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,519评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,971评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,100评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,738评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,293评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,289评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,517评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,547评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,834评论 2 345

推荐阅读更多精彩内容

  • 2106年底初识BM特训营,开始了我的第一次成长,让我重新审视了自己,确定了自己的发展方向,让我在2017年遇到了...
    麦子育儿说阅读 2,400评论 7 5
  • 作为一个25岁的女生,一直被男票安利护肤公号却一直不认真看的我,终于醒悟了。毕竟,25岁了,再也不是怎么吃也不胖,...
    戴家小呆阅读 2,382评论 9 53
  • 3 进派出所的流程,我也摸了个大概。毕竟我也算是“二进宫”了。“二进宫”的“二”字是个修饰词,实际上我进派出所...
    叶本阅读 325评论 0 0