三元组抽取任务,基于“半指针-半标注”结构

文章介绍:https://kexue.fm/archives/7161

数据集:http://ai.baidu.com/broad/download?dataset=sked

最优f1=0.82198

代码来源 bert4keras examples

苏神鼓励大家进行小改动后发文章出来哟。

Baidu Research Open-Access Dataset​ai.baidu.com
8c82bd4d3703cdf3471ed0768a7d8e2c.png

这个数据集是一个中文的三元组抽取的数据集

{ 
    "text": "《新駌鸯蝴蝶梦》是黄安的音乐作品,收录在《流金十载全记录》专辑中", 
    "spo_list": [ 
        { 
            "subject": "新駌鸯蝴蝶梦", 
            "predicate": "所属专辑", 
            "object": "流金十载全记录", 
            "subject_type": "歌曲", 
            "object_type": "音乐专辑" 
        }, 
        { 
            "subject": "新駌鸯蝴蝶梦", 
            "predicate": "歌手", 
            "object": "黄安", 
            "subject_type": "歌曲", 
            "object_type": "人物" 
        } 
    ] 
}

安装bert4keras

pip install git+https://www.github.com/bojone/bert4keras.git

训练代码如下

 
import json 
import codecs 
import numpy as np 
import tensorflow as tf 
from bert4keras.backend import keras, set_gelu, K 
from bert4keras.layers import LayerNormalization 
from bert4keras.tokenizer import Tokenizer 
from bert4keras.bert import build_bert_model 
from bert4keras.optimizers import Adam, ExponentialMovingAverage 
from bert4keras.snippets import sequence_padding, DataGenerator 
from keras.layers import * 
from keras.models import Model 
from tqdm import tqdm 
 
 
maxlen = 128 
batch_size = 64 
config_path = 'wwm/bert_config.json' 
checkpoint_path = 'wwm/bert_model.ckpt' 
dict_path = 'wwm/vocab.txt' 
 
 
def load_data(filename): 
    D = [] 
    with codecs.open(filename, encoding='utf-8') as f: 
        for l in f: 
            l = json.loads(l) 
            D.append({ 
                'text': l['text'], 
                'spo_list': [ 
                    (spo['subject'], spo['predicate'], spo['object']) 
                    for spo in l['spo_list'] 
                ] 
            }) 
    return D 
 
 
# 加载数据集 
train_data = load_data('kg_huge/train_data.json') 
valid_data = load_data('kg_huge/dev_data.json') 
predicate2id, id2predicate = {}, {} 
 
with codecs.open('kg_huge/all_50_schemas') as f: 
    for l in f: 
        l = json.loads(l) 
        if l['predicate'] not in predicate2id: 
            id2predicate[len(predicate2id)] = l['predicate'] 
            predicate2id[l['predicate']] = len(predicate2id) 
 
# 建立分词器 
tokenizer = Tokenizer(dict_path, do_lower_case=True) 
 
 
def search(pattern, sequence): 
    """从sequence中寻找子串pattern 
    如果找到,返回第一个下标;否则返回-1。 
    """ 
    n = len(pattern) 
    for i in range(len(sequence)): 
        if sequence[i:i + n] == pattern: 
            return i 
    return -1 
 
 
class data_generator(DataGenerator): 
    """数据生成器 
    """ 
    def __iter__(self, random=False): 
        idxs = list(range(len(self.data))) 
        if random: 
            np.random.shuffle(idxs) 
        batch_token_ids, batch_segment_ids = [], [] 
        batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] 
        for i in idxs: 
            d = self.data[i] 
            token_ids, segment_ids = tokenizer.encode(d['text'], max_length=maxlen) 
            # 整理三元组 {s: [(o, p)]} 
            spoes = {} 
            for s, p, o in d['spo_list']: 
                s = tokenizer.encode(s)[0][1:-1] 
                p = predicate2id[p] 
                o = tokenizer.encode(o)[0][1:-1] 
                s_idx = search(s, token_ids) 
                o_idx = search(o, token_ids) 
                if s_idx != -1 and o_idx != -1: 
                    s = (s_idx, s_idx + len(s) - 1) 
                    o = (o_idx, o_idx + len(o) - 1, p) 
                    if s not in spoes: 
                        spoes[s] = [] 
                    spoes[s].append(o) 
            if spoes: 
                # subject标签 
                subject_labels = np.zeros((len(token_ids), 2)) 
                for s in spoes: 
                    subject_labels[s[0], 0] = 1 
                    subject_labels[s[1], 1] = 1 
                # 随机选一个subject 
                start, end = np.array(list(spoes.keys())).T 
                start = np.random.choice(start) 
                end = np.random.choice(end[end >= start]) 
                subject_ids = (start, end) 
                # 对应的object标签 
                object_labels = np.zeros((len(token_ids), len(predicate2id), 2)) 
                for o in spoes.get(subject_ids, []): 
                    object_labels[o[0], o[2], 0] = 1 
                    object_labels[o[1], o[2], 1] = 1 
                # 构建batch 
                batch_token_ids.append(token_ids) 
                batch_segment_ids.append(segment_ids) 
                batch_subject_labels.append(subject_labels) 
                batch_subject_ids.append(subject_ids) 
                batch_object_labels.append(object_labels) 
                if len(batch_token_ids) == self.batch_size or i == idxs[-1]: 
                    batch_token_ids = sequence_padding(batch_token_ids) 
                    batch_segment_ids = sequence_padding(batch_segment_ids) 
                    batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2)) 
                    batch_subject_ids = np.array(batch_subject_ids) 
                    batch_object_labels = sequence_padding(batch_object_labels, padding=np.zeros((len(predicate2id), 2))) 
                    yield [ 
                        batch_token_ids, batch_segment_ids, 
                        batch_subject_labels, batch_subject_ids, batch_object_labels 
                    ], None 
                    batch_token_ids, batch_segment_ids = [], [] 
                    batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [] 
 
 
def batch_gather(params, indices): 
    """params.shape=[b, n, d],indices.shape=[b] 
    从params的第i个序列中选出第indices[i]个向量,返回shape=[b, d]。 
    """ 
    indices = K.cast(indices, 'int32') 
    batch_idxs = K.arange(0, K.shape(indices)[0]) 
    indices = K.stack([batch_idxs, indices], 1) 
    return tf.gather_nd(params, indices) 
 
 
def extrac_subject(inputs): 
    """根据subject_ids从output中取出subject的向量表征 
    """ 
    output, subject_ids = inputs 
    start = batch_gather(output, subject_ids[:, 0]) 
    end = batch_gather(output, subject_ids[:, 1]) 
    subject = K.concatenate([start, end], 1) 
    return subject 
 
 
# 补充输入 
subject_labels = Input(shape=(None, 2), name='Subject-Labels') 
subject_ids = Input(shape=(2, ), name='Subject-Ids') 
object_labels = Input(shape=(None, len(predicate2id), 2), name='Object-Labels') 
 
# 加载预训练模型 
bert = build_bert_model( 
    config_path=config_path, 
    checkpoint_path=checkpoint_path, 
    return_keras_model=False, 
) 
 
# 预测subject 
output = Dense(units=2, 
               activation='sigmoid', 
               kernel_initializer=bert.initializer)(bert.model.output) 
subject_preds = Lambda(lambda x: x**2)(output) 
 
subject_model = Model(bert.model.inputs, subject_preds) 
 
# 传入subject,预测object 
# 通过Conditional Layer Normalization将subject融入到object的预测中 
output = bert.model.layers[-2].get_output_at(-1) 
subject = Lambda(extrac_subject)([output, subject_ids]) 
output = LayerNormalization(conditional=True)([output, subject]) 
output = Dense(units=len(predicate2id) * 2, 
               activation='sigmoid', 
               kernel_initializer=bert.initializer)(output) 
output = Reshape((-1, len(predicate2id), 2))(output) 
object_preds = Lambda(lambda x: x**4)(output) 
 
object_model = Model(bert.model.inputs + [subject_ids], object_preds) 
 
# 训练模型 
train_model = Model(bert.model.inputs + [subject_labels, subject_ids, object_labels], 
                    [subject_preds, object_preds]) 
 
mask = bert.model.get_layer('Sequence-Mask').output 
 
subject_loss = K.binary_crossentropy(subject_labels, subject_preds) 
subject_loss = K.mean(subject_loss, 2) 
subject_loss = K.sum(subject_loss * mask) / K.sum(mask) 
 
object_loss = K.binary_crossentropy(object_labels, object_preds) 
object_loss = K.sum(K.mean(object_loss, 3), 2) 
object_loss = K.sum(object_loss * mask) / K.sum(mask) 
 
train_model.add_loss(subject_loss + object_loss) 
train_model.compile(optimizer=Adam(1e-5)) 
 
 
def extract_spoes(text): 
    """抽取输入text所包含的三元组 
    """ 
    tokens = tokenizer.tokenize(text, max_length=maxlen) 
    token_ids, segment_ids = tokenizer.encode(text, max_length=maxlen) 
    # 抽取subject 
    subject_preds = subject_model.predict([[token_ids], [segment_ids]]) 
    start = np.where(subject_preds[0, :, 0] > 0.6)[0] 
    end = np.where(subject_preds[0, :, 1] > 0.5)[0] 
    subjects = [] 
    for i in start: 
        j = end[end >= i] 
        if len(j) > 0: 
            j = j[0] 
            subjects.append((i, j)) 
    if subjects: 
        spoes = [] 
        token_ids = np.repeat([token_ids], len(subjects), 0) 
        segment_ids = np.repeat([segment_ids], len(subjects), 0) 
        subjects = np.array(subjects) 
        # 传入subject,抽取object和predicate 
        object_preds = object_model.predict([token_ids, segment_ids, subjects]) 
        for subject, object_pred in zip(subjects, object_preds): 
            start = np.where(object_pred[:, :, 0] > 0.6) 
            end = np.where(object_pred[:, :, 1] > 0.5) 
            for _start, predicate1 in zip(*start): 
                for _end, predicate2 in zip(*end): 
                    if _start <= _end and predicate1 == predicate2: 
                        spoes.append((subject, predicate1, (_start, _end))) 
                        break 
        return [ 
            ( 
                tokenizer.decode(token_ids[0, s[0]:s[1] + 1], tokens[s[0]:s[1] + 1]), 
                id2predicate[p], 
                tokenizer.decode(token_ids[0, o[0]:o[1] + 1], tokens[o[0]:o[1] + 1]) 
            ) for s, p, o in spoes 
        ] 
    else: 
        return [] 
 
 
class SPO(tuple): 
    """用来存三元组的类 
    表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法, 
    使得在判断两个三元组是否等价时容错性更好。 
    """ 
    def __init__(self, spo): 
        self.spox = ( 
            tuple(tokenizer.tokenize(spo[0])), 
            spo[1], 
            tuple(tokenizer.tokenize(spo[2])), 
        ) 
 
    def __hash__(self): 
        return self.spox.__hash__() 
 
    def __eq__(self, spo): 
        return self.spox == spo.spox 
 
 
def evaluate(data): 
    """评估函数,计算f1、precision、recall 
    """ 
    X, Y, Z = 1e-10, 1e-10, 1e-10 
    f = codecs.open('dev_pred.json', 'w', encoding='utf-8') 
    pbar = tqdm() 
    for d in data: 
        R = set([SPO(spo) for spo in extract_spoes(d['text'])]) 
        T = set([SPO(spo) for spo in d['spo_list']]) 
        X += len(R & T) 
        Y += len(R) 
        Z += len(T) 
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z 
        pbar.update() 
        pbar.set_description('f1: %.5f, precision: %.5f, recall: %.5f' % 
                             (f1, precision, recall)) 
        s = json.dumps( 
            { 
                'text': d['text'], 
                'spo_list': list(T), 
                'spo_list_pred': list(R), 
                'new': list(R - T), 
                'lack': list(T - R), 
            }, 
            ensure_ascii=False, 
            indent=4) 
        f.write(s + 'n') 
    pbar.close() 
    f.close() 
    return f1, precision, recall 
 
 
class Evaluator(keras.callbacks.Callback): 
    """评估和保存模型 
    """ 
    def __init__(self): 
        self.best_val_f1 = 0. 
 
    def on_epoch_end(self, epoch, logs=None): 
        EMAer.apply_ema_weights() 
        f1, precision, recall = evaluate(valid_data) 
        if f1 >= self.best_val_f1: 
            self.best_val_f1 = f1 
            train_model.save_weights('best_model.weights') 
        EMAer.reset_old_weights() 
        print('f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5fn' % 
              (f1, precision, recall, self.best_val_f1)) 
 
 
if __name__ == '__main__': 
 
    train_generator = data_generator(train_data, batch_size) 
    evaluator = Evaluator() 
    EMAer = ExponentialMovingAverage(0.999) 
 
    train_model.fit_generator(train_generator.forfit(), 
                             steps_per_epoch=len(train_generator), 
                             epochs=20, 
                             callbacks=[evaluator, EMAer]) 
 
else: 
 
    train_model.load_weights('best_model.weights')

中文wwm下载地址

ymcui/Chinese-BERT-wwm​github.com
7c53c13c8b6738d080bd94e3d28c891e.png

wwm小数据集训练截图

3e9eae5527620fa21c3841ea6d52777c.png

全量数据集第一轮

b8350e0a421aa6c078c0aa5ea819f1b2.png

一轮就已经有79.5的准确率了

df1adb49f60ca2fd95a82e1cb6d0df48.png

第十六个epoch的时候损失到了0.0122 最佳的f1分数是82.41分

评论关闭
IT源码网

微信公众号号:IT虾米 (左侧二维码扫一扫)欢迎添加!