Browse Source

Move examples to unit tests

pull/202/head
Clemens Wolff 5 years ago
parent
commit
340209a90e
2 changed files with 154 additions and 113 deletions
  1. 135
      app/server/tests/test_utils.py
  2. 132
      app/server/utils.py

135
app/server/tests/test_utils.py

@ -1,6 +1,10 @@
from django.test import TestCase
from server.utils import Color
from seqeval.metrics.sequence_labeling import get_entities
from ..models import Label, Document
from ..utils import BaseStorage, ClassificationStorage, SequenceLabelingStorage, Seq2seqStorage, CoNLLParser
from ..utils import Color
class TestColor(TestCase):
@ -20,3 +24,132 @@ class TestColor(TestCase):
color = Color(red=199, green=21, blue=133)
self.assertEqual(color.contrast_color.hex, '#ffffff')
class TestBaseStorage(TestCase):
def test_extract_label(self):
data = [{'labels': ['positive']}, {'labels': ['negative']}]
actual = BaseStorage.extract_label(data)
self.assertEqual(actual, [['positive'], ['negative']])
def test_exclude_created_labels(self):
labels = ['positive', 'negative']
created = {'positive': Label(text='positive')}
actual = BaseStorage.exclude_created_labels(labels, created)
self.assertEqual(actual, ['negative'])
def test_to_serializer_format(self):
labels = ['positive']
created = {}
actual = BaseStorage.to_serializer_format(labels, created, random_seed=123)
self.assertEqual(actual, [{
'text': 'positive',
'prefix_key': None,
'suffix_key': 'p',
'background_color': '#0d1668',
'text_color': '#ffffff',
}])
def test_get_shortkey_without_existing_shortkey(self):
label = 'positive'
created = {}
actual = BaseStorage.get_shortkey(label, created)
self.assertEqual(actual, ('p', None))
def test_get_shortkey_with_existing_shortkey(self):
label = 'positive'
created = {('p', None)}
actual = BaseStorage.get_shortkey(label, created)
self.assertEqual(actual, ('p', 'ctrl'))
def test_update_saved_labels(self):
saved = {'positive': Label(text='positive', text_color='#000000')}
new = [Label(text='positive', text_color='#ffffff')]
actual = BaseStorage.update_saved_labels(saved, new)
self.assertEqual(actual['positive'].text_color, '#ffffff')
class TestClassificationStorage(TestCase):
def test_extract_unique_labels(self):
labels = [['positive'], ['positive', 'negative'], ['negative']]
actual = ClassificationStorage.extract_unique_labels(labels)
self.assertCountEqual(actual, ['positive', 'negative'])
def test_make_annotations(self):
docs = [Document(text='a', id=1), Document(text='b', id=2), Document(text='c', id=3)]
labels = [['positive'], ['positive', 'negative'], ['negative']]
saved_labels = {'positive': Label(text='positive', id=1), 'negative': Label(text='negative', id=2)}
actual = ClassificationStorage.make_annotations(docs, labels, saved_labels)
self.assertCountEqual(actual, [
{'document': 1, 'label': 1},
{'document': 2, 'label': 1},
{'document': 2, 'label': 2},
{'document': 3, 'label': 2},
])
class TestSequenceLabelingStorage(TestCase):
def test_extract_unique_labels(self):
labels = [[[0, 1, 'LOC']], [[3, 4, 'ORG']]]
actual = SequenceLabelingStorage.extract_unique_labels(labels)
self.assertCountEqual(actual, ['LOC', 'ORG'])
def test_make_annotations(self):
docs = [Document(text='a', id=1), Document(text='b', id=2)]
labels = [[[0, 1, 'LOC']], [[3, 4, 'ORG']]]
saved_labels = {'LOC': Label(text='LOC', id=1), 'ORG': Label(text='ORG', id=2)}
actual = SequenceLabelingStorage.make_annotations(docs, labels, saved_labels)
self.assertEqual(actual, [
{'document': 1, 'label': 1, 'start_offset': 0, 'end_offset': 1},
{'document': 2, 'label': 2, 'start_offset': 3, 'end_offset': 4},
])
class TestSeq2seqStorage(TestCase):
def test_make_annotations(self):
docs = [Document(text='a', id=1), Document(text='b', id=2)]
labels = [['Hello!'], ['How are you?', "What's up?"]]
actual = Seq2seqStorage.make_annotations(docs, labels)
self.assertEqual(actual, [
{'document': 1, 'text': 'Hello!'},
{'document': 2, 'text': 'How are you?'},
{'document': 2, 'text': "What's up?"},
])
class TestCoNLLParser(TestCase):
def test_calc_char_offset(self):
words = ['EU', 'rejects', 'German', 'call']
tags = ['B-ORG', 'O', 'B-MISC', 'O']
entities = get_entities(tags)
actual = CoNLLParser.calc_char_offset(words, tags)
self.assertEqual(entities, [('ORG', 0, 0), ('MISC', 2, 2)])
self.assertEqual(actual, {
'text': 'EU rejects German call',
'labels': [[0, 2, 'ORG'], [11, 17, 'MISC']]
})

132
app/server/utils.py

@ -4,7 +4,6 @@ import itertools
import json
import re
from collections import defaultdict
from math import floor
from random import Random
from django.db import transaction
@ -55,40 +54,16 @@ class BaseStorage(object):
annotation = serializer.save(user=user)
return annotation
def extract_label(self, data):
"""Extract labels from parsed data.
Example:
>>> data = [{"labels": ["positive"]}, {"labels": ["negative"]}]
>>> self.extract_label(data)
[["positive"], ["negative"]]
"""
@classmethod
def extract_label(cls, data):
return [d.get('labels', []) for d in data]
def exclude_created_labels(self, labels, created):
"""Exclude created labels.
Example:
>>> labels = ["positive", "negative"]
>>> created = {"positive": ...}
>>> self.exclude_created_labels(labels, created)
["negative"]
"""
@classmethod
def exclude_created_labels(cls, labels, created):
return [label for label in labels if label not in created]
@classmethod
def to_serializer_format(cls, labels, created):
"""Convert a label to model dictionary.
Also assigns shortkeys for each label that don't clash with existing
label shortkeys.
Example:
>>> labels = ["positive"]
>>> created = {}
>>> BaseStorage.to_serializer_format(labels, created)
[{"text": "positive", "suffix_key": "p", "prefix_key": None}]
"""
def to_serializer_format(cls, labels, created, random_seed=None):
existing_shortkeys = {(label.suffix_key, label.prefix_key)
for label in created.values()}
@ -103,7 +78,7 @@ class BaseStorage(object):
serializer_label['prefix_key'] = shortkey[1]
existing_shortkeys.add(shortkey)
color = Color.random()
color = Color.random(seed=random_seed)
serializer_label['background_color'] = color.hex
serializer_label['text_color'] = color.contrast_color.hex
@ -113,16 +88,6 @@ class BaseStorage(object):
@classmethod
def get_shortkey(cls, label, existing_shortkeys):
"""Find the first non existing shortkey for the label.
Example without existing shortkey:
>>> BaseStorage.get_shortkey("positive", set())
("p", None)
Example with existing shortkey:
>>> BaseStorage.get_shortkey("positive", {("p", None)})
("p", "ctrl")
"""
model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
prefix_keys = [None] + model_prefix_keys
@ -135,13 +100,8 @@ class BaseStorage(object):
return None
def update_saved_labels(self, saved, new):
"""Update saved labels.
Example:
>>> saved = {'positive': ...}
>>> new = [<Label: positive>]
"""
@classmethod
def update_saved_labels(cls, saved, new):
for label in new:
saved[label.text] = label
return saved
@ -176,27 +136,12 @@ class ClassificationStorage(BaseStorage):
annotations = self.make_annotations(docs, labels, saved_labels)
self.save_annotation(annotations, user)
def extract_unique_labels(self, labels):
"""Extract unique labels
Example:
>>> labels = [["positive"], ["positive", "negative"], ["negative"]]
>>> self.extract_unique_labels(labels)
["positive", "negative"]
"""
@classmethod
def extract_unique_labels(cls, labels):
return set(itertools.chain(*labels))
def make_annotations(self, docs, labels, saved_labels):
"""Make list of annotation obj for serializer.
Example:
>>> docs = ["<Document: a>", "<Document: b>", "<Document: c>"]
>>> labels = [["positive"], ["positive", "negative"], ["negative"]]
>>> saved_labels = {"positive": "<Label: positive>", 'negative': "<Label: negative>"}
>>> self.make_annotations(docs, labels, saved_labels)
[{"document": 1, "label": 1}, {"document": 2, "label": 1}
{"document": 2, "label": 2}, {"document": 3, "label": 2}]
"""
@classmethod
def make_annotations(cls, docs, labels, saved_labels):
annotations = []
for doc, label in zip(docs, labels):
for name in label:
@ -226,29 +171,12 @@ class SequenceLabelingStorage(BaseStorage):
annotations = self.make_annotations(docs, labels, saved_labels)
self.save_annotation(annotations, user)
def extract_unique_labels(self, labels):
"""Extract unique labels
Example:
>>> labels = [[[0, 1, "LOC"]], [[3, 4, "ORG"]]]
>>> self.extract_unique_labels(labels)
["LOC", "ORG"]
"""
@classmethod
def extract_unique_labels(cls, labels):
return set([label for _, _, label in itertools.chain(*labels)])
def make_annotations(self, docs, labels, saved_labels):
"""Make list of annotation obj for serializer.
Example:
>>> docs = ["<Document: a>", "<Document: b>"]
>>> labels = labels = [[[0, 1, "LOC"]], [[3, 4, "ORG"]]]
>>> saved_labels = {"LOC": "<Label: LOC>", 'ORG': "<Label: ORG>"}
>>> self.make_annotations(docs, labels, saved_labels)
[
{"document": 1, "label": 1, "start_offset": 0, "end_offset": 1}
{"document": 2, "label": 2, "start_offset": 3, "end_offset": 4}
]
"""
@classmethod
def make_annotations(cls, docs, labels, saved_labels):
annotations = []
for doc, spans in zip(docs, labels):
for span in spans:
@ -276,16 +204,8 @@ class Seq2seqStorage(BaseStorage):
annotations = self.make_annotations(doc, labels)
self.save_annotation(annotations, user)
def make_annotations(self, docs, labels):
"""Make list of annotation obj for serializer.
Example:
>>> docs = ["<Document: a>", "<Document: b>"]
>>> labels = [["Hello!"], ["How are you?", "What's up?"]]
>>> self.make_annotations(docs, labels)
[{"document": 1, "text": "Hello"}, {"document": 2, "text": "How are you?"}
{"document": 2, "text": "What's up?"}]
"""
@classmethod
def make_annotations(cls, docs, labels):
annotations = []
for doc, texts in zip(docs, labels):
for text in texts:
@ -352,20 +272,8 @@ class CoNLLParser(FileParser):
data.append(j)
yield data
def calc_char_offset(self, words, tags):
"""
Examples:
>>> words = ['EU', 'rejects', 'German', 'call']
>>> tags = ['B-ORG', 'O', 'B-MISC', 'O']
>>> entities = get_entities(tags)
>>> entities
[['ORG', 0, 0], ['MISC', 2, 2]]
>>> self.calc_char_offset(words, tags)
{
'text': 'EU rejects German call',
'labels': [[0, 2, 'ORG'], [11, 17, 'MISC']]
}
"""
@classmethod
def calc_char_offset(cls, words, tags):
doc = ' '.join(words)
j = {'text': ' '.join(words), 'labels': []}
pos = defaultdict(int)

Loading…
Cancel
Save