# coding=utf-8
# Copyright (c) 2020 Alibaba PAI team and The HuggingFace Inc. team and Facebook, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import numpy as np
import torch
import json
from .base import BaseDataset
from ..model_zoo import AutoTokenizer
[docs]class WMMLanguageModelDataset(BaseDataset):
""" Whole word mask Language Model Dataset
"""
def __init__(self,
pretrained_model_name_or_path,
data_file,
max_seq_length,
mlm_mask_prop=0.15,
**kwargs):
super(WMMLanguageModelDataset, self).__init__(data_file, **kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
vocab = self.tokenizer.get_vocab()
self.vocab_size = len(vocab)
self.cls_ids = vocab["[CLS]"]
self.pad_idx = vocab["[PAD]"]
self.mask_idx = vocab["[MASK]"]
self.sep_ids = vocab["[SEP]"]
self.fp16 = False
self.mlm_mask_prop = mlm_mask_prop
self.max_seq_length = max_seq_length
@property
def eval_metrics(self):
return ('mlm_accuracy',)
@property
def label_enumerate_values(self):
return []
[docs] def convert_single_row_to_example(self, row):
text = json.loads(row.strip())['text']
token_ids = [self.cls_ids]
for sentence in text:
sentence_ids = self.tokenizer.tokenize(sentence)
token_ids.extend(self.tokenizer.convert_tokens_to_ids(sentence_ids))
token_ids = token_ids[:self.max_seq_length-1]
token_ids.append(self.sep_ids)
ref_tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
mask_labels, mask_span_indices = self._whole_word_mask(ref_tokens)
return token_ids, mask_labels, mask_span_indices
[docs] def batch_fn(self, batch):
token_ids = [t[0] for t in batch]
mask_labels = [t[1] for t in batch]
lengths = [len(t[0]) for t in batch]
# Max for paddings
max_seq_len_ = max(lengths)
assert max_seq_len_ <= self.max_seq_length
# Pad token ids
padded_token_ids = [t + [self.pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
padded_mask_labels = [t + [self.pad_idx] * (max_seq_len_ - len(t)) for t in mask_labels]
assert len(padded_token_ids) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in padded_token_ids)
assert all(len(t) == max_seq_len_ for t in padded_mask_labels)
token_ids = torch.LongTensor(padded_token_ids)
mask_labels = torch.LongTensor(padded_mask_labels)
lengths = torch.tensor(lengths) # (bs)
attn_mask = torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]
attn_mask = attn_mask.long()
input_ids, label_ids = self.mask_tokens(token_ids, mask_labels)
return {
"input_ids": input_ids,
"attention_mask": attn_mask,
"label_ids": label_ids,
"mask_span_indices": [t[2] for t in batch]
}
def _whole_word_mask(self, input_tokens, max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
"""
cand_indexes = []
for (i, token) in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_mask_prop))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
assert len(covered_indexes) == len(masked_lms)
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
mask_span_indices = [t for t in cand_indexes if t[0] in covered_indexes]
return mask_labels, mask_span_indices
[docs] def mask_tokens(self, inputs, mask_labels):
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
Set 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = mask_labels
# special_tokens_mask = [
# self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
# ]
# probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
padding_mask = labels.eq(self.pad_idx)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.mask_idx
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(self.vocab_size, labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels