文本预训练

导入依赖的包

from easytexminer.applications import get_application_predictor
from easytexminer.core import Evaluator, PredictorManager
from easytexminer.core.trainer import Trainer
from easytexminer.data import get_dataset
from easytexminer.losses import cross_entropy
from easytexminer.model_zoo import BertPreTrainedModel, BertModel
from easytexminer.model_zoo.models.bert.modeling_bert import BertOnlyMLMHead
from easytexminer.utils import initialize_easytexminer, get_args

定义预训练需要的BertForMaskedLM

class BertForMaskedLM(BertPreTrainedModel):
    def __init__(self, config, **kwargs):
        self.model_name = "language_modeling_bert"
        super(BertForMaskedLM, self).__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.lm_head = BertOnlyMLMHead(config)

        self.init_weights()

    def forward(self, inputs):
        outputs = self.bert(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"])

        sequence_output = outputs[0]
        prediction_scores = self.lm_head(sequence_output)

        return {
            "logits": prediction_scores
        }

    def compute_loss(self, model_outputs, inputs):
        prediction_scores = model_outputs["logits"]
        masked_lm_labels = inputs["label_ids"]
        masked_lm_loss = cross_entropy(
            prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
        return {"loss": masked_lm_loss}

初始化EasyTexMiner平台

initialize_easytexminer()

创建预训练模型

model = BertForMaskedLM.from_pretrained(
    pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
    model_name=cfg.model_name)

创建数据集

train_dataset = get_dataset(model_type=cfg.model_name,
                            pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
                            data_file=cfg.tables.split(",")[0],
                            max_seq_length=cfg.seq_length,
                            is_training=True)

valid_dataset = get_dataset(model_type=cfg.model_name,
                            pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
                            data_file=cfg.tables.split(",")[-1],
                            max_seq_length=cfg.seq_length,
                            is_training=False)

创建训练引擎

trainer = Trainer(model=model,
                  train_dataset=train_dataset,
                  valid_dataset=valid_dataset)

开始训练

trainer.train()

创建评估引擎

evaluator = Evaluator(metrics=valid_dataset.eval_metrics)

开始评估

evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=cfg.micro_batch_size)

开始预测

predictor = get_application_predictor(
    model_type=cfg.model_name, model_dir=cfg.checkpoint_dir,
    first_sequence=cfg.first_sequence,
    second_sequence=cfg.second_sequence,
    sequence_length=cfg.seq_length)

predictor_manager = PredictorManager(
    predictor=predictor,
    input_file=cfg.tables.split(",")[-1],
    input_schema=cfg.input_schema,
    output_file=cfg.outputs,
    output_schema=cfg.output_schema,
    append_cols=cfg.append_cols,
    batch_size=cfg.micro_batch_size
)

predictor_manager.run()