Source code for easytexminer.applications.classification.bert

import json
import os
import uuid

import numpy as np
import torch
import torch.nn as nn

from ... import losses
from ...core.predictor import Predictor, get_model_predictor
from ...utils import io
from ...model_zoo import BertModel, BertPreTrainedModel
from ...model_zoo import BertTokenizer

[docs]class BertTextClassify(BertPreTrainedModel): """ Bert Model with a classification head on top (a linear layer on top of the pooled output). This model inherits from BertPreTrainedModel. Parameters: config: Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. """ def __init__(self, config, **kwargs): super(BertTextClassify, self).__init__(config) self.model_name = "text_classify_bert" self.bert = BertModel(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.init_weights()
[docs] def forward(self, inputs): """Forward Method. Args: inputs: The input of the model. inputs['input_ids']: :obj:`torch.LongTensor` of shape (batch_size, sequence_length)) Indices of input sequence tokens in the vocabulary. Indices can be obtained using BertTokenizer. inputs['token_type_ids']: :obj:`torch.LongTensor` of shape (batch_size, sequence_length). Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0,1]: - 0 corresponds to a sentence A token, - 1 corresponds to a sentence B token. inputs['attention_mask']: :obj:`torch.FloatTensor` of shape (batch_size, sequence_length). Mask to avoid performing attention on padding token indices. Mask values selected in [0,1]: - 0 for token that are not masked, - 1 for token that are masked. Returns: A Dict contains four elements. hidden: :obj:`tuple(torch.FloatTensor)` Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. logits: :obj:`torch.FloatTensor` of shape (batch_size, num_labels) Last layer hidden state of the first token of the sequence (classification token) further processed by a linear layer (classifier). predictions: :obj:`torch.FloatTensor` of shape (batch_size, 1) Applies the Argmax function to pooler_output tensor to obtain the label of each sample. probabilities: :obj:`torch.FloatTensor` of shape (batch_size, num_labels) Applies the Softmax function to pooler_output tensor rescaling them so that the elements of the output Tensor lie in the range [0,1] and sum to 1. """ outputs = self.bert(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'], attention_mask=inputs['attention_mask']) pooler_output = outputs.pooler_output hidden_states = outputs.hidden_states pooler_output = self.dropout(pooler_output) logits = self.classifier(pooler_output) return { "hidden": hidden_states, "logits": logits, "predictions": torch.argmax(logits, dim=-1), "probabilities": torch.softmax(logits, dim=-1) }
[docs] def compute_loss(self, model_outputs, inputs): """ Compute the cross entropy loss of the predicted label and the ground truth label. Args: model_outputs: The output of BertTextClassify. inputs: The input of BertTextClassify. """ logits = model_outputs["logits"] label_ids = inputs["label_ids"] return { "loss": losses.cross_entropy(logits, label_ids) }
[docs]class BertTextClassifyPredictor(Predictor): def __init__(self, model_dir, model_cls=None, *args, **kwargs): super(BertTextClassifyPredictor, self).__init__(*args, **kwargs) self.bert_tokenizer = BertTokenizer.from_pretrained(model_dir) self.model_predictor = get_model_predictor(model_dir=model_dir, model_cls=model_cls, input_keys=[("input_ids", torch.LongTensor), ("attention_mask", torch.LongTensor), ("token_type_ids", torch.LongTensor)], output_keys=["predictions", "probabilities", "logits"]) self.label_path = os.path.join(model_dir, "label_mapping.json") with io.open(self.label_path) as f: self.label_mapping = json.load(f) self.label_id_to_name = {idx: name for name, idx in self.label_mapping.items()} self.first_sequence = kwargs.pop("first_sequence", "first_sequence") self.second_sequence = kwargs.pop("second_sequence", "second_sequence") self.sequence_length = kwargs.pop("sequence_length", 128)
[docs] def preprocess(self, in_data): if not in_data: raise RuntimeError("Input data should not be None.") if not isinstance(in_data, list): in_data = [in_data] rst = { "id": [], "input_ids": [], "attention_mask": [], "token_type_ids": [] } max_seq_length = -1 for record in in_data: if not "sequence_length" in record: break max_seq_length = max(max_seq_length, record["sequence_length"]) max_seq_length = self.sequence_length if (max_seq_length == -1) else max_seq_length for record in in_data: text_a = record[self.first_sequence] text_b = record.get(self.second_sequence, None) feature = self.bert_tokenizer(text_a, text_b, padding='max_length', max_length=max_seq_length) rst["id"].append(record.get("id", str(uuid.uuid4()))) rst["input_ids"].append(feature["input_ids"]) rst["attention_mask"].append(feature["attention_mask"]) rst["token_type_ids"].append(feature["token_type_ids"]) return rst
[docs] def predict(self, in_data): return self.model_predictor.predict(in_data)
[docs] def postprocess(self, result): probs = result["probabilities"] logits = result["logits"] predictions = np.argsort(-probs, axis=-1) new_results = list() for b, preds in enumerate(predictions): new_result = list() for pred in preds: new_result.append({ "pred": self.label_id_to_name[pred], "prob": float(probs[b][pred]), "logit": float(logits[b][pred]) }) new_results.append({ "id": result["id"][b] if "id" in result else str(uuid.uuid4()), "output": new_result, "predictions": new_result[0]["pred"], "probabilities": ",".join([str(t) for t in result["probabilities"][b]]), "logits": ",".join([str(t) for t in result["logits"][b]]) }) if len(new_results) == 1: new_results = new_results[0] return new_results