# coding=utf-8
# Copyright (c) 2019 Alibaba PAI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import six
import numpy as np
import tensorflow as tf
from easytransfer.engines.distribution import Process
[docs]class ClassificationPostprocessor(Process):
""" Postprocessor for text classification, convert label_id to the label_name
"""
def __init__(self,
label_enumerate_values,
output_schema,
thread_num=None,
input_queue=None,
output_queue=None,
prediction_colname="predictions",
job_name='CLSpostprocessor'):
super(ClassificationPostprocessor, self).__init__(
job_name, thread_num, input_queue, output_queue, batch_size=1)
self.prediction_colname = prediction_colname
self.label_enumerate_values = label_enumerate_values
self.output_schema = output_schema
if label_enumerate_values is not None:
self.idx_label_map = dict()
for (i, label) in enumerate(label_enumerate_values.split(",")):
if six.PY2:
self.idx_label_map[i] = label.encode("utf8")
else:
self.idx_label_map[i] = label
[docs] def process(self, in_data):
""" Post-process the model outputs
Args:
in_data (`dict`): a dict of model outputs
Returns:
ret (`dict`): a dict of post-processed model outputs
"""
if self.label_enumerate_values is None:
return in_data
tmp = {key: val for key, val in in_data.items()}
if self.prediction_colname in tmp:
raw_preds = tmp[self.prediction_colname]
new_preds = []
for raw_pred in raw_preds:
if isinstance(raw_pred, list) or isinstance(raw_pred, np.ndarray):
pred = ",".join(
[self.idx_label_map[idx] for idx, val
in enumerate(raw_pred) if val == 1])
else:
pred = self.idx_label_map[int(raw_pred)]
new_preds.append(pred)
tmp[self.prediction_colname] = np.array(new_preds)
ret = dict()
for output_col_name in self.output_schema.split(","):
if output_col_name in tmp:
ret[output_col_name] = tmp[output_col_name]
return ret
[docs]class MultiTaskClassificationPostprocessor(Process):
""" Postprocessor for text classification, convert label_id to the label_name
"""
def __init__(self,
label_enumerate_values,
output_schema,
thread_num=None,
input_queue=None,
output_queue=None,
prediction_colname="predictions",
job_name='CLSpostprocessor'):
super(MultiTaskClassificationPostprocessor, self).__init__(
job_name, thread_num, input_queue, output_queue, batch_size=1)
self.prediction_colname = prediction_colname
self.label_meta_info_path = label_enumerate_values
self.output_schema = output_schema
self.task_id_to_label_inv_mapping = dict()
with tf.gfile.Open(self.label_meta_info_path.strip("^")) as f:
label_meta_info = json.load(f)
self.max_label_size = max([len(t["labelMap"]) for t in label_meta_info])
for task_label_info in label_meta_info:
task_idx = task_label_info["taskIndex"]
labels = task_label_info["labelMap"]
if six.PY2:
label_inv_map = {idx: label.encode("utf-8") for idx, label in enumerate(labels)}
else:
label_inv_map = {idx: label for idx, label in enumerate(labels)}
self.task_id_to_label_inv_mapping[task_idx] = label_inv_map
[docs] def process(self, in_data):
""" Post-process the model outputs
Args:
in_data (`dict`): a dict of model outputs
Returns:
ret (`dict`): a dict of post-processed model outputs
"""
tmp = {key: val for key, val in in_data.items()}
if self.prediction_colname in tmp:
raw_preds = tmp[self.prediction_colname]
task_ids = tmp["task_ids"]
new_preds = []
for raw_pred, task_id in zip(raw_preds, task_ids):
label_inv_map = self.task_id_to_label_inv_mapping[task_id]
idx = int(raw_pred)
if idx >= len(label_inv_map):
idx = 0
new_preds.append(label_inv_map[idx])
tmp[self.prediction_colname] = np.array(new_preds)
ret = dict()
for output_col_name in self.output_schema.split(","):
if output_col_name in tmp:
ret[output_col_name] = tmp[output_col_name]
return ret