import json
import os
import uuid
import numpy as np
import torch
import torch.nn as nn
from ... import losses
from ...core.predictor import Predictor, get_model_predictor
from ...utils import io
from ...model_zoo import BertModel, BertPreTrainedModel
from ...model_zoo import BertTokenizer
[docs]class BertTextClassify(BertPreTrainedModel):
"""
Bert Model with a classification head on top (a linear layer on top of the pooled output).
This model inherits from BertPreTrainedModel.
Parameters:
config: Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration.
"""
def __init__(self, config, **kwargs):
super(BertTextClassify, self).__init__(config)
self.model_name = "text_classify_bert"
self.bert = BertModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.init_weights()
[docs] def forward(self, inputs):
"""Forward Method.
Args:
inputs: The input of the model.
inputs['input_ids']: :obj:`torch.LongTensor` of shape (batch_size, sequence_length))
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using BertTokenizer.
inputs['token_type_ids']: :obj:`torch.LongTensor` of shape (batch_size, sequence_length).
Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0,1]:
- 0 corresponds to a sentence A token,
- 1 corresponds to a sentence B token.
inputs['attention_mask']: :obj:`torch.FloatTensor` of shape (batch_size, sequence_length).
Mask to avoid performing attention on padding token indices. Mask values selected in [0,1]:
- 0 for token that are not masked,
- 1 for token that are masked.
Returns: A Dict contains four elements.
hidden: :obj:`tuple(torch.FloatTensor)`
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
logits: :obj:`torch.FloatTensor` of shape (batch_size, num_labels)
Last layer hidden state of the first token of the sequence (classification token) further processed by a
linear layer (classifier).
predictions: :obj:`torch.FloatTensor` of shape (batch_size, 1)
Applies the Argmax function to pooler_output tensor to obtain the label of each sample.
probabilities: :obj:`torch.FloatTensor` of shape (batch_size, num_labels)
Applies the Softmax function to pooler_output tensor rescaling them so that the elements of the
output Tensor lie in the range [0,1] and sum to 1.
"""
outputs = self.bert(input_ids=inputs['input_ids'],
token_type_ids=inputs['token_type_ids'],
attention_mask=inputs['attention_mask'])
pooler_output = outputs.pooler_output
hidden_states = outputs.hidden_states
pooler_output = self.dropout(pooler_output)
logits = self.classifier(pooler_output)
return {
"hidden": hidden_states,
"logits": logits,
"predictions": torch.argmax(logits, dim=-1),
"probabilities": torch.softmax(logits, dim=-1)
}
[docs] def compute_loss(self, model_outputs, inputs):
"""
Compute the cross entropy loss of the predicted label and the ground truth label.
Args:
model_outputs: The output of BertTextClassify.
inputs: The input of BertTextClassify.
"""
logits = model_outputs["logits"]
label_ids = inputs["label_ids"]
return {
"loss": losses.cross_entropy(logits, label_ids)
}
[docs]class BertTextClassifyPredictor(Predictor):
def __init__(self, model_dir, model_cls=None, *args, **kwargs):
super(BertTextClassifyPredictor, self).__init__(*args, **kwargs)
self.bert_tokenizer = BertTokenizer.from_pretrained(model_dir)
self.model_predictor = get_model_predictor(model_dir=model_dir,
model_cls=model_cls,
input_keys=[("input_ids", torch.LongTensor),
("attention_mask", torch.LongTensor),
("token_type_ids", torch.LongTensor)],
output_keys=["predictions", "probabilities", "logits"])
self.label_path = os.path.join(model_dir, "label_mapping.json")
with io.open(self.label_path) as f:
self.label_mapping = json.load(f)
self.label_id_to_name = {idx: name for name, idx in self.label_mapping.items()}
self.first_sequence = kwargs.pop("first_sequence", "first_sequence")
self.second_sequence = kwargs.pop("second_sequence", "second_sequence")
self.sequence_length = kwargs.pop("sequence_length", 128)
[docs] def preprocess(self, in_data):
if not in_data:
raise RuntimeError("Input data should not be None.")
if not isinstance(in_data, list):
in_data = [in_data]
rst = {
"id": [],
"input_ids": [],
"attention_mask": [],
"token_type_ids": []
}
max_seq_length = -1
for record in in_data:
if not "sequence_length" in record:
break
max_seq_length = max(max_seq_length, record["sequence_length"])
max_seq_length = self.sequence_length if (max_seq_length == -1) else max_seq_length
for record in in_data:
text_a = record[self.first_sequence]
text_b = record.get(self.second_sequence, None)
feature = self.bert_tokenizer(text_a, text_b, padding='max_length', max_length=max_seq_length)
rst["id"].append(record.get("id", str(uuid.uuid4())))
rst["input_ids"].append(feature["input_ids"])
rst["attention_mask"].append(feature["attention_mask"])
rst["token_type_ids"].append(feature["token_type_ids"])
return rst
[docs] def predict(self, in_data):
return self.model_predictor.predict(in_data)
[docs] def postprocess(self, result):
probs = result["probabilities"]
logits = result["logits"]
predictions = np.argsort(-probs, axis=-1)
new_results = list()
for b, preds in enumerate(predictions):
new_result = list()
for pred in preds:
new_result.append({
"pred": self.label_id_to_name[pred],
"prob": float(probs[b][pred]),
"logit": float(logits[b][pred])
})
new_results.append({
"id": result["id"][b] if "id" in result else str(uuid.uuid4()),
"output": new_result,
"predictions": new_result[0]["pred"],
"probabilities": ",".join([str(t) for t in result["probabilities"][b]]),
"logits": ",".join([str(t) for t in result["logits"][b]])
})
if len(new_results) == 1:
new_results = new_results[0]
return new_results