Browse Source

Merge pull request #199 from CatalystCode/enhancement/auto-generate-label-shortkeys

Enhancement/Auto-generate label shortkeys and colors on corpus import
pull/202/head
Hiroki Nakayama 5 years ago
committed by GitHub
parent
commit
17b9d3b083
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 175 additions and 16 deletions
  1. 2
      app/server/models.py
  2. 17
      app/server/serializers.py
  3. 2
      app/server/static/js/label.vue
  4. 1
      app/server/tests/data/classification.jsonl
  5. 40
      app/server/tests/test_api.py
  6. 22
      app/server/tests/test_utils.py
  7. 107
      app/server/utils.py

2
app/server/models.py

@ -142,7 +142,7 @@ class Label(models.Model):
('shift', 'shift'), ('shift', 'shift'),
('ctrl shift', 'ctrl shift') ('ctrl shift', 'ctrl shift')
) )
SUFFIX_KEYS = (
SUFFIX_KEYS = tuple(
(c, c) for c in string.ascii_lowercase (c, c) for c in string.ascii_lowercase
) )

17
app/server/serializers.py

@ -34,12 +34,17 @@ class LabelSerializer(serializers.ModelSerializer):
raise ValidationError('Shortcut key may not have a suffix key.') raise ValidationError('Shortcut key may not have a suffix key.')
# Don't allow to save same shortcut key when prefix_key is null. # Don't allow to save same shortcut key when prefix_key is null.
context = self.context['request'].parser_context
project_id = context['kwargs'].get('project_id')
if Label.objects.filter(suffix_key=suffix_key,
prefix_key__isnull=True,
project=project_id).exists():
raise ValidationError('Duplicate key.')
try:
context = self.context['request'].parser_context
project_id = context['kwargs']['project_id']
except (AttributeError, KeyError):
pass # unit tests don't always have the correct context set up
else:
if Label.objects.filter(suffix_key=suffix_key,
prefix_key__isnull=True,
project=project_id).exists():
raise ValidationError('Duplicate key.')
return super().validate(attrs) return super().validate(attrs)
class Meta: class Meta:

2
app/server/static/js/label.vue

@ -203,7 +203,7 @@ export default {
methods: { methods: {
generateColor() { generateColor() {
const color = (Math.random() * 0xFFFFFF | 0).toString(16); // eslint-disable-line no-bitwise
const color = Math.floor(Math.random() * 0xFFFFFF).toString(16);
const randomColor = '#' + ('000000' + color).slice(-6); const randomColor = '#' + ('000000' + color).slice(-6);
return randomColor; return randomColor;
}, },

1
app/server/tests/data/classification.jsonl

@ -1,3 +1,4 @@
{"text": "example", "labels": ["positive"], "meta": {"wikiPageID": 1}} {"text": "example", "labels": ["positive"], "meta": {"wikiPageID": 1}}
{"text": "example", "labels": ["positive", "negative"], "meta": {"wikiPageID": 2}} {"text": "example", "labels": ["positive", "negative"], "meta": {"wikiPageID": 2}}
{"text": "example", "labels": ["negative"], "meta": {"wikiPageID": 3}} {"text": "example", "labels": ["negative"], "meta": {"wikiPageID": 3}}
{"text": "example", "labels": ["neutral"], "meta": {"wikiPageID": 4}}

40
app/server/tests/test_api.py

@ -682,7 +682,9 @@ class TestUploader(APITestCase):
users=[super_user], project_type=SEQUENCE_LABELING) users=[super_user], project_type=SEQUENCE_LABELING)
cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ) cls.seq2seq_project = mommy.make('server.Seq2seqProject', users=[super_user], project_type=SEQ2SEQ)
cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id]) cls.classification_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id])
cls.classification_labels_url = reverse(viewname='label_list', args=[cls.classification_project.id])
cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id]) cls.labeling_url = reverse(viewname='doc_uploader', args=[cls.labeling_project.id])
cls.labeling_labels_url = reverse(viewname='label_list', args=[cls.labeling_project.id])
cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id]) cls.seq2seq_url = reverse(viewname='doc_uploader', args=[cls.seq2seq_project.id])
def setUp(self): def setUp(self):
@ -694,6 +696,20 @@ class TestUploader(APITestCase):
response = self.client.post(url, data={'file': f, 'format': format}) response = self.client.post(url, data={'file': f, 'format': format})
self.assertEqual(response.status_code, expected_status) self.assertEqual(response.status_code, expected_status)
def label_test_helper(self, url, expected_labels, expected_label_keys):
expected_keys = {key for label in expected_labels for key in label}
response = self.client.get(url).json()
actual_labels = [{key: value for (key, value) in label.items() if key in expected_keys}
for label in response]
self.assertCountEqual(actual_labels, expected_labels)
for label in response:
for expected_label_key in expected_label_keys:
self.assertIsNotNone(label.get(expected_label_key))
def test_can_upload_conll_format_file(self): def test_can_upload_conll_format_file(self):
self.upload_test_helper(url=self.labeling_url, self.upload_test_helper(url=self.labeling_url,
filename='labeling.conll', filename='labeling.conll',
@ -736,12 +752,36 @@ class TestUploader(APITestCase):
format='json', format='json',
expected_status=status.HTTP_201_CREATED) expected_status=status.HTTP_201_CREATED)
self.label_test_helper(
url=self.classification_labels_url,
expected_labels=[
{'text': 'positive', 'suffix_key': 'p', 'prefix_key': None},
{'text': 'negative', 'suffix_key': 'n', 'prefix_key': None},
{'text': 'neutral', 'suffix_key': 'n', 'prefix_key': 'ctrl'},
],
expected_label_keys=[
'background_color',
'text_color',
])
def test_can_upload_labeling_jsonl(self): def test_can_upload_labeling_jsonl(self):
self.upload_test_helper(url=self.labeling_url, self.upload_test_helper(url=self.labeling_url,
filename='labeling.jsonl', filename='labeling.jsonl',
format='json', format='json',
expected_status=status.HTTP_201_CREATED) expected_status=status.HTTP_201_CREATED)
self.label_test_helper(
url=self.labeling_labels_url,
expected_labels=[
{'text': 'LOC', 'suffix_key': 'l', 'prefix_key': None},
{'text': 'ORG', 'suffix_key': 'o', 'prefix_key': None},
{'text': 'PER', 'suffix_key': 'p', 'prefix_key': None},
],
expected_label_keys=[
'background_color',
'text_color',
])
def test_can_upload_seq2seq_jsonl(self): def test_can_upload_seq2seq_jsonl(self):
self.upload_test_helper(url=self.seq2seq_url, self.upload_test_helper(url=self.seq2seq_url,
filename='seq2seq.jsonl', filename='seq2seq.jsonl',

22
app/server/tests/test_utils.py

@ -0,0 +1,22 @@
from django.test import TestCase
from server.utils import Color
class TestColor(TestCase):
def test_random_color(self):
color = Color.random()
self.assertTrue(0 <= color.red <= 255)
self.assertTrue(0 <= color.green <= 255)
self.assertTrue(0 <= color.blue <= 255)
def test_hex(self):
color = Color(red=255, green=192, blue=203)
self.assertEqual(color.hex, '#ffc0cb')
def test_contrast_color(self):
color = Color(red=255, green=192, blue=203)
self.assertEqual(color.contrast_color.hex, '#000000')
color = Color(red=199, green=21, blue=133)
self.assertEqual(color.contrast_color.hex, '#ffffff')

107
app/server/utils.py

@ -4,6 +4,8 @@ import itertools
import json import json
import re import re
from collections import defaultdict from collections import defaultdict
from math import floor
from random import Random
from django.db import transaction from django.db import transaction
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
@ -74,16 +76,64 @@ class BaseStorage(object):
""" """
return [label for label in labels if label not in created] return [label for label in labels if label not in created]
def to_serializer_format(self, labels):
"""Exclude created labels.
@classmethod
def to_serializer_format(cls, labels, created):
"""Convert a label to model dictionary.
Also assigns shortkeys for each label that don't clash with existing
label shortkeys.
Example: Example:
>>> labels = ["positive"] >>> labels = ["positive"]
>>> self.to_serializer_format(labels)
[{"text": "negative"}]
```
>>> created = {}
>>> BaseStorage.to_serializer_format(labels, created)
[{"text": "positive", "suffix_key": "p", "prefix_key": None}]
""" """
return [{'text': label} for label in labels]
existing_shortkeys = {(label.suffix_key, label.prefix_key)
for label in created.values()}
serializer_labels = []
for label in sorted(labels):
serializer_label = {'text': label}
shortkey = cls.get_shortkey(label, existing_shortkeys)
if shortkey:
serializer_label['suffix_key'] = shortkey[0]
serializer_label['prefix_key'] = shortkey[1]
existing_shortkeys.add(shortkey)
color = Color.random()
serializer_label['background_color'] = color.hex
serializer_label['text_color'] = color.contrast_color.hex
serializer_labels.append(serializer_label)
return serializer_labels
@classmethod
def get_shortkey(cls, label, existing_shortkeys):
"""Find the first non existing shortkey for the label.
Example without existing shortkey:
>>> BaseStorage.get_shortkey("positive", set())
("p", None)
Example with existing shortkey:
>>> BaseStorage.get_shortkey("positive", {("p", None)})
("p", "ctrl")
"""
model_prefix_keys = [key for (key, _) in Label.PREFIX_KEYS]
prefix_keys = [None] + model_prefix_keys
model_suffix_keys = {key for (key, _) in Label.SUFFIX_KEYS}
suffix_keys = [key for key in label.lower() if key in model_suffix_keys]
for shortkey in itertools.product(suffix_keys, prefix_keys):
if shortkey not in existing_shortkeys:
return shortkey
return None
def update_saved_labels(self, saved, new): def update_saved_labels(self, saved, new):
"""Update saved labels. """Update saved labels.
@ -120,7 +170,7 @@ class ClassificationStorage(BaseStorage):
labels = self.extract_label(data) labels = self.extract_label(data)
unique_labels = self.extract_unique_labels(labels) unique_labels = self.extract_unique_labels(labels)
unique_labels = self.exclude_created_labels(unique_labels, saved_labels) unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
unique_labels = self.to_serializer_format(unique_labels)
unique_labels = self.to_serializer_format(unique_labels, saved_labels)
new_labels = self.save_label(unique_labels) new_labels = self.save_label(unique_labels)
saved_labels = self.update_saved_labels(saved_labels, new_labels) saved_labels = self.update_saved_labels(saved_labels, new_labels)
annotations = self.make_annotations(docs, labels, saved_labels) annotations = self.make_annotations(docs, labels, saved_labels)
@ -170,7 +220,7 @@ class SequenceLabelingStorage(BaseStorage):
labels = self.extract_label(data) labels = self.extract_label(data)
unique_labels = self.extract_unique_labels(labels) unique_labels = self.extract_unique_labels(labels)
unique_labels = self.exclude_created_labels(unique_labels, saved_labels) unique_labels = self.exclude_created_labels(unique_labels, saved_labels)
unique_labels = self.to_serializer_format(unique_labels)
unique_labels = self.to_serializer_format(unique_labels, saved_labels)
new_labels = self.save_label(unique_labels) new_labels = self.save_label(unique_labels)
saved_labels = self.update_saved_labels(saved_labels, new_labels) saved_labels = self.update_saved_labels(saved_labels, new_labels)
annotations = self.make_annotations(docs, labels, saved_labels) annotations = self.make_annotations(docs, labels, saved_labels)
@ -444,3 +494,44 @@ class CSVPainter(JSONPainter):
for a in annotations: for a in annotations:
res.append({**d, **a}) res.append({**d, **a})
return res return res
class Color:
def __init__(self, red, green, blue):
self.red = red
self.green = green
self.blue = blue
@property
def contrast_color(self):
"""Generate black or white color.
Ensure that text and background color combinations provide
sufficient contrast when viewed by someone having color deficits or
when viewed on a black and white screen.
Algorithm from w3c:
* https://www.w3.org/TR/AERT/#color-contrast
"""
return Color.white() if self.brightness < 128 else Color.black()
@property
def brightness(self):
return ((self.red * 299) + (self.green * 587) + (self.blue * 114)) / 1000
@property
def hex(self):
return '#{:02x}{:02x}{:02x}'.format(self.red, self.green, self.blue)
@classmethod
def white(cls):
return cls(red=255, green=255, blue=255)
@classmethod
def black(cls):
return cls(red=0, green=0, blue=0)
@classmethod
def random(cls, seed=None):
rgb = Random(seed).choices(range(256), k=3)
return cls(*rgb)
Loading…
Cancel
Save