基于Bert的文本分类

导入依赖的包

import torch
from torch import nn
from easytexminer.core import Trainer
from easytexminer.core import Evaluator
from easytexminer.core import PredictorManager
from easytexminer.data import BertClassificationDataset
from easytexminer.losses import cross_entropy
from easytexminer.model_zoo import BertModel, BertPreTrainedModel
from easytexminer.applications import get_application_predictor
from easytexminer.utils import initialize_easytexminer, get_args

定义文本分类所需的类BertTextClassify

class BertTextClassify(BertPreTrainedModel):
def __init__(self, config, **kwargs):
    super(BertTextClassify, self).__init__(config)
    self.model_name = "text_classify_bert"
    self.bert = BertModel(config)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.init_weights()

def forward(self, inputs):
    outputs = self.bert(input_ids=inputs['input_ids'],
                        token_type_ids=inputs['token_type_ids'],
                        attention_mask=inputs['attention_mask'])

    pooler_output = outputs.pooler_output
    hidden_states = outputs.hidden_states
    pooler_output = self.dropout(pooler_output)
    logits = self.classifier(pooler_output)
    return {
        "hidden": hidden_states,
        "logits": logits,
        "predictions": torch.argmax(logits, dim=-1),
        "probabilities": torch.softmax(logits, dim=-1)
    }

def compute_loss(self, model_outputs, inputs):
    logits = model_outputs["logits"]
    label_ids = inputs["label_ids"]
    return {
        "loss": cross_entropy(logits, label_ids)
    }

初始化EasyTexMiner平台

initialize_easytexminer()

创建预训练模型

model = BertTextClassify.from_pretrained(
        pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
        num_labels=valid_dataset.max_num_labels)

创建数据集

train_dataset = BertClassificationDataset(
        pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
        data_file=cfg.tables.split(",")[0],
        max_seq_length=cfg.sequence_length,
        input_schema=cfg.input_schema,
        first_sequence=cfg.first_sequence,
        second_sequence=cfg.second_sequence,
        label_name=cfg.label_name,
        label_enumerate_values=cfg.label_enumerate_values,
        is_training=True)

valid_dataset = BertClassificationDataset(
    pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
    data_file=cfg.tables.split(",")[-1],
    max_seq_length=cfg.sequence_length,
    input_schema=cfg.input_schema,
    first_sequence=cfg.first_sequence,
    second_sequence=cfg.second_sequence,
    label_name=cfg.label_name,
    label_enumerate_values=cfg.label_enumerate_values,
    is_training=False)

创建训练引擎

trainer = Trainer(model=model, train_dataset=train_dataset, valid_dataset=valid_dataset)

开始训练

trainer.train()

创建评估引擎

evaluator = Evaluator(metrics=valid_dataset.eval_metrics)

开始评估

evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=args.micro_batch_size)

开始预测

predictor = get_application_predictor(
    model_type=args.model_name, model_dir=args.checkpoint_dir,
    first_sequence=args.first_sequence,
    second_sequence=args.second_sequence,
    sequence_length=args.sequencee_length)
predictor_manager = PredictorManager(
    predictor=predictor,
    input_file=args.tables.split(",")[-1],
    input_schema=args.input_schema,
    output_file=args.outputs,
    output_schema=args.output_schema,
    append_cols=args.append_cols,
    batch_size=args.micro_batch_size
)
predictor_manager.run()