Browse Source

Merge pull request #199 from CatalystCode/enhancement/auto-generate-label-shortkeys

Enhancement/Auto-generate label shortkeys and colors on corpus import
pull/202/head
Hiroki Nakayama 5 years ago
committed by GitHub
parent
commit
17b9d3b083
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 175 additions and 16 deletions
  1. 2
      app/server/models.py
  2. 17
      app/server/serializers.py
  3. 2
      app/server/static/js/label.vue
  4. 1
      app/server/tests/data/classification.jsonl
  5. 40
      app/server/tests/test_api.py
  6. 22
      app/server/tests/test_utils.py
  7. 107
      app/server/utils.py

2
app/server/models.py

@ -142,7 +142,7 @@ class Label(models.Model):
('shift', 'shift'),
('ctrl shift', 'ctrl shift')
)
SUFFIX_KEYS = (
SUFFIX_KEYS = tuple(
(c, c) for c in string.ascii_lowercase
)

17
app/server/serializers.py

@ -34,12 +34,17 @@ class LabelSerializer(serializers.ModelSerializer):
raise ValidationError('Shortcut key may not have a suffix key.')
# Don't allow to save same shortcut key when prefix_key is null.
context = self.context['request'].parser_context
project_id = context['kwargs'].get('project_id')
if Label.objects.filter(suffix_key=suffix_key,
prefix_key__isnull=True,
project=project_id).exists():
raise ValidationError('Duplicate key.')
try:
context = self.context['request'].parser_context
project_id = context['kwargs']['project_id']
except (AttributeError, KeyError):
pass # unit tests don't always have the correct context set up
else:
if Label.objects.filter(suffix_key=suffix_key,
prefix_key__isnull=True,
project=project_id).exists():
raise ValidationError('Duplicate key.')
return super().validate(attrs)
class Meta:

2
app/server/static/js/label.vue

@ -203,7 +203,7 @@ export default {
methods: {
generateColor() {
const color = (Math.random() * 0xFFFFFF | 0).toString(16); // eslint-disable-line no-bitwise
const color = Math.floor(Math.random() * 0xFFFFFF).toString(16);
const randomColor = '#' + ('000000' + color).slice(-6);
return randomColor;
},

1
app/server/tests/data/classification.jsonl

@ -1,3 +1,4 @@
{"text": "example", "labels": ["positive"], "meta": {"wikiPageID": 1}}
{"text": "example", "labels": ["positive", "negative"], "meta": {"wikiPageID": 2}}
{"text": "example", "labels": ["negative"], "meta": {"wikiPageID": 3}}
{"text": "example", "labels": ["neutral"], "meta": {"wikiPageID": 4}}

40
app/server/tests/test_api.py

@ -682,7 +682,9 @@ class TestUploader(APITestCase):
users=[super_user], project_type=SEQUENCE_LABELING)
cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ)
cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id])
cls.classification_labels_url = reverse(viewname='label_list', args=[cls.classification_project.id])
cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id])
cls.labeling_labels_url = reverse(viewname='label_list', args=[cls.labeling_project.id])
cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id])
def setUp(self):
@ -694,6 +696,20 @@ class TestUploader(APITestCase):
response = self.client.post(url, data={'file': f, 'format': format})
self.assertEqual(response.status_code, expected_status)
def label_test_helper(self, url, expected_labels, expected_label_keys):
expected_keys = {key for label in expected_labels for key in label}
response = self.client.get(url).json()
actual_labels = [{key: value for (key, value) in label.items() if key in expected_keys}
for label in response]
self.assertCountEqual(actual_labels, expected_labels)
for label in response:
for expected_label_key in expected_label_keys:
self.assertIsNotNone(label.get(expected_label_key))
def test_can_upload_conll_format_file(self):
self.upload_test_helper(url=self.labeling_url,
filename='labeling.conll',
@ -736,12 +752,36 @@ class TestUploader(APITestCase):
format='json',
expected_status=status.HTTP_201_CREATED)
self.label_test_helper(
url=self.classification_labels_url,
expected_labels=[
{'text': 'positive', 'suffix_key': 'p', 'prefix_key': None},
{'text': 'negative', 'suffix_key': 'n', 'prefix_key': None},
{'text': 'neutral', 'suffix_key': 'n', 'prefix_key': 'ctrl'},
],
expected_label_keys=[
'background_color',
'text_color',
])
def test_can_upload_labeling_jsonl(self):
self.upload_test_helper(url=self.labeling_url,
filename='labeling.jsonl',
format='json',
expected_status=status.HTTP_201_CREATED)
self.label_test_helper(
url=self.labeling_labels_url,
expected_labels=[
{'text': 'LOC', 'suffix_key': 'l', 'prefix_key': None},
{'text': 'ORG', 'suffix_key': 'o', 'prefix_key': None},
{'text': 'PER', 'suffix_key': 'p', 'prefix_key': None},
],
expected_label_keys=[
'background_color',
'text_color',
])
def test_can_upload_seq2seq_jsonl(self):
self.upload_test_helper(url=self.seq2seq_url,
filename='seq2seq.jsonl',

22
app/server/tests/test_utils.py

@ -0,0 +1,22 @@
from django.test import TestCase
from server.utils import Color
class TestColor(TestCase):
def test_random_color(self):
color = Color.random()
self.assertTrue(0 <= color.red <= 255)
self.assertTrue(0 <= color.green <= 255)
self.assertTrue(0 <= color.blue <= 255)
def test_hex(self):
color = Color(red=255, green=192, blue=203)
self.assertEqual(color.hex, '#ffc0cb')
def test_contrast_color(self):
color = Color(red=255, green=192, blue=203)
self.assertEqual(color.contrast_color.hex, '#000000')
color = Color(red=199, green=21, blue=133)
self.assertEqual(color.contrast_color.hex, '#ffffff')

107
app/server/utils.py

@ -4,6 +4,8 @@ import itertools
import json
import re
from collections import defaultdict
from math import floor
from random import Random
from django.db import transaction
from rest_framework.renderers import JSONRenderer
@ -74,16 +76,64 @@ class BaseStorage(object):
"""
return [label for label in labels if label not in created]
def to_serializer_format(self, labels):
"""Exclude created labels.
@classmethod
def to_serializer_format(cls, labels, created):
"""Convert a label to model dictionary.
Also assigns shortkeys for each label that don't clash with existing
label shortkeys.
Example:
>>> labels = ["positive"]
>>> self.to_serializer_format(labels)
[{"text": "negative"}]
```
>>> created = {}
>>> BaseStorage.to_serializer_format(labels, created)
[{"text": "positive", "suffix_key": "p", "prefix_key": None}]
"""
return [{'text': label} for label in labels]
existing_shortkeys = {(label.suffix_key, label.prefix_key)
for label in created.values()}
serializer_labels = []
for label in sorted(labels):
serializer_label = {'text': label}
shortkey = cls.get_shortkey(label, existing_shortkeys)
if shortkey:
serializer_label['suffix_key'] = shortkey[0]
serializer_label['prefix_key'] = shortkey[1]
existing_shortkeys.add(shortkey)
color = Color.random()
serializer_label['background_color'] = color.hex
serializer_label['text_color'] = color.contrast_color.hex
serializer_labels.append(serializer_label)
return serializer_labels
@classmethod
def get_shortkey(cls, label, existing_shortkeys):
"""Find the first non existing shortkey for the label.
Example without existing shortkey:
>>> BaseStorage.get_shortkey("positive", set())
("p", None)
Example with existing shortkey:
>>> BaseStorage.get_shortkey("positive", {("p", None)})
("p", "ctrl")
"""
model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
prefix_keys = [None] + model_prefix_keys
model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
for shortkey in itertools.product(suffix_keys, prefix_keys):
if shortkey not in existing_shortkeys:
return shortkey
return None
def update_saved_labels(self, saved, new):
"""Update saved labels.
@ -120,7 +170,7 @@ class ClassificationStorage(BaseStorage):
labels = self.extract_label(data)
unique_labels = self.extract_unique_labels(labels)
unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
unique_labels = self.to_serializer_format(unique_labels)
unique_labels = self.to_serializer_format(unique_labels, saved_labels)
new_labels = self.save_label(unique_labels)
saved_labels = self.update_saved_labels(saved_labels, new_labels)
annotations = self.make_annotations(docs, labels, saved_labels)
@ -170,7 +220,7 @@ class SequenceLabelingStorage(BaseStorage):
labels = self.extract_label(data)
unique_labels = self.extract_unique_labels(labels)
unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
unique_labels = self.to_serializer_format(unique_labels)
unique_labels = self.to_serializer_format(unique_labels, saved_labels)
new_labels = self.save_label(unique_labels)
saved_labels = self.update_saved_labels(saved_labels, new_labels)
annotations = self.make_annotations(docs, labels, saved_labels)
@ -444,3 +494,44 @@ class CSVPainter(JSONPainter):
for a in annotations:
res.append({**d, **a})
return res
class Color:
def __init__(self, red, green, blue):
self.red = red
self.green = green
self.blue = blue
@property
def contrast_color(self):
"""Generate black or white color.
Ensure that text and background color combinations provide
sufficient contrast when viewed by someone having color deficits or
when viewed on a black and white screen.
Algorithm from w3c:
* https://www.w3.org/TR/AERT/#color-contrast
"""
return Color.white() if self.brightness < 128 else Color.black()
@property
def brightness(self):
return ((self.red * 299) + (self.green * 587) + (self.blue * 114)) / 1000
@property
def hex(self):
return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
@classmethod
def white(cls):
return cls(red=255, green=255, blue=255)
@classmethod
def black(cls):
return cls(red=0, green=0, blue=0)
@classmethod
def random(cls, seed=None):
rgb = Random(seed).choices(range(256), k=3)
return cls(*rgb)
Loading…
Cancel
Save