基于Bert的文本分类¶
导入依赖的包
import torch
from torch import nn
from easytexminer.core import Trainer
from easytexminer.core import Evaluator
from easytexminer.core import PredictorManager
from easytexminer.data import BertClassificationDataset
from easytexminer.losses import cross_entropy
from easytexminer.model_zoo import BertModel, BertPreTrainedModel
from easytexminer.applications import get_application_predictor
from easytexminer.utils import initialize_easytexminer, get_args
定义文本分类所需的类BertTextClassify
class BertTextClassify(BertPreTrainedModel):
def __init__(self, config, **kwargs):
super(BertTextClassify, self).__init__(config)
self.model_name = "text_classify_bert"
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
def forward(self, inputs):
outputs = self.bert(input_ids=inputs['input_ids'],
token_type_ids=inputs['token_type_ids'],
attention_mask=inputs['attention_mask'])
pooler_output = outputs.pooler_output
hidden_states = outputs.hidden_states
pooler_output = self.dropout(pooler_output)
logits = self.classifier(pooler_output)
return {
"hidden": hidden_states,
"logits": logits,
"predictions": torch.argmax(logits, dim=-1),
"probabilities": torch.softmax(logits, dim=-1)
}
def compute_loss(self, model_outputs, inputs):
logits = model_outputs["logits"]
label_ids = inputs["label_ids"]
return {
"loss": cross_entropy(logits, label_ids)
}
初始化EasyTexMiner平台
initialize_easytexminer()
创建预训练模型
model = BertTextClassify.from_pretrained(
pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
num_labels=valid_dataset.max_num_labels)
创建数据集
train_dataset = BertClassificationDataset(
pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
data_file=cfg.tables.split(",")[0],
max_seq_length=cfg.sequence_length,
input_schema=cfg.input_schema,
first_sequence=cfg.first_sequence,
second_sequence=cfg.second_sequence,
label_name=cfg.label_name,
label_enumerate_values=cfg.label_enumerate_values,
is_training=True)
valid_dataset = BertClassificationDataset(
pretrained_model_name_or_path=cfg.pretrained_model_name_or_path,
data_file=cfg.tables.split(",")[-1],
max_seq_length=cfg.sequence_length,
input_schema=cfg.input_schema,
first_sequence=cfg.first_sequence,
second_sequence=cfg.second_sequence,
label_name=cfg.label_name,
label_enumerate_values=cfg.label_enumerate_values,
is_training=False)
创建训练引擎
trainer = Trainer(model=model, train_dataset=train_dataset, valid_dataset=valid_dataset)
开始训练
trainer.train()
创建评估引擎
evaluator = Evaluator(metrics=valid_dataset.eval_metrics)
开始评估
evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=args.micro_batch_size)
开始预测
predictor = get_application_predictor(
model_type=args.model_name, model_dir=args.checkpoint_dir,
first_sequence=args.first_sequence,
second_sequence=args.second_sequence,
sequence_length=args.sequencee_length)
predictor_manager = PredictorManager(
predictor=predictor,
input_file=args.tables.split(",")[-1],
input_schema=args.input_schema,
output_file=args.outputs,
output_schema=args.output_schema,
append_cols=args.append_cols,
batch_size=args.micro_batch_size
)
predictor_manager.run()