easytexminer.applications

BERT Text Classify

class easytexminer.applications.classification.bert.BertTextClassify(config, **kwargs)[source]

Bert Model with a classification head on top (a linear layer on top of the pooled output).

This model inherits from BertPreTrainedModel.

Parameters

config -- Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration.

forward(inputs)[source]

Forward Method.

Parameters
  • inputs -- The input of the model.

  • inputs['input_ids'] --

    torch.LongTensor of shape (batch_size, sequence_length))

    Indices of input sequence tokens in the vocabulary.

    Indices can be obtained using BertTokenizer.

  • inputs['token_type_ids'] --

    torch.LongTensor of shape (batch_size, sequence_length).

    Segment token indices to indicate first and second portions of the inputs. Indices are selected in [0,1]:

    • 0 corresponds to a sentence A token,

    • 1 corresponds to a sentence B token.

  • inputs['attention_mask'] --

    torch.FloatTensor of shape (batch_size, sequence_length).

    Mask to avoid performing attention on padding token indices. Mask values selected in [0,1]:

    • 0 for token that are not masked,

    • 1 for token that are masked.

Returns: A Dict contains four elements.
hidden: tuple(torch.FloatTensor)

Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

logits: torch.FloatTensor of shape (batch_size, num_labels)

Last layer hidden state of the first token of the sequence (classification token) further processed by a linear layer (classifier).

predictions: torch.FloatTensor of shape (batch_size, 1)

Applies the Argmax function to pooler_output tensor to obtain the label of each sample.

probabilities: torch.FloatTensor of shape (batch_size, num_labels)

Applies the Softmax function to pooler_output tensor rescaling them so that the elements of the output Tensor lie in the range [0,1] and sum to 1.

compute_loss(model_outputs, inputs)[source]

Compute the cross entropy loss of the predicted label and the ground truth label.

Parameters
  • model_outputs -- The output of BertTextClassify.

  • inputs -- The input of BertTextClassify.

class easytexminer.applications.classification.bert.BertTextClassifyPredictor(model_dir, model_cls=None, *args, **kwargs)[source]
preprocess(in_data)[source]
predict(in_data)[source]
postprocess(result)[source]

TextCNN Text Classify

class easytexminer.applications.classification.cnn.CNNTextClassify(config, **kwargs)[source]

CNN Model with a classification head on top (a linear layer on top of the pooled output).

This model inherits from TextCNNEncoder.

Parameters

config -- Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration.

config_class

alias of easytexminer.model_zoo.models.cnn.configuration_cnn.TextCNNConfig

forward(inputs)[source]

Forward Method.

Parameters
  • inputs -- The input of the model, A Dict must contain input_ids.

  • inputs['input_ids'] --

    torch.LongTensor of shape (batch_size, sequence_length))

    Indices of input sequence tokens in the vocabulary.

    Indices can be obtained using BertTokenizer.

Returns: A Dict contains three elements.
logits: torch.FloatTensor of shape (batch_size, num_labels)

Last layer hidden state of the first token of the sequence (classification token) further processed by a linear layer (classifier).

predictions: torch.FloatTensor of shape (batch_size, 1)

Applies the Argmax function to pooler_output tensor to obtain the label of each sample.

probabilities: torch.FloatTensor of shape (batch_size, num_labels)

Applies the Softmax function to pooler_output tensor rescaling them so that the elements of the output Tensor lie in the range [0,1] and sum to 1.

compute_loss(model_outputs, inputs)[source]
class easytexminer.applications.classification.cnn.CNNTextClassifyPredictor(model_dir, model_cls=None, *args, **kwargs)[source]
preprocess(in_data)[source]
predict(in_data)[source]
postprocess(result)[source]