|
|
import csv import io import itertools import json import re from collections import defaultdict from random import Random
from django.db import transaction from rest_framework.renderers import JSONRenderer from seqeval.metrics.sequence_labeling import get_entities
from app.settings import IMPORT_BATCH_SIZE from .exceptions import FileParseException from .models import Label from .serializers import DocumentSerializer, LabelSerializer
def extract_label(tag): ptn = re.compile(r'(B|I|E|S)-(.+)') m = ptn.match(tag) if m: return m.groups()[1] else: return tag
class BaseStorage(object):
def __init__(self, data, project): self.data = data self.project = project
@transaction.atomic def save(self, user): raise NotImplementedError()
def save_doc(self, data): serializer = DocumentSerializer(data=data, many=True) serializer.is_valid(raise_exception=True) doc = serializer.save(project=self.project) return doc
def save_label(self, data): serializer = LabelSerializer(data=data, many=True) serializer.is_valid(raise_exception=True) label = serializer.save(project=self.project) return label
def save_annotation(self, data, user): annotation_serializer = self.project.get_annotation_serializer() serializer = annotation_serializer(data=data, many=True) serializer.is_valid(raise_exception=True) annotation = serializer.save(user=user) return annotation
@classmethod def extract_label(cls, data): return [d.get('labels', []) for d in data]
@classmethod def exclude_created_labels(cls, labels, created): return [label for label in labels if label not in created]
@classmethod def to_serializer_format(cls, labels, created, random_seed=None): existing_shortkeys = {(label.suffix_key, label.prefix_key) for label in created.values()}
serializer_labels = []
for label in sorted(labels): serializer_label = {'text': label}
shortkey = cls.get_shortkey(label, existing_shortkeys) if shortkey: serializer_label['suffix_key'] = shortkey[0] serializer_label['prefix_key'] = shortkey[1] existing_shortkeys.add(shortkey)
color = Color.random(seed=random_seed) serializer_label['background_color'] = color.hex serializer_label['text_color'] = color.contrast_color.hex
serializer_labels.append(serializer_label)
return serializer_labels
@classmethod def get_shortkey(cls, label, existing_shortkeys): model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS] prefix_keys = [None] + model_prefix_keys
model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS} suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
for shortkey in itertools.product(suffix_keys, prefix_keys): if shortkey not in existing_shortkeys: return shortkey
return None
@classmethod def update_saved_labels(cls, saved, new): for label in new: saved[label.text] = label return saved
class PlainStorage(BaseStorage):
@transaction.atomic def save(self, user): for text in self.data: self.save_doc(text)
class ClassificationStorage(BaseStorage): """Store json for text classification.
The format is as follows: {"text": "Python is awesome!", "labels": ["positive"]} ... """
@transaction.atomic def save(self, user): saved_labels = {label.text: label for label in self.project.labels.all()} for data in self.data: docs = self.save_doc(data) labels = self.extract_label(data) unique_labels = self.extract_unique_labels(labels) unique_labels = self.exclude_created_labels(unique_labels, saved_labels) unique_labels = self.to_serializer_format(unique_labels, saved_labels) new_labels = self.save_label(unique_labels) saved_labels = self.update_saved_labels(saved_labels, new_labels) annotations = self.make_annotations(docs, labels, saved_labels) self.save_annotation(annotations, user)
@classmethod def extract_unique_labels(cls, labels): return set(itertools.chain(*labels))
@classmethod def make_annotations(cls, docs, labels, saved_labels): annotations = [] for doc, label in zip(docs, labels): for name in label: label = saved_labels[name] annotations.append({'document': doc.id, 'label': label.id}) return annotations
class SequenceLabelingStorage(BaseStorage): """Upload jsonl for sequence labeling.
The format is as follows: {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]} ... """
@transaction.atomic def save(self, user): saved_labels = {label.text: label for label in self.project.labels.all()} for data in self.data: docs = self.save_doc(data) labels = self.extract_label(data) unique_labels = self.extract_unique_labels(labels) unique_labels = self.exclude_created_labels(unique_labels, saved_labels) unique_labels = self.to_serializer_format(unique_labels, saved_labels) new_labels = self.save_label(unique_labels) saved_labels = self.update_saved_labels(saved_labels, new_labels) annotations = self.make_annotations(docs, labels, saved_labels) self.save_annotation(annotations, user)
@classmethod def extract_unique_labels(cls, labels): return set([label for _, _, label in itertools.chain(*labels)])
@classmethod def make_annotations(cls, docs, labels, saved_labels): annotations = [] for doc, spans in zip(docs, labels): for span in spans: start_offset, end_offset, name = span label = saved_labels[name] annotations.append({'document': doc.id, 'label': label.id, 'start_offset': start_offset, 'end_offset': end_offset}) return annotations
class Seq2seqStorage(BaseStorage): """Store json for seq2seq.
The format is as follows: {"text": "Hello, World!", "labels": ["こんにちは、世界!"]} ... """
@transaction.atomic def save(self, user): for data in self.data: doc = self.save_doc(data) labels = self.extract_label(data) annotations = self.make_annotations(doc, labels) self.save_annotation(annotations, user)
@classmethod def make_annotations(cls, docs, labels): annotations = [] for doc, texts in zip(docs, labels): for text in texts: annotations.append({'document': doc.id, 'text': text}) return annotations
class FileParser(object):
def parse(self, file): raise NotImplementedError()
class CoNLLParser(FileParser): """Uploads CoNLL format file.
The file format is tab-separated values. A blank line is required at the end of a sentence. For example: ``` EU B-ORG rejects O German B-MISC call O to O boycott O British B-MISC lamb O . O
Peter B-PER Blackburn I-PER ... ``` """
def parse(self, file): """Store json for seq2seq.
Return format: {"text": "Python is awesome!", "labels": [[0, 6, "Product"],]} ... """
words, tags = [], [] data = [] for i, line in enumerate(file, start=1): if len(data) >= IMPORT_BATCH_SIZE: yield data data = [] line = line.decode('utf-8') line = line.strip() if line: try: word, tag = line.split('\t') except ValueError: raise FileParseException(line_num=i, line=line) words.append(word) tags.append(tag) else: j = self.calc_char_offset(words, tags) data.append(j) words, tags = [], [] if len(words) > 0: j = self.calc_char_offset(words, tags) data.append(j) yield data
@classmethod def calc_char_offset(cls, words, tags): doc = ' '.join(words) j = {'text': ' '.join(words), 'labels': []} pos = defaultdict(int) for label, start_offset, end_offset in get_entities(tags): entity = ' '.join(words[start_offset: end_offset + 1]) char_left = doc.index(entity, pos[entity]) char_right = char_left + len(entity) span = [char_left, char_right, label] j['labels'].append(span) pos[entity] = char_right return j
class PlainTextParser(FileParser): """Uploads plain text.
The file format is as follows: ``` EU rejects German call to boycott British lamb. President Obama is speaking at the White House. ... ``` """
def parse(self, file): file = io.TextIOWrapper(file, encoding='utf-8') while True: batch = list(itertools.islice(file, IMPORT_BATCH_SIZE)) if not batch: break yield [{'text': line.strip()} for line in batch]
class CSVParser(FileParser): """Uploads csv file.
The file format is comma separated values. Column names are required at the top of a file. For example: ``` text, label "EU rejects German call to boycott British lamb.",Politics "President Obama is speaking at the White House.",Politics "He lives in Newark, Ohio.",Other ... ``` """
def parse(self, file): file = io.TextIOWrapper(file, encoding='utf-8') reader = csv.reader(file) columns = next(reader) data = [] for i, row in enumerate(reader, start=2): if len(data) >= IMPORT_BATCH_SIZE: yield data data = [] if len(row) == len(columns) and len(row) >= 2: text, label = row[:2] meta = json.dumps(dict(zip(columns[2:], row[2:]))) j = {'text': text, 'labels': [label], 'meta': meta} data.append(j) else: raise FileParseException(line_num=i, line=row) if data: yield data
class JSONParser(FileParser):
def parse(self, file): data = [] for i, line in enumerate(file, start=1): if len(data) >= IMPORT_BATCH_SIZE: yield data data = [] try: j = json.loads(line) j['meta'] = json.dumps(j.get('meta', {})) data.append(j) except json.decoder.JSONDecodeError: raise FileParseException(line_num=i, line=line) if data: yield data
class JSONLRenderer(JSONRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None): """
Render `data` into JSON, returning a bytestring. """
if data is None: return bytes()
if not isinstance(data, list): data = [data]
for d in data: yield json.dumps(d, cls=self.encoder_class, ensure_ascii=self.ensure_ascii, allow_nan=not self.strict) + '\n'
class JSONPainter(object):
def paint(self, documents): serializer = DocumentSerializer(documents, many=True) data = [] for d in serializer.data: d['meta'] = json.loads(d['meta']) for a in d['annotations']: a.pop('id') a.pop('prob') a.pop('document') data.append(d) return data
class CSVPainter(JSONPainter):
def paint(self, documents): data = super().paint(documents) res = [] for d in data: annotations = d.pop('annotations') for a in annotations: res.append({**d, **a}) return res
class Color: def __init__(self, red, green, blue): self.red = red self.green = green self.blue = blue
@property def contrast_color(self): """Generate black or white color.
Ensure that text and background color combinations provide sufficient contrast when viewed by someone having color deficits or when viewed on a black and white screen.
Algorithm from w3c: * https://www.w3.org/TR/AERT/#color-contrast """
return Color.white() if self.brightness < 128 else Color.black()
@property def brightness(self): return ((self.red * 299) + (self.green * 587) + (self.blue * 114)) / 1000
@property def hex(self): return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
@classmethod def white(cls): return cls(red=255, green=255, blue=255)
@classmethod def black(cls): return cls(red=0, green=0, blue=0)
@classmethod def random(cls, seed=None): rgb = Random(seed).choices(range(256), k=3) return cls(*rgb)
|