BART中文摘要生成,(nplcc与LCSTS数据集)

 1 训练

import tqdm 
from datasets import load_dataset
import lawrouge

import datasets
import random
import pandas as pd

from datasets import dataset_dict
import datasets

from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

import warnings
from pathlib import Path
from typing import List, Tuple, Union

from torch import nn

import jieba
import numpy as np
import lawrouge

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
from transformers.utils import logging


dataset = load_dataset('json', data_files='nlpcc_data.json', field='data')

def flatten(example):
    return {
        "document": example["content"],
        "summary": example["title"],
        "id":"0"
    }
dataset = dataset["train"].map(flatten, remove_columns=["title", "content"]) # , remove_columns=["title", "content"]


TokenModel = "bert-base-chinese"

from transformers import AutoTokenizer, BertConfig
tokenizer = AutoTokenizer.from_pretrained(TokenModel)

config = BertConfig.from_pretrained(TokenModel)

model_checkpoint = "fnlp/bart-large-chinese"

print(model_checkpoint)

max_input_length = 512 # input, source text 注意长度,复旦BART中文预训练模型使用的bert tokenizer
max_target_length = 128 # summary, target text

def preprocess_function(examples):
    inputs = [doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



raw_datasets = dataset
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
train_data_txt, test_data_tex = train_data_txt.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
# 装载数据
dd = datasets.DatasetDict({"train":train_data_txt,"validation": validation_data_txt,"test":test_data_tex }) 

raw_datasets = dd
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)


model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)


logger = logging.get_logger(__name__)


batch_size = 4
args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=50,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=batch_size,  # demo
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-04,
    warmup_steps=500,
    weight_decay=0.001,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=500,
    evaluation_strategy="epoch",
    save_total_limit=3,
    
    # generation_max_length最大生成长度,系统默认20 generation_num_beams=1表示贪心解码,大于1为树搜索
    generation_max_length=64,
    generation_num_beams=1,
)



data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


# 这里用的是中文lawrouge 至于字符级还是词级计算看自己调整 这里是字符级
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]
    decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]
    # Rouge with jieba cut
    # decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]
    # decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]

    labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]
    # length = len(prediction_lens)

    # print(decoded_preds)
    # print(decoded_labels)
    rouge = lawrouge.Rouge()

    result = rouge.get_scores(decoded_preds, decoded_labels,avg=True)
    # print(result)
    print(result)
    result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}

    result = {key: value * 100 for key, value in result.items()}
    return result;

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

train_result = trainer.train()
print(train_result)

trainer.save_model()
metrics = train_result.metrics
trainer.log_metrics("train",metrics)
trainer.save_metrics("train",metrics)
trainer.save_state()

import torch
model.load_state_dict(torch.load('./results/pytorch_model.bin'))

def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples,
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
 
    attention_mask = inputs.attention_mask.to(model.device)
  
    outputs = model.generate(input_ids, attention_mask=attention_mask,max_length=128)
    print(outputs)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str

_,x = generate_summary("20日凌晨,寒风刺骨,两名年纪相仿的婴儿相继被狠心的父母遗弃在翔安的两个角落,一个在莲花总院厕所里,一个在东园社区一榕树下。两名婴儿被发现时间相距不过10分钟,莲河边防派出所民警连夜走访,未寻得婴儿家属。目前,一名婴儿已被送往福利院,另一名暂时安置在村民家中。据悉,经医生初步检查,两名婴儿均身体健康,无残疾、无疾病。记者陈佩珊通讯员蔡美娟林才龙",model)
print(x)
print(len(x[0]))

'''
tensor([[ 102,  101, 1336, 7305, 5425, 2128,  697, 1399, 2399, 5279, 4685,  820,
         4638, 2048, 1036, 4685, 5326, 6158, 6890, 2461, 1762,  697,  702, 6235,
         5862,  117,  671,  782, 1762, 5813, 5709, 2600, 7368, 1329, 2792, 7027,
          117,  671,  702, 1762,  691, 1736, 4852, 1277,  671, 3525, 3409,  678,
          511,  102]], device='cuda:0')
['厦 门 翔 安 两 名 年 纪 相 仿 的 婴 儿 相 继 被 遗 弃 在 两 个 角 落, 一 人 在 莲 花 总 院 厕 所 里, 一 个 在 东 园 社 区 一 榕 树 下 。']
91
'''

eval_results = trainer.evaluate()
print(eval_results)

2 测试

import logging

import jieba
import lawrouge
import numpy as np
import datasets
import torch
from datasets import load_dataset, Dataset
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge
from torch.utils.data import dataloader
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, BertTokenizer, \
    BartForConditionalGeneration



logger = logging.getLogger("bart-base-chinese")
logging.basicConfig(level=logging.INFO)

dataset = load_dataset('json', data_files=[r'E:\zwj\nlp\datasets\Summarization\NLPCC2017\evaluation_with_ground_truth_origianl.txt'], field='data')

tokenizer = BertTokenizer.from_pretrained(r'checkpoint-337500') 
model = BartForConditionalGeneration.from_pretrained(r'checkpoint-337500')

def flatten(example):
    return {
        "document": example["article"],
        "summary": example["summarization"],
        "id": "0"
    }

dataset = dataset["train"].map(flatten, remove_columns=["summarization", "article"])

max_input_length = 512
max_target_length = 64
model_inputs = tokenizer(dataset[0]["document"], max_length=max_input_length,padding="max_length", truncation=True)



def preprocess_function(examples):
    inputs = [doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return  model_inputs

# 装载数据
dd = datasets.DatasetDict({"test": dataset})
test_batch_size = 1
tokenized_datasets = dd.map(preprocess_function, batched=True)
args = Seq2SeqTrainingArguments(
    fp16 = True,
    output_dir=r'./',
    do_eval=True,
    per_device_eval_batch_size=test_batch_size,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
)
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["test"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
    tokenizer=tokenizer,
)

eval_dataloader = trainer.get_eval_dataloader()

model.to("cuda:0")

rouge = Rouge()
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0

num_return_sequences = 4

for i,batch in enumerate(tqdm(eval_dataloader)):
    model.train()
    output = model.generate(
        input_ids=batch["input_ids"].to("cuda:0"),
        attention_mask=batch["attention_mask"].to("cuda:0"),
        early_stopping=True,
        num_beams=num_return_sequences,
        length_penalty=1.,
        # no_repeat_ngram_size=0,
        # diversity_penalty=0.,
        num_return_sequences=1,
        num_beam_groups=1,
        max_length=64,
    )
    #
    lsp = []
    for s in range(test_batch_size):
        lsp.append(int(s * num_return_sequences)) # num_return_sequences
    outputs = output[lsp,:]
    labels = batch["labels"]
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    decoded_preds = [" ".join((pred.replace(" ", ""))) for pred in decoded_preds]
    decoded_labels = [" ".join((label.replace(" ", ""))) for label in decoded_labels]
    for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds):
        scores = rouge.get_scores(hyps=decoded_pred, refs=decoded_label)
        rouge_1 += scores[0]['rouge-1']['f']
        rouge_2 += scores[0]['rouge-2']['f']
        rouge_l += scores[0]['rouge-l']['f']
        bleu += sentence_bleu(
            references=[decoded_label.split(' ')],
            hypothesis=decoded_pred.split(' '),
            smoothing_function=SmoothingFunction().method1
        )
bleu /= len(dataset)
rouge_1 /= len(dataset)
rouge_2 /= len(dataset)
rouge_l /= len(dataset)



3 数据格式

这里强调一下,这里使用的复旦fnlp/bart-large-chinese的bart中文预训练模型,数据集是nlpcc,五万条数据,只训练epoch=1,rouge-1 49还算可以吧,感谢复旦提供预训练模型。

lcsts的数据集也可以直接训练,预处理方式见上一片博客。

# 用PART_I做训练的,PART_II做验证的,epoch = 1

# 给几个测试的结果,有的确实很不错,但有的就很糟糕了,颠倒是非了
'''
(1)
中信证券员工迎娶世界小姐张梓琳的消息,瞬间引爆投行圈。何许人能娶到世界小姐?据爆料,张梓琳的新郎名叫聂磊(Neil),目前是中信证券债务资本市场部的SVP(高级副总裁)。有投行圈人士感叹“我追高圆圆也有戏啊”。
生成:中 信 证 券 员 工 迎 娶 世 界 小 姐 投 行 圈 人 士
参考:中信员工迎娶世界小姐张梓琳投行男直呼励志
(2)此前曾宣布退出线下POS市场的支付宝,近日正在宁夏、江西等地布局线下支付业务。此举被一些业内人士解读为向“银联发起总攻”。一位业内分析人士更是指出,支付宝甚至可能挑战中国银联在线下支付市场的主导地位,“成为第二家‘银联’”
生成:支 付 宝 向 银 联 发 起 总 攻
参考:支付宝重返线下市场二维码支付监管标准公布可期
(3)上海质局对上海生产和销售的水嘴产品质量进行了专项监督抽查,共抽查水嘴产品68批次,不合格产品达21批次,其中包括成霖洁具等。3成不合格产品中6批次产品经检测析出过量的铅。过量的铅会损害人的神经系统、造血系统,甚至生殖系统。
生成:成 霖 洁 具 等 6 批 次 水 嘴 产 品 检 出 过 量 铅
参考:成霖洁具等被爆铅超标或损害造血生殖系统
(4)韩方应对路径可以概括为:企业道歉担责;政府公正不护短;民间祈福关怀。他们深知形象的重要,竭力呵护企业品牌和国家形象。正如有评论,韩国“政府+企业+民众”三位一体式呵护韩国国家形象的“苦心经营”,的确有值得我们借鉴之处。
生成:韩 国 国 家 形 象 的 苦 心 经 营
参考:从韩亚航空事故看其应对路径
(5)63岁退休教师谢淑华,拉着人力板车,历时1年,走了2万4千里路,带着年过九旬的妈妈环游中国,完成了妈妈“一辈子在锅台边转,也想出去走走”的心愿。她说:“妈妈愿意出去走走,我就愿意拉着,孝心不能等,能走多远就走多远。
生成:63 岁 女 教 师 拉 板 车 带 九 旬 母 亲 游 中 国
参考:女子用板车拉九旬老母环游中国1年走2万4千里
'''

我还测试了复旦cpt模型,效果感觉不太行阿。

细节不同的地方如下:

from modeling_cpt import CPTForConditionalGeneration
model_checkpoint = "fnlp/cpt-large"

注意不要搞错了,具体见下一篇

ps:注意计算metric是的rouge,有两个库都可以计算,rouge和lawrouge

(1)不分词:直接调用lawrouge中文计算函数就行不用分词

(2)rouge库需要将中文用空格分开(不分词+空格隔开),如果要计算分词后的rouge值也要用rouge库(分词+空格隔开)

(3)注意:lawrouge比rouge库计算指标低一些,做论文一定要注意!

import lawrouge
from rouge import Rouge
hypothesis = ['我 爱 最 美 的 中 国','老 夫 子']
reference = ['我 爱 中 国','你 好 老 夫子']

rouge = Rouge()
scores = rouge.get_scores(hypothesis, reference)
print(scores)

rouge = lawrouge.Rouge()
scores = rouge.get_scores(['我爱最美的中国','老夫子'], ['我爱中国','你好老夫子'])
print(scores)

'''
[{'rouge-1': {'r': 1.0, 'p': 0.5714285714285714, 'f': 0.7272727226446282}, 'rouge-2': {'r': 0.6666666666666666, 'p': 0.3333333333333333, 'f': 0.44444444000000005}, 'rouge-l': {'r': 1.0, 'p': 0.5714285714285714, 'f': 0.7272727226446282}}, {'rouge-1': {'r': 0.6, 'p': 1.0, 'f': 0.7499999953125}, 'rouge-2': {'r': 0.5, 'p': 1.0, 'f': 0.6666666622222223}, 'rouge-l': {'r': 0.6, 'p': 1.0, 'f': 0.7499999953125}}]
[{'rouge-1': {'f': 0.7272727226446282, 'p': 0.5714285714285714, 'r': 1.0}, 'rouge-2': {'f': 0.44444444000000005, 'p': 0.3333333333333333, 'r': 0.6666666666666666}, 'rouge-l': {'f': 0.7272727226446282, 'p': 0.5714285714285714, 'r': 1.0}}, {'rouge-1': {'f': 0.7499999953125, 'p': 1.0, 'r': 0.6}, 'rouge-2': {'f': 0.6666666622222223, 'p': 1.0, 'r': 0.5}, 'rouge-l': {'f': 0.7499999953125, 'p': 1.0, 'r': 0.6}}]
'''

pps:transformers==4.4.1 pytorch = 1.10.0


给大家提供一个经过简单清洗的CNewSum中文摘要数据集,数据集共包含275596条摘要数据,源文件是经过分词的,我这里没有分词,先提供一个训练集。

百度网盘链接:CNewSum_trainhttps://pan.baidu.com/s/1FXq_deJ0Yi9rfSAElQLoIQ 

提取码:b6f9


nlpcc2017清洗数据

链接:nlpcc2017_cleanhttps://pan.baidu.com/s/1qXC81XmcWY9GprQKe5i58whttps://pan.baidu.com/s/1qXC81XmcWY9GprQKe5i58w 
提取码:knci


提供一个经过层次位置分解编码的BART中文预训练权重支持最大输入长度为1024,修改配置文件config.json的max_position_embeddings=1024即可

百度网盘链接:bart-large-chinese-1024https://pan.baidu.com/s/1xisndfl27sOm1YyZzp-vYw

提取码:b6f9


参考文献:pytorch 使用BART模型进行中文自动摘要_keep-hungry的博客-CSDN博客

复旦bart中文预训练模型:小: https://huggingface.co/fnlp/bart-base-chinese/tree/main

大:

复旦cpt中文预训练模型:fnlp/cpt-large · Hugging Face

来源:道天翁

物联沃分享整理
物联沃-IOTWORD物联网 » BART中文摘要生成,(nplcc与LCSTS数据集)

发表评论