基于CNN的文本分类¶
TextCNN文本分类是使用基于CNN网络的深度学习模型,输出为分类标签。具体实现脚本可以见于 scripts/cnn_classify
导入依赖的包
import torch
from torch import nn
from easytexminer import losses
from easytexminer.core import Trainer
from easytexminer.core import Evaluator
from easytexminer.core import PredictorManager
from easytexminer.data import CNNClassificationDataset
from easytexminer.applications import get_application_predictor
from easytexminer.utils import initialize_easytexminer, get_args
from easytexminer.model_zoo.modeling_utils import PreTrainedModel
from easytexminer.model_zoo.models.cnn import TextCNNEncoder, TextCNNConfig
构建TextCNN分类模型
class CNNTextClassify(PreTrainedModel):
""" CNN Classification """
config_class = TextCNNConfig
def __init__(self, config, **kwargs):
super(CNNTextClassify, self).__init__(config)
self.model_name = "text_classify_cnn"
self.cnn_encoder = TextCNNEncoder(config)
self.classifier = nn.Linear(config.linear_hidden_size, config.num_labels)
self.init_weights()
def forward(self, inputs):
sequence_output = self.cnn_encoder(inputs["input_ids"])
logits = self.classifier(sequence_output)
return {
"logits": logits,
"predictions": torch.argmax(logits, dim=-1),
"probabilities": torch.softmax(logits, dim=-1)
}
def compute_loss(self, model_outputs, inputs):
logits = model_outputs["logits"]
label_ids = inputs["label_ids"]
return {
"loss": losses.cross_entropy(logits, label_ids)
}
def _init_weights(self, module):
""" Initialize the weights.
"""
if getattr(module, "weight", False) is not False:
module.weight.data.normal_(
mean=0.0, std=0.02)
if getattr(module, "bias", False) is not False:
module.bias.data.zero_()
初始化EasyTexMiner,初始化前需要构造get_cnn_args函数传入TextCNN相比于Transformers类模型所要额外引入的参数
def get_cnn_args(parser):
"""Provide extra arguments required for tasks."""
cnn_group = parser.add_argument_group('textcnn-args')
# Arguments for cnn text classification
cnn_group.add_argument("--conv_dim", type=int, default=100,
help='convolution dimension')
cnn_group.add_argument("--kernel_sizes", type=str, default=None,
help='kernel size of cnn task.')
cnn_group.add_argument("--linear_hidden_size", type=int, default=None,
help='linear hidden dimensions for the feed-forward layer')
cnn_group.add_argument("--embed_size", type=int, default=None,
help='embedding dimension')
cnn_group.add_argument("--vocab_size", type=int, default=None,
help='vocab dimension')
return parser
initialize_easytexminer(extra_args_provider=get_cnn_args)
创建数据集
train_dataset = CNNClassificationDataset(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
data_file=args.tables.split(",")[0],
max_seq_length=args.sequence_length,
input_schema=args.input_schema,
first_sequence=args.first_sequence,
second_sequence=args.second_sequence,
label_name=args.label_name,
label_enumerate_values=args.label_enumerate_values,
is_training=True)
valid_dataset = CNNClassificationDataset(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
data_file=args.tables.split(",")[-1],
max_seq_length=args.sequence_length,
input_schema=args.input_schema,
first_sequence=args.first_sequence,
second_sequence=args.second_sequence,
label_name=args.label_name,
label_enumerate_values=args.label_enumerate_values,
is_training=False)
创建训练模型与Trainer。与BERT类模型不同点在于,CNN类模型没有相应的预训练模型进行载入,所以需要在设置TextCNNConfig后随机初始化;
args.vocab_size = len(valid_dataset.tokenizer)
config = TextCNNConfig(args.conv_dim, args.kernel_sizes, args.linear_hidden_size, args.embed_size, args.vocab_size, args.sequence_length)
model = CNNTextClassify(config=config)
trainer = Trainer(model=model,
train_dataset=train_dataset,
valid_dataset=valid_dataset)
开始训练
trainer.train()
创建评估引擎
evaluator = Evaluator(metrics=valid_dataset.eval_metrics)
evaluator.evaluate(model=model, valid_dataset=valid_dataset, eval_batch_size=args.micro_batch_size)
预测
predictor = get_application_predictor(
model_type=args.model_name, model_dir=args.checkpoint_dir,
first_sequence=args.first_sequence,
second_sequence=args.second_sequence,
sequence_length=args.sequence_length)
predictor_manager = PredictorManager(
predictor=predictor,
input_file=args.tables.split(",")[-1],
input_schema=args.input_schema,
output_file=args.outputs,
output_schema=args.output_schema,
append_cols=args.append_cols,
batch_size=args.micro_batch_size
)
predictor_manager.run()