irpas技术客

BERT代码解析_世界划水锦标赛冠军_bert代码详解

大大的周 2903

一、bert的原理 1、最核心的一点是:MLM损失函数的计算

什么是MLM损失函数? 损失函数就是用来表现预测与实际数据的差距程度-----我根据数据预测出一个函数来预测我之后的变化,而损失函数就是L=(Y-f(x))2,最后计算一个平均损失函数求和的值来表示差距。

MLM损失函数: 这个任务就是将sentence中一些token进行掩盖,模型会输出这些掩盖的token的隐藏状态,将这些隐藏状态输入softmax可以得到候选单词的概率分布,这样根据ground truth就可以计算cross entropy了 计算mlm损失的时候使用的是那部分数据

2、使用TRM的编码部分


BERT核心目的就是:是把下游具体NLP任务的活逐渐移到预训练产生词向量上

bert的亮点: 1、双向的transformers–同时考虑上下文 2、句子级别的应用 3、适用于不同任务----google已经预预训练好了模型,我们要做的就是根据不同的任务,按照bert的输入要求(后面会看到)输入我们的数据,然后获取输出,在输出层加一层(通常情况下)全连接层就OK啦,整个训练过程就是基于预训练模型的微调

主要分成文本分类,关系抽取等等的句子级别任务和命名实体识别,知识问答等token级别任务 代码解析

tokenization是对原始句子内容的解析,分为一下两部分

BasicTokenizer

主要是进行unicode转换、标点符号分割、中文字符分割、去除重音符号等操作, 最后返回的是关于词的数组(中文是字的数组)

WordpieceTokenizer

WordpieceTokenizer的目的是将合成词分解成类似词根一样的词片。 run_classifier.py

BERT主要分为两个部分。 一个是训练语言模型(language model)的预训练(run_pretraining.py)部分(已公布)。 另一个是训练具体任务(task)的fine-tune部分(fine tune就是用别人训练好的模型,加上我们自己的数据,来训练新的模型)

用run_classifier.py对自己的训练集数据进行预处理,让其能够输入到后续的模型中,解读run_classifier.py: (直接看class)

InputExample类 InputExample类主要定义了一些数据预处理后要生成的字段名

guid就是一个id号,一般将数据处理成train、dev、test数据集,那么这里定义方式就可以是相应的数据集+行号(句子) text_a 就是当前的句子,text_b是另一个句子,因为有的任务需要两个两个句子,如果任务中没有的话,可以将text_b设为None label就是标签

InputFeatures类 InputFeatures类主要是定义了bert的输入格式,且还会通过一些代码将InputExample转化为InputFeatures,这才是bert最终使用的数据格式

label_id是计算loss时候用到的, input_ids,segment_ids分别对应单词id和句子(上下句标示), Input_mask就是记录的是填充信息

DataProcessor类 DataProcessor,这是一个数据预处理的基类,里面定义了一些基本方法 XnliProcessor、MnliProcessor、MrpcProcessor、ColaProcessor四个类是对DataProcessor的具体实现-------一般包含get_train_examples,get_dev_examples,get_test_examples,get_labels,_create_examples方法

get_train_examples,get_dev_examples,get_test_examples------通过调用_create_examples返回一个InputExample类数据结构 get_labels就是返回类别

这里的tokenization的convert_to_unicode就是将文本转化为utf-8编码

对数据的预处理,除了DataProcessor类还有四个方法:

convert_single_example 负责bert的输入部分,负责加上了[CLS]和SEP]标示。最后返回的就是一个InputFeatures类

file_based_convert_examples_to_features 该函数主要就是将上述返回的InputFeatures类数据,保存成一个TFrecords数据格式。为了在训练时读写快速方便

file_based_input_fn_builder 对应的就是从TFrecords 解析读取数据

truncate_seq_pair 就是来限制text_a和text_b总长度的,当超过的话,会轮番pop掉tokens

model_fn_builder

整个模型过程采用了tf.contrib.tpu.TPUEstimator这一高级封装的API model_fn_builder是壳,create_model是核心,其内部定义了loss,预测概率以及预测结果等等。

首先调用create_model得到total_loss, per_example_loss, logits, probabilities等等,然后针对不同的状态返回不同的结果(output_spec),如果是train则返回loss,train_op等,如果是dev则返回一些评价指标如accuracy,如果是test则返回预测结果

create_model 首先调用modeling.BertModel得到bert模型 bert模型的输入:input_ids,input_mask,segment_ids config是bert的配置文件,在开头下载的中文模型中里面有,直接加载即可 use_one_hot_embeddings是根据是不是用GPU而定的 输出有两种: 第一种输出结果是[batch_size, seq_length, embedding_size] 第二种输出结果是[batch_size, embedding_size] 第二种结果是第一种结果在第二个维度上面进行了池化,要是形象点比喻的话,第一种结果得到是tokens级别的结果,第二种是句子级别的,其实就是一个池化

定义部分,根据自己的任务进行自定义

main 主要就是通过人为定义的一些配置值(FLAGS)将上面的流程整个组合起来

预处理 这里就是定义数据预处理器的,记得把自己定义的预处理包含进来

使用tf.contrib.tpu.TPUEstimator定义模型

根据不同模式(train/dev/test,这也是运行时可以指定的)运行estimator.train,estimator.evaluate,estimator.predict

训练集(training set):训练算法。 开发集(development set):调整参数、选择特征,以及对学习算法作出其它决定。 测试集(test set):开发集中选出的最优的模型在测试集上进行评估。不会据此改变学习算法或参数。


1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,会注明原创字样,如未注明都非原创,如有侵权请联系删除!;3.作者投稿可能会经我们编辑修改或补充;4.本站不提供任何储存功能只提供收集或者投稿人的网盘链接。

标签: #bert代码详解 #truth就可以计算cross #Entropy