文本预训练¶
导入依赖的包
from easytexminer.applications import get_application_predictor
from easytexminer.core import Evaluator, PredictorManager
from easytexminer.core.trainer import Trainer
from easytexminer.data import get_dataset
from easytexminer.losses import cross_entropy
from easytexminer.model_zoo import BertPreTrainedModel, BertModel
from easytexminer.model_zoo.models.bert.modeling_bert import BertOnlyMLMHead
from easytexminer.utils import initialize_easytexminer, get_args
定义预训练需要的BertForMaskedLM
class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config, **kwargs):
self.model_name = "language_modeling_bert"
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.lm_head = BertOnlyMLMHead(config)
self.init_weights()
def forward(self, inputs):
outputs = self.bert(
inputs["input_ids"],
attention_mask=inputs["attention_mask"])
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
return {
"logits": prediction_scores
}
def compute_loss(self, model_outputs, inputs):
prediction_scores = model_outputs["logits"]
masked_lm_labels = inputs["label_ids"]
masked_lm_loss = cross_entropy(
prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return {"loss": masked_lm_loss}
初始化EasyTexMiner平台
initialize_easytexminer()
创建预训练模型
model = BertForMaskedLM.from_pretrained(
pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
model_name=cfg.model_name)
创建数据集
train_dataset = get_dataset(model_type=cfg.model_name,
pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
data_file=cfg.tables.split(",")[0],
max_seq_length=cfg.seq_length,
is_training=True)
valid_dataset = get_dataset(model_type=cfg.model_name,
pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
data_file=cfg.tables.split(",")[-1],
max_seq_length=cfg.seq_length,
is_training=False)
创建训练引擎
trainer = Trainer(model=model,
train_dataset=train_dataset,
valid_dataset=valid_dataset)
开始训练
trainer.train()
创建评估引擎
evaluator = Evaluator(metrics=valid_dataset.eval_metrics)
开始评估
evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=cfg.micro_batch_size)
开始预测
predictor = get_application_predictor(
model_type=cfg.model_name, model_dir=cfg.checkpoint_dir,
first_sequence=cfg.first_sequence,
second_sequence=cfg.second_sequence,
sequence_length=cfg.seq_length)
predictor_manager = PredictorManager(
predictor=predictor,
input_file=cfg.tables.split(",")[-1],
input_schema=cfg.input_schema,
output_file=cfg.outputs,
output_schema=cfg.output_schema,
append_cols=cfg.append_cols,
batch_size=cfg.micro_batch_size
)
predictor_manager.run()